给定一棵 n n 个节点的树,每个节点上有颜色,要求支持两种操作:

  1. 询问两点 (u,v) (u, v) 之间有多少段颜色
  2. 将两点 (u,v) (u, v) 之间所有的节点的颜色染成 c c

题解

树链剖分线段树维护,注意线段树每个节点要维护这个区间最左边和最右边位置上节点的颜色,两段区间的答案合并时如果左区间的最右节点和右区间的最左节点颜色相同的话答案要减 1 1

树剖统计答案的时候也要注意这一点。

代码

#include <cstdio>
#include <cstring>
#include <climits>
#include <stack>
#include <algorithm>
const int MAX_N = 100000;
struct Node;
struct Edge;
int a[MAX_N + 1];
struct Node {
Edge *e;
Node *fa, *top, *maxc;
int col, dep, size, dfn;
bool vis;
} N[MAX_N + 1];
struct Edge {
Node *to;
Edge *ne;
Edge(Node *fr, Node *to) : to(to), ne(fr->e) {}
};
inline void addEdge(int fr, int to) {
N[fr].e = new Edge(&N[fr], &N[to]);
N[to].e = new Edge(&N[to], &N[fr]);
}
void dfs1(Node *v) {
v->size = 1;
v->vis = true;
for (Edge *e = v->e; e; e = e->ne) {
if (!e->to->vis) {
e->to->dep = v->dep + 1;
e->to->fa = v;
dfs1(e->to);
v->size += e->to->size;
if (!v->maxc || e->to->size > v->maxc->size) v->maxc = e->to;
}
}
}
void dfs2(Node *v) {
static int ts = 0;
v->dfn = ++ts;
if (!v->fa || v != v->fa->maxc) v->top = v;
else v->top = v->fa->top;
if (v->maxc) dfs2(v->maxc);
for (Edge *e = v->e; e; e = e->ne) {
if (e->to->fa == v && e->to != v->maxc) dfs2(e->to);
}
}
void split(Node *root) {
root->dep = 1;
dfs1(root);
dfs2(root);
}
struct Segt {
int l, r;
Segt *lc, *rc;
int sum, tag;
int lcol, rcol;
Segt(int l, int r, Segt *lc, Segt *rc, int sum = 0, int lcol = -1, int rcol = -1) :
l(l), r(r), lc(lc), rc(rc), sum(sum), tag(-1), lcol(lcol), rcol(rcol) {}
void maintain() {
if (lc->rcol == rc->lcol) sum = lc->sum + rc->sum - 1;
else sum = lc->sum + rc->sum;
lcol = lc->lcol, rcol = rc->rcol;
}
void change(int col) {
sum = 1;
lcol = rcol = col;
tag = col;
}
void pushDown() {
if (tag != -1) {
lc->change(tag);
rc->change(tag);
tag = -1;
}
}
void modify(int l, int r, int col) {
if (l > this->r || r < this->l) return;
else if (l <= this->l && r >= this->r) change(col);
else pushDown(), lc->modify(l, r, col), rc->modify(l, r, col), maintain();
}
int query(int l, int r) {
if (l <= this->l && r >= this->r) return sum;
else {
pushDown();
int mid = this->l + (this->r - this->l) / 2;
if (l > mid) return rc->query(l, r);
else if (r <= mid) return lc->query(l, r);
else {
int lsum = lc->query(l, r), rsum = rc->query(l, r);
if (lc->rcol == rc->lcol) return lsum + rsum - 1;
else return lsum + rsum;
}
}
}
int query(int pos) {
if (l == r) return lcol;
else {
pushDown();
int mid = l + (r - l) / 2;
if (pos <= mid) return lc->query(pos);
else return rc->query(pos);
}
}
static Segt *build(int l, int r) {
if (l == r) return new Segt(l, r, NULL, NULL, 1, a[l], a[r]);
else {
int mid = l + (r - l) / 2;
Segt *v = new Segt(l, r, build(l, mid), build(mid + 1, r));
v->maintain();
return v;
}
}
} *segt;
inline void change(int a, int b, int c) {
Node *u = &N[a], *v = &N[b];
while (u->top != v->top) {
if (u->top->dep < v->top->dep) std::swap(u, v);
segt->modify(u->top->dfn, u->dfn, c);
u = u->top->fa;
}
if (u->dep > v->dep) std::swap(u, v);
segt->modify(u->dfn, v->dfn, c);
}
inline int query(int a, int b) {
Node *u = &N[a], *v = &N[b];
int res = 0;
while (u->top != v->top) {
if (u->top->dep < v->top->dep) std::swap(u, v);
res += segt->query(u->top->dfn, u->dfn);
if (segt->query(u->top->dfn) == segt->query(u->top->fa->dfn)) res--;
u = u->top->fa;
}
if (u->dep > v->dep) std::swap(u, v);
res += segt->query(u->dfn, v->dfn);
return res;
}
int main() {
// freopen("data.in", "r", stdin);
int n, m;
scanf("%d %d", &n, &m);
for (int i = 1; i <= n; i++) scanf("%d", &N[i].col);
for (int i = 1; i < n; i++) {
int u, v;
scanf("%d %d", &u, &v);
addEdge(u, v);
}
split(&N[1]);
for (int i = 1; i <= n; i++) {
a[N[i].dfn] = N[i].col;
}
segt = Segt::build(1, n);
for (int i = 1; i <= m; i++) {
char cmd[2];
scanf("%s", cmd);
if (cmd[0] == 'C') {
int a, b, c;
scanf("%d %d %d", &a, &b, &c);
change(a, b, c);
} else if (cmd[0] == 'Q') {
int a, b;
scanf("%d %d", &a, &b);
printf("%d\n", query(a, b));
}
}
return 0;
}