[wc2013]糖果公园

树上莫队算法:

#include 
#include 
#include 
#include 
#include 
#include 

typedef long long int64;
typedef int (*cmp_t) (const void *, const void *);

#ifdef WIN32
#define fmt64 "%I64d"
#else
#define fmt64 "%lld"
#endif
#define oo 0x13131313
#define swap(x, y) ({ int _ = x; x = y; y = _; })

void get(int *x)
{
	char c = getchar();
	for (; '0' > c || c > '9'; c = getchar());
	*x = c - '0', c = getchar();
	for (; '0' <= c && c <= '9'; c = getchar()) *x *= 10, *x += c - '0';
}

typedef struct edge { int t; struct edge *n; } edge;
typedef struct elem { int x, y, t; struct elem *n; } elem;
typedef struct info { int t, f; struct info *n; } info;
typedef int arr32[200100];

edge mem[200005], *adjc = mem, *adj[100005];
elem memo[100005], *me = memo, *modify[100005], *que[100005], *ask[60][60];
info Mem[200005], *memt = Mem, *list[100005];
arr32 v, w, c, cc, ufs, br, ll, rr, lca, bel, sum, type, xx, yy;
bool appr[100005]; int64 ans, anss[100005];
int n, m, Q, B, tot, mt, qt, bt, hehe, haha;

int find(int x) { return ufs[x] == x ? x : (ufs[x] = find(ufs[x])); }

void dfs(int u, int fa)
{
	edge *e;
	br[ll[u] = ++tot] = u;
	for (e = adj[u]; e; e = e->n)
		if (e->t != fa) dfs(e->t, u);
	br[rr[u] = ++tot] = u;
}

void tarjan(int u, int fa)
{
	edge *e; info *i; ufs[u] = u;
	for (e = adj[u]; e; e = e->n)
		if (e->t != fa)
			tarjan(e->t, u), ufs[e->t] = u;
	for (i = list[u]; i; i = i->n)
		if (ufs[i->t])
			lca[i->f] = find(i->t);
}

void trans(int p)
{
	if (appr[p])
		appr[p] = 0, ans -= (int64) v[c[p]] * w[sum[c[p]]--];
	else
		appr[p] = 1, ans += (int64) v[c[p]] * w[++sum[c[p]]];
}

int main()
{
	freopen("park.in", "r", stdin);
	freopen("park.out", "w", stdout);

	int i, j, k;
	get(&n), get(&m), get(&Q);
	for (i = 1; i <= m; ++i)
		get(v + i);
	for (i = 1; i <= n; ++i)
		get(w + i);

	for (i = 1; i < n; ++i) {
		int a, b;
		get(&a), get(&b);
		*adjc = (edge) {b, adj[a]}, adj[a] = adjc++;
		*adjc = (edge) {a, adj[b]}, adj[b] = adjc++;
	}

	for (i = 1; i <= n; ++i)
		get(c + i), cc[i] = c[i];

	/* get brackets sequence */
	dfs(1, 0);
	/* n ^ (2 / 3) */
	for (B = 1; B * B * B < tot; ++B); B *= B;

	for (i = j = bt = 1; i <= tot; ++i, ++j) {
		if (j > B) j = 1, ++bt;
		bel[i] = bt;
	}

	for (i = 0; i < Q; ++i) {
		int x, y;
		get(type + i), get(&x), get(&y);

		if (!type[i])
			*me = (elem) {x, y, i}, modify[mt++] = me++;
		else {
			/* for LCA queries */
			*memt = (info) {y, me - memo, list[x]}, list[x] = memt++;
			*memt = (info) {x, me - memo, list[y]}, list[y] = memt++;

			/* asked intervals */
			if (ll[x] > ll[y]) swap(x, y);
			xx[i] = x, yy[i] = y;
			x = rr[x] < rr[y] ? rr[x] : ll[x]; y = ll[y];
			*me = (elem) {x, y, i}; que[qt++] = me++;
		}
	}

	/* divide queries into groups */
	for (i = qt; i--; ) {
		elem *j = que[i];
		j->n = ask[bel[j->x]][bel[j->y]], ask[bel[j->x]][bel[j->y]] = j;
	}

	/* LCA */
	tarjan(1, 0);

	/* answer */
	for (i = 1; i <= bt; ++i)
		for (j = i; j <= bt; ++j)
			if (ask[i][j]) {
				int l = (i - 1) * B + 1, r = l, f;
				elem *e, *d;
				memset(sum, 0, (m + 1) << 2);
				memset(appr, 0, (n + 1));
				memcpy(c, cc, (n + 1) << 2);
				ans = k = 0; trans(br[l]);

				for (e = ask[i][j]; e; e = e->n) {
					if (l < e->x) for (; l != e->x; ++l) trans(br[l]);
					else if (l > e->x) do --l, trans(br[l]); while (l != e->x);
					if (r > e->y) for (; r != e->y; --r) trans(br[r]);
					else if (r < e->y) do ++r, trans(br[r]); while (r != e->y);

					for (; k < mt && modify[k]->t < e->t; ++k) {
						d = modify[k];
						bool flag = (l <= ll[d->x] && ll[d->x] <= r) ^ (l <= rr[d->x] && rr[d->x] <= r);
						if (flag) trans(d->x);
						c[d->x] = d->y;
						if (flag) trans(d->x);
					}

					anss[e->t] = ans;

					if (f = lca[e - memo], f != xx[e->t] && f != yy[e->t])
						anss[e->t] += (int64) v[c[f]] * w[sum[c[f]] + 1];
				}
			}

	for (i = 0; i < Q; ++i)
		if (type[i])
			printf(fmt64 "\n", anss[i]);

	return 0;
}



