tree

<pre name="code" class="cpp">#include<iostream>  
#include<cstdio>  
#include<algorithm>  
#include<cstring>  
#define N 200003  
using namespace std;  
int n,m;  
int sum[N],key[N],size[N],ch[N][3],fa[N],rev[N],top,st[N];  
int delta[N],num[N];
int isroot(int x)
{
  return ch[fa[x]][0]!=x&&ch[fa[x]][1]!=x;
}
int get(int x)  
{  
 return ch[fa[x]][1]==x;  
}
void update(int x)
{
  if (!x) return;
  size[x]=size[ch[x][0]]+size[ch[x][1]]+1;
  sum[x]=sum[ch[x][0]]+sum[ch[x][1]]+key[x];
}  
void pushdown1(int x)
{
  if (!delta[x]) return;
  if (delta[ch[x][0]]) pushdown1(ch[x][0]);
  if (delta[ch[x][1]]) pushdown1(ch[x][1]);
  if (delta[x]==1)
  {
  	key[x]+=num[x];  
	delta[ch[x][0]]=delta[ch[x][1]]=1;
	num[ch[x][0]]=num[ch[x][1]]=num[x];
	delta[x]=num[x]=0; 
	sum[x]=size[x]*key[x];
  }
  else
  {
  	key[x]*=num[x];  
	delta[ch[x][0]]=delta[ch[x][1]]=2;
	num[ch[x][0]]=num[ch[x][1]]=num[x];
	delta[x]=num[x]=0; 
	sum[x]=size[x]*key[x];
  }
}
void pushdown(int x)
{
  if (!rev[x]) return;
  rev[ch[x][0]]=rev[ch[x][1]]=1; rev[x]=0;
  swap(ch[x][0],ch[x][1]);
}
void rotate(int x)  
{  
  pushdown1(x);
  int y=fa[x]; int z=fa[y]; int which=get(x);  
  if (!isroot(y))  ch[z][ch[z][1]==y]=x;  
  ch[y][which]=ch[x][which^1]; fa[ch[y][which]]=y;  
  ch[x][which^1]=y; fa[y]=x; fa[x]=z; 
  update(y); update(x); 
}  
void splay(int x)  
{  
  top=0; st[++top]=x;  
  for (int i=x;!isroot(i);i=fa[i])  
   st[++top]=fa[i];  
  for (int i=top;i>=1;i--)  
   pushdown(st[i]);  
  while(!isroot(x))  
  {  
    int y=fa[x];  
    if (!isroot(y))  
     rotate(get(x)==get(y)?y:x);  
    rotate(x);  
  }  
}  
void access(int x)  
{  
  int t=0;  
  while(x)  
  {  
   splay(x);   
   ch[x][1]=t;  
   t=x; x=fa[x];  
  }  
}  
void rever(int x)  
{  
  access(x); splay(x); rev[x]^=1;  
}  
void link(int x,int y)  
{  
  rever(x); fa[x]=y; splay(x);  
}  
void cut(int x,int y)
{
  rever(x); access(y); splay(y); ch[y][0]=fa[x]=0;
}
int main()  
{  
  freopen("a.in","r",stdin);
  scanf("%d%d",&n,&m); 
  for (int i=1;i<=n;i++) key[i]=1; 
  for (int i=1;i<=n-1;i++)  
   {  
    int x,y; scanf("%d%d\n",&x,&y);  
    link(x,y);  
   }  
  for (int i=1;i<=m;i++)
   {
   	char c=getchar();
   	if (c=='-')
   	{
   		int x,y,z,k; scanf("%d%d%d%d\n",&x,&y,&z,&k);
   		cut(x,y); link(z,k);
	   }
	else
	if (c=='+')
	{
	  int x,y,z; scanf("%d%d%d\n",&x,&y,&z);
	  rever(x); access(y); splay(y); 
	  if (delta[y]) pushdown1(y);
	  delta[y]=1;	num[y]=z;
	  pushdown(y);
	}
	else
	if (c=='*')
	{
	  int x,y,z; scanf("%d%d%d\n",&x,&y,&z);
	  //cout<<x<<" "<<y<<" "<<z<<endl;
	  rever(x); access(y); splay(y);
	  if (delta[y]) pushdown1(y);
	  delta[y]=2; num[y]=z;
	  pushdown(y);
	}
	else
	{
	  int x,y; scanf("%d%d\n",&x,&y);
	  //cout<<x<<" "<<y<<endl;
	  rever(x); access(y); splay(y);
	  //for (int i=1;i<=n;i++)
	   //cout<<key[i]<<" ";
	  //cout<<endl;
	  printf("%d\n",sum[y]);
	}
	 }  
}   


 

你可能感兴趣的:(tree)