换根DP

  |   0 评论   |   0 浏览

转自:大佬的博客

换根dp是一种用来求解树上各点到其他点的距离之和的算法。

在进行换根dp时,需要先利用 dfs 计算出根节点到其他节点的距离之和,以及以每个节点为根节点的子树的节点数量。然后再对其进行换根dp。以下假设每条边的长度为1。

Part1:计算根节点到其他节点距离之和
我们在求解时利用递归的思想进行求解:

假设 a 为根节点,b 为直系子节点,那么对于 b 所在子树对 a 的贡献为 dis[b]+point[b] ,其中 dis[b]为 b到以b为根节点的子树中所有点的距离之和。point[b]为以 b 为根节点的子树中的点的个数,其实很好理解,就相当于以 b 为根节点中的所有路径长度全部 +1,然后就到达了 a 节点。

代码如下:

int dis[maxn];//第一次dfs每个节点到其子节点距离之和
int point[maxn];//每个节点下子节点个数(包括这个节点本身)
int dp[maxn];//最终结果
bool vis[maxn];
vectorvec[maxn];
int dis[maxn];//第一次dfs每个节点到其子节点距离之和 
int point[maxn];//每个节点下子节点个数(包括这个节点本身) 
int dp[maxn];//最终结果  
bool vis[maxn];
vector<int>vec[maxn];

void dfs(int x){
	vis[x]=true;
	int sum=0;
	for(int i=0;i<vec[x].size();i++){
		int y=vec[x][i];
		if(!vis[y]){
			dfs(y);
			sum+=point[y];
			dis[x]+=dis[y]+point[y];
		}
	}
	point[x]=sum+1;
	return ;
}

Part2:进行换根dp
准备工作做完之后,我们就可以开始换根dp,换根dp的思想就是把与根相连的节点通过一定的操作将其变为根。

依然利用上述节点 a,b ,将根节点从 a 移动到 b ,dp[b] 的值为:dp[a]−point[b]+(n−point[b]),其中:

−point[b] 表示从 b 引申出来的 point[b] 条路径长度全部 −1.

n−point[b] 表示从 a 引申出来的不包含 b 的其他路径长度全部 +1
状态转移方程计算出来之后就可以利用 dfs 进行换根dp了。

代码如下(初始状态下dp[a]=dis[a]):

memset(vis,0,sizeof(vis));//dfs时使用过一次,因此需要清空
dp[a]=dis[a];//此处a指代上面自己找的dfs起点,一般情况下a=1
void Dp(int x){
	vis[x]=true;
	for(int i=0;i<vec[x].size();i++){
		int y=vec[x][i];
		if(!vis[y]){
			dp[y]=dp[x]-point[y]+n-point[y];
			Dp(y);
		}
	}
}dp[a]=dis[a];//此处a指代上面自己找的dfs起点,一般情况下a=1
void Dp(int x){
vis[x]=true;
for(int i=0;i<vec[x].size();i++){
int y=vec[x][i];
if(!vis[y]){
dp[y]=dp[x]-point[y]+n-point[y];
Dp(y);
}
}
}

这样就能计算出所有的点到其他节点的距离之和了。

Tree

这是一个换根dp模板题,直接套用就好了。

代码如下:

#include<bits/stdc++.h>
using namespace std;
const int maxn = 1e6+10;
int dis[maxn],n;//第一次dfs每个节点到其子节点距离之和
int point[maxn];//每个节点下子节点个数(包括这个节点本身)
int dp[maxn];//最终结果
bool vis[maxn];
vector<int>vec[maxn];

void dfs(int x){
    vis[x]=true;
    int sum=0;
    for(int i=0;i<vec[x].size();i++){
        int y=vec[x][i];
        if(!vis[y]){
            dfs(y);
            sum+=point[y];
            dis[x]+=dis[y]+point[y];
        }
    }
    point[x]=sum+1;
    return ;
}
 
void Dp(int x){
    vis[x]=true;
    for(int i=0;i<vec[x].size();i++){
        int y=vec[x][i];
        if(!vis[y]){
            dp[y]=dp[x]-point[y]+n-point[y];
            Dp(y);
        }
    }
}

int main()
{
	scanf("%d",&n);
    for(int i=1,u,v;i<n;i++){
        scanf("%d %d",&u,&v);
        vec[u].push_back(v);
        vec[v].push_back(u);
    }
    dfs(1);
    memset(vis,0,sizeof(vis));
    dp[1]=dis[1];
    Dp(1);
    int ans=dp[1];
    for(int i=2;i<=n;i++){
	    ans=min(ans,dp[i]);
	}
    printf("%d\n",ans);
}