70分树上莫队算法:

#include 
#include 
#include 
#include 
#include 
#include 

#define uns unsigned
#define int64 long long
#ifdef WIN32
#define fmt64 "%I64d"
#else
#define fmt64 "%lld"
#endif
#define oo 0x13131313
#define REP(i, n) for (i = 0; i < (n); ++i)
#define maxn 200005

using namespace std;

int n, m, Q, Br;
struct edge { int t; edge *n; } edges[maxn * 2], *adj = edges, *lst[maxn];
typedef int array[maxn];
array v, w, c, x, y, p, xx, yy, br, ll, rr, hehe, ufs, lca, sum, pa;
struct elem { int t, f; elem *n; } elems[maxn * 2], *eptr = elems, *stk[maxn];
int64 result, ans[maxn]; bool appr[maxn];

void link(int a, int b)
{
	*adj = (edge){b, lst[a]}, lst[a] = adj++;
	*adj = (edge){a, lst[b]}, lst[b] = adj++;
}

void dfs(int u, int fa)
{
	br[ll[u] = ++Br] = u;
	for (edge *e = lst[u]; e; e = e->n)
		if (e->t != fa) dfs(e->t, u);
	br[rr[u] = ++Br] = u;
}

void add(int a, int b, int c)
{
	*eptr = (elem){b, c, stk[a]}, stk[a] = eptr++;
	*eptr = (elem){a, c, stk[b]}, stk[b] = eptr++;
}

int find(int x) { return ufs[x] == x ? x : ufs[x] = find(ufs[x]); }

void tarjan(int u, int fa)
{
	ufs[u] = u;
	for (elem *e = stk[u]; e; e = e->n)
		if (ufs[e->t])
			lca[e->f] = find(e->t);
	for (edge *e = lst[u]; e; e = e->n)
		if (e->t != fa)
			tarjan(e->t, u), ufs[e->t] = u;
}

bool nicer(int a, int b)
{
	return hehe[xx[a]] < hehe[xx[b]] || (hehe[xx[a]] == hehe[xx[b]] && yy[a] < yy[b]);
}

void trans(int x)
{
	if (appr[x]) {
		appr[x] = 0;
		result -= (int64)v[c[x]] * w[sum[c[x]]--];
	} else {
		appr[x] = 1;
		result += (int64)v[c[x]] * w[++sum[c[x]]];
	}
}

int64 ask(int x, int y)
{
	int i, j, f; int64 res = 0;
	static int Mark, tot;
	static array mark, a;
	++Mark, tot = 0;
	for (i = x; i; i = pa[i])
		mark[i] = Mark;
	for (f = y; mark[f] < Mark; f = pa[f])
		a[++tot] = c[f];
	a[++tot] = c[f];
	for (i = x; i != f; i = pa[i])
		a[++tot] = c[i];
	sort(a + 1, a + tot + 1);
	for (i = 1; i <= tot; )
		for (j = 0, f = a[i]; a[i] == f && i <= tot; ++i)
			res += (int64)v[f] * w[++j];
	return res;
}

void dfs1(int u, int fa)
{
	edge *e; pa[u] = fa;
	for (e = lst[u]; e; e = e->n)
		if (e->t != fa) dfs1(e->t, u);
}

