【2020牛客多校第七场 E】NeoMole Synthesis 题解

题目大意

  给定一棵 n n n 个点的目标树,以及 m m m 棵模板树,每棵模板树有一个单价 c i c_i ci,数量无限多。这里的树都是无根树。
  现在要用若干模板树拼成目标树(就是用模板去覆盖目标树,使得目标树的每个点恰好被覆盖一次),求最小代价。

   n ≤ 500 ,   m ≤ 200 n \leq 500,\ m \leq 200 n500, m200,所有模板树的结点数总和 N ≤ 500 N \le 500 N500
   c i ≤ 1 0 6 c_i \leq 10^6 ci106
  1s

\\
\\
\\

题解

  妙啊。。。

  首先大的框架是个树形 dp。把目标树当有根树,设 d p i , ( j , p j ) dp_{i,(j,pj)} dpi,(j,pj) 表示目标树的第 i i i 个结点,匹配模板里的结点 j j j i i i 连向父亲的边匹配 j j j 连向 p j pj pj 的边,的最小代价;设 g i g_i gi 表示目标树以 i i i 为根的子树完全匹配的最小代价。
   d p dp dp 数组的状态数是 O ( n N ) O(nN) O(nN) 的。 g i g_i gi 也可以视为 min ⁡ j d p i , ( j , 0 ) \min_{j} dp_{i,(j,0)} minjdpi,(j,0) 0 0 0 就表示 j j j 没有父亲,是所在其模板树的根),所以以下就是求 d p dp dp 数组。

  而这个转移就是让 i i i 的儿子去匹配 j j j 的儿子。这是一个二分图最小权匹配:左边一排 d e g j − 1 deg_j-1 degj1 个点表示模板树上 j j j 的儿子(如果是在做 d p i , ( j , 0 ) dp_{i,(j,0)} dpi,(j,0),那么 j j j 没有父亲,左边就有 d e g j deg_j degj 个点),右边一排 d e g i − 1 deg_i-1 degi1 个点表示目标树上 i i i 的儿子,左边 x x x 连向右边 y y y 的边权是 d p y , ( x , j ) − g y dp_{y,(x,j)}-g_y dpy,(x,j)gy,最后答案加上 ∑ g y \sum g_y gy
  注意到只有 j j j 的儿子数 ≤ i \le i i 的儿子数才有意义,因此这样一次 KM 的时间复杂度是 O ( d e g j 2 ⋅ d e g i ) O(deg_j^2 \cdot deg_i) O(degj2degi),如果每一对 ( i , j , p j ) (i,j,pj) (i,j,pj) 都做一次 KM 的话,时间复杂度是
∑ i ∑ j d e g j ⋅ O ( d e g j 2 ⋅ d e g i ) = O ( n N 3 ) \sum_i \sum_j deg_j \cdot O(deg_j^2 \cdot deg_i) = O(nN^3) ijdegjO(degj2degi)=O(nN3)

  T 掉了。

  所以这里加一个改进,观察到, d p i , ( j , p j ) dp_{i,(j,pj)} dpi,(j,pj) 的二分图匹配,实际上就是 d p i , ( j , 0 ) dp_{i,(j,0)} dpi,(j,0) 的二分图匹配删去左边 p j pj pj 这个点。既然如此,就没必要重新跑一边 KM,直接用最短路退流就好了。
  具体来说,首先做出 d p i , ( j , 0 ) dp_{i,(j,0)} dpi,(j,0) 的 KM,答案为 a n s ans ans,然后求出左边每个点走交错路到达右边结点的最短路(从左到右只能退流匹配边,从右到左只能走非匹配边),记为 a u g p j aug_{pj} augpj,那么 d p i , ( j , p j ) = a n s + a u g p j dp_{i,(j,pj)}=ans+aug_{pj} dpi,(j,pj)=ans+augpj。这个最短路用 floyd 就好了。

  来分析时间复杂度。KM 的部分现在是