int main()
{
	freopen("park.in", "r", stdin);
	freopen("park.out", "w", stdout);

	int i, j, k;
	scanf("%d%d%d", &n, &m, &Q);
	for (i = 1; i <= m; ++i)
		scanf("%d", v + i);
	for (i = 1; i <= n; ++i)
		scanf("%d", w + i);
	for (i = 1; i < n; ++i) {
		int a, b;
		scanf("%d%d", &a, &b);
		link(a, b);
	}
	for (i = 1; i <= n; ++i)
		scanf("%d", c + i);

	if (n <= 20000 && m <= 20000) {
		for (dfs1(1, 0); Q--; ) {
			int t, x, y;
			scanf("%d%d%d", &t, &x, &y);
			t ? printf(fmt64"\n", ask(x, y)) : c[x] = y;
		}
		exit(0);
	}

	dfs(1, 0);
	for (i = 1, j = k = 0; i <= Br; ++i) {
		if (i > j) j += 400, ++k;
		hehe[i] = k;
	}
	REP(i, Q) {
		int a, b;
		scanf("%d%d%d", &j, &a, &b);
		if (ll[a] > ll[b]) swap(a, b);
		if (rr[a] < ll[b])
			xx[i] = rr[a], yy[i] = ll[b];
		else
			xx[i] = ll[a], yy[i] = ll[b];
		x[i] = a, y[i] = b, p[i] = i;
		add(x[i], y[i], i);
	}
	tarjan(1, 0);

	sort(p, p + Q, nicer);
	int ll = 0, rr = 0;
	REP(i, Q) {
		int u = p[i];
		int l = xx[u], r = yy[u];
		if (ll < l)
			for (j = ll; j < l; ++j) trans(br[j]);
		else if (ll > l)
			for (j = l; j < ll; ++j) trans(br[j]);
		if (rr > r)
			for (j = rr; j > r; --j) trans(br[j]);
		else if (rr < r)
			for (j = r; j > rr; --j) trans(br[j]);
		ans[u] = result;
		if (lca[u] != x[u] && lca[u] != y[u])
			ans[u] += (int64)v[c[lca[u]]] * w[sum[c[lca[u]]] + 1];
		ll = l, rr = r;
	}
	REP(i, Q)
		printf(fmt64"\n", ans[i]);
}

70分树分块:

#include 
#include 
#include 
#include 
#include 
#include 

#define REP(i, n) for (i = 0; i < (n); ++i)
#define FER(i, j) for (i = lst[j]; i; i = i->n)
#define int64 long long
#ifdef WIN32
#define fmt64 "%I64d"
#else
#define fmt64 "%lld"
#endif
#define oo 0x13131313
#define maxn 90002
#define BLOCK 300

using namespace std; double now;

template void read(T &x)
{
	char c = getchar();
	for (; '0' > c || c > '9'; c = getchar());
	x = c - '0', c = getchar();
	for (; '0' <= c && c <= '9'; c = getchar())
		x = x * 10 + c - '0';
}

int n, m, Q, w[maxn], c[maxn]; short Mod[maxn]; int64 v[maxn];
int pa[maxn], size[maxn], ufs[maxn], dep[maxn], dfn[maxn], Dfn, ca[maxn];
int F, f, mark[maxn], tot, a[maxn], Mark;
int64 buf[maxn * BLOCK], *buft = buf, *ans[maxn], *Ans;

struct edge { int t; edge *n; } edges[maxn * 2], *adj = edges, *lst[maxn], *fr[maxn];
struct block { int a[BLOCK]; } blocks[maxn + BLOCK], *btot = blocks;
struct array { block *a[BLOCK]; int operator[](int); } sum[maxn];

int array::operator[](int b) { return --b, a[b / BLOCK]->a[Mod[b]]; }

void inherit(array &a, array &b, int pos)
{
	block *&p = a.a[--pos / BLOCK];
	memcpy(&a, &b, sizeof(array)), memcpy(btot, p, sizeof(block));
	p = btot++, ++p->a[Mod[pos]];
}

int find(int x)
{
	int f, g;
	for (f = x; ufs[f] != f; f = ufs[f]);
	for (; ufs[x] != x; x = g) g = ufs[x], ufs[x] = f;
	return f;
}

void dfs(int u, int fa)
{
	edge *e; int f = -1;
	inherit(sum[u], sum[fa], c[u]), dep[u] = dep[fa] + 1, dfn[u] = ++Dfn;
	FER(e, u) if (e->t != fa)
	{
		dfs(e->t, u), fr[e->t] = e, pa[e->t] = u;
		if (!~f || size[f] + size[e->t] > BLOCK << 1)
			f = e->t;
		else
			size[f] += size[e->t], ca[f] = u, ufs[e->t] = f;
	}
	ca[u] = ufs[u] = u, size[u] = 1;
	if (~f && size[f] < BLOCK)
		size[u] += size[f], ufs[f] = u;
}

void bfs(int S)
{
	static int q[maxn]; int h, t; edge *e;
	for (q[h = t = S] = 0; h; h = q[h])
	{
		Ans[h] = Ans[pa[h]] + v[c[h]] * w[sum[F][c[h]] - (sum[f][c[h]] << 1) + sum[h][c[h]] + (c[f] == c[h])];
		FER(e, h) if (e->t != pa[h]) q[t = q[t] = e->t] = 0;
	}
}

void init()/*pretreat for the answers between blocks*/
{
	int i, j; edge *e;
	REP(i, BLOCK) sum->a[i] = btot++;
	dfs(1, 0);
	fprintf(stderr, "%.2lf\n", (clock() - now) / CLOCKS_PER_SEC);
	REP(i, n) if (find(i + 1) == i + 1)
	{
		F = f = ca[i + 1];
		if (ans[f]) continue;
		buft += n, Ans = ans[f] = buft - n - 1;
		Ans[f] = v[c[f]] * w[1];
		FER(e, f) if (e->t != pa[f]) bfs(e->t);
		for (j = f; fr[j]; j = f)
		{
			f = pa[j];
			Ans[f] = Ans[j] + v[c[f]] * w[sum[F][c[f]] - sum[f][c[f]] + 1];
			for (e = fr[j]->n; e; e = e->n)
				if (e->t != pa[f]) bfs(e->t);
		}
	}
}

int64 ask(int x, int y)
{
	int i, j, f; int64 res = 0; ++Mark, tot = 0;
	for (i = x; i; i = pa[i]) mark[i] = Mark;
	for (f = y; mark[f] < Mark; f = pa[f]) a[++tot] = c[f];
	a[++tot] = c[f];
	for (i = x; i != f; i = pa[i]) a[++tot] = c[i];
	sort(a + 1, a + tot + 1);
	for (i = 1; i <= tot; )
		for (j = 0, f = a[i]; a[i] == f && i <= tot; ++i)
			res += v[f] * w[++j];
	return res;
}

void Dfs(int u, int fa)
{
	edge *e; pa[u] = fa;
	for (e = lst[u]; e; e = e->n)
		if (e->t != fa) Dfs(e->t, u);
}

void input()
{
	int i; scanf("%d%d%d", &n, &m, &Q);
	REP(i, m) read(v[i + 1]);
	REP(i, n) read(w[i + 1]), Mod[i] = i % BLOCK;
	REP(i, n - 1)
	{
		int a, b; read(a), read(b);
		*adj = (edge){b, lst[a]}, lst[a] = adj++;
		*adj = (edge){a, lst[b]}, lst[b] = adj++;
	}
	REP(i, n) read(c[i + 1]);
	if (n <= 20000 && m <= 20000)
	{
		for (Dfs(1, 0); Q--; )
		{
			int t, x, y; read(t), read(x), read(y);
			t ? printf(fmt64"\n", ask(x, y)) : c[x] = y;
		}
		exit(0);
	}
}

int LCA(int x, int y)
{
	for (; x != y; )
		ufs[x] == ufs[y] ?
			dep[x] > dep[y] ? x = pa[x] : y = pa[y] :
			dep[ufs[x]] > dep[ufs[y]] ? x = pa[ufs[x]] : y = pa[ufs[y]];
	return x;
}

int main()
{
	freopen("park.in", "r", stdin);
	freopen("park.out", "w", stdout);

	now = clock();
	input();
	init();
	fprintf(stderr, "%.2lf\n", (clock() - now) / CLOCKS_PER_SEC);
	for (; Q--; )
	{
		int t, x, y; read(t), read(x), read(y);
		if (!t) exit(0); if (dfn[y] < dfn[x]) swap(x, y);
		int f = ca[ufs[x]], g = LCA(x, y);
		int64 Ans = ans[f][y];
		if (dep[f] < dep[g])
		{
			for (t = pa[g]; t != pa[f]; t = pa[t])
				Ans -= v[c[t]] * w[sum[y][c[t]] - sum[t][c[t]] + 1];
			f = g;
		}
		for (t = x; t != f; t = pa[t])
			Ans += v[c[t]] * w[sum[y][c[t]] - (sum[g][c[t]] << 1) + sum[t][c[t]] + (c[g] == c[t])];
		printf(fmt64"\n", Ans);
	}
	fprintf(stderr, "%.2lf\n", (clock() - now) / CLOCKS_PER_SEC);
}


你可能感兴趣的:(题目)