∑ i ∑ j O ( d e g j 2 ⋅ d e g i ) = O ( n N 2 ) \sum_i \sum_j O(deg_j^2 \cdot deg_i) = O(nN^2) ijO(degj2degi)=O(nN2)

  floyd 的部分需要注意姿势,如果直接对 d e g i + d e g j deg_i+deg_j degi+degj 个点跑最短路,或者对右边的 d e g i deg_i degi 个点跑最短路,时间都是不对的:
∑ i ∑ j ( d e g j + d e g i ) 3 = O ( n 3 N + n N 3 ) ∑ i ∑ j d e g i 3 = O ( n 3 N ) \begin{aligned} &\sum_i \sum_j (deg_j+deg_i)^3 = O(n^3N+nN^3) \\ &\sum_i \sum_j deg_i^3 = O(n^3N) \end{aligned} ij(degj+degi)3=O(n3N+nN3)ijdegi3=O(n3N)

  要对左边的点跑 floyd,用到的性质是“左边的点数 ≤ \le 右边的点数”,因此时间复杂度是
∑ j d e g j 3 ⋅ n d e g j = O ( n N 2 ) \sum_j deg_j^3 \cdot \frac{n}{deg_j} = O(nN^2) jdegj3degjn=O(nN2)

  具体来说,对于左边的两点 x , y x,y x,y,设它们的 KM 匹配点分别为 x ′ , y ′ x',y' x,y,二分图从 i i i j j j 的边权为 w i , j w_{i,j} wi,j,那么在 floyd 的初始距离中, x x x y y y 的距离为 − w x , x ′ + w y , x ′ -w_{x,x'}+w_{y,x'} wx,x+wy,x。这样求出最短路以后,退流 x x x 的答案为 a u g x = min ⁡ y d i s x , y − w y , y ′ aug_x=\min_{y} dis_{x,y}-w_{y,y'} augx=minydisx,ywy,y

  注意一个细节,如果 j j j 的儿子数比 i i i 多 1 个(即 d e g j = d e g i − 1 + 1 = d e g i deg_j=deg_i-1+1=deg_i degj=degi1+1=degi),那么 d p i , ( j , 0 ) dp_{i,(j,0)} dpi,(j,0) 的 KM 是不合法的,但 d p i , ( j , p j ) dp_{i,(j,pj)} dpi,(j,pj) 的 KM 都是合法的。这里可以给右边加一个空点,跑 d p i , ( j , 0 ) dp_{i,(j,0)} dpi,(j,0) 的 KM 但不要更新答案,然后退流的时候,强制 floyd 的终点是右边这个空点。

代码

// 这里的 KM 跑的是二分图最大权匹配,所以边权取反,floyd 跑最长路

#include
#define fo(i,a,b) for(int i=a;i<=b;i++)
using namespace std;

typedef long long LL;

const int maxn=505;
const LL inf=2139062143;

int n,m,N,c[maxn];
vector<int> e[maxn],em[maxn];
map<pair<int,int>,int> M;

LL lx[maxn],ly[maxn],slack[maxn],mp[maxn][maxn];
int f[maxn],pre[maxn];
bool vis[maxn];
LL KM(int nl,int nr)
{
	fo(i,1,nl)
	{
		lx[i]=-inf;
		fo(j,1,nr) lx[i]=max(lx[i],mp[i][j]);
	}
	memset(ly,0,sizeof(LL)*(nr+1));
	memset(f,0,sizeof(int)*(nr+1));
	memset(pre,0,sizeof(int)*(nr+1));
	fo(i,1,nl)
	{
		memset(slack,127,sizeof(LL)*(nr+1));
		memset(vis,0,sizeof(bool)*(nr+1));
		f[0]=i;
		int py=0, nextpy;
		for(; f[py]; py=nextpy)
		{
			int px=f[py];
			LL d=inf<<3;
			vis[py]=1;
			fo(j,1,nr) if (!vis[j])
			{
				if (lx[px]+ly[j]-mp[px][j]<slack[j]) slack[j]=lx[px]+ly[j]-mp[px][j], pre[j]=py;
				if (slack[j]<d) d=slack[j], nextpy=j;
			}
			fo(j,0,nr) if (vis[j]) lx[f[j]]-=d, ly[j]+=d;
				else slack[j]-=d;
		}
		for(; py; py=pre[py]) f[py]=f[pre[py]];
	}
	LL re=0;
	fo(i,1,nl) re+=lx[i];
	fo(i,1,nr) re+=ly[i];
	return re;
}

LL dis[maxn][maxn],aug[maxn];
int ff[maxn];
void floyd(int nl,int nr,bool ty)
{
	fo(y,1,nr) if (f[y]) ff[f[y]]=y;
	fo(i,1,nl)
		fo(j,1,nl) dis[i][j]=(i==j) ?0 :mp[j][ff[i]]-mp[i][ff[i]];
	
	fo(k,1,nl)
		fo(i,1,nl) if (i!=k)
			fo(j,1,nl) if (j!=i && j!=k) dis[i][j]=max(dis[i][j],dis[i][k]+dis[k][j]);
	
	if (ty)
	{
		fo(x,1,nl) aug[x]=dis[x][f[nr]];
	} else
	{
		fo(x,1,nl)
		{
			aug[x]=-mp[x][ff[x]];
			fo(y,1,nl) aug[x]=max(aug[x],dis[x][y]-mp[y][ff[y]]);
		}
	}
}

LL dp[maxn][3*maxn],g[maxn];
int s0,s[maxn];
void dfs(int k,int last)
{
	for(int son:e[k]) if (son!=last) dfs(son,k);
	
	s0=0;
	LL gsum=0;
	for(int son:e[k]) if (son!=last) s[++s0]=son, gsum+=g[son];
	g[k]=inf;
	for(int i=1, cnt=0, sz; i<=N; i++, cnt+=sz)
	{
		sz=em[i].size();
		if (sz-1>s0)
		{
			fo(j,0,sz-1) dp[k][cnt+j]=inf;
			continue;
		}
		fo(x,0,sz-1)
		{
			int id=M[make_pair(em[i][x],i)];
			fo(y,1,s0) mp[x+1][y]=g[s[y]]-dp[s[y]][id];
		}
		if (sz>s0)
		{
			s[++s0]=0;
			fo(x,1,sz) mp[x][s0]=0;
		}
		LL ans=gsum-KM(sz,s0);
		if (s[s0] || !s0) g[k]=min(g[k],c[i]+ans);
		
		floyd(sz,s0,(s[s0]==0 && s0>0));
		fo(x,0,sz-1)
			dp[k][cnt+x]=min(ans-aug[x+1],inf);
		
		s0-=(s0>0 && s[s0]==0);
	}
}

int main()
{
	scanf("%d",&n);
	fo(i,2,n)
	{
		int x,y;
		scanf("%d %d",&x,&y);
		e[x].push_back(y), e[y].push_back(x);
	}
	scanf("%d",&m);
	fo(i,1,m)
	{
		int tn,tc;
		scanf("%d %d",&tn,&tc);
		fo(j,2,tn)
		{
			int x,y;
			scanf("%d %d",&x,&y);
			em[N+x].push_back(N+y), em[N+y].push_back(N+x);
		}
		fo(j,1,tn) c[N+j]=tc;
		N+=tn;
	}
	
	int tot=0;
	fo(i,1,N)
		for(int go:em[i]) M[make_pair(i,go)]=tot++;
	
	dfs(1,0);
	
	if (g[1]>=inf) puts("impossible"); else printf("%lld\n",g[1]);
}

你可能感兴趣的:(算法_网络流,算法_DP)