小 C 同学认为跑步非常有趣,于是决定制作一款叫做《天天爱跑步》的游戏。《天天爱跑步》是一个养成类游戏,需要玩家每天按时上线,完成打卡任务。

这个游戏的地图可以看作一棵包含 n n 个结点和 n1 n - 1 条边的树,每条边连接两个结点,且任意两个结点存在一条路径互相可达。树上结点编号为从 1 1 n n 的连续正整数。

现在有 m m 个玩家,第 i i 个玩家的起点为 Si S_i ,终点为 Ti T_i 。每天打卡任务开始时,所有玩家在第 0 0 秒同时从自己的起点出发,以每秒跑一条边的速度,不间断地沿着最短路径向着自己的终点跑去,跑到终点后该玩家就算完成了打卡任务。(由于地图是一棵树,所以每个人的路径是唯一的)

小 C 想知道游戏的活跃度,所以在每个结点上都放置了一个观察员。在结点 j j 的观察员会选择在第 Wj W_j 秒观察玩家,一个玩家能被这个观察员观察到当且仅当该玩家在第 Wj W_j 秒也理到达了结点 j j 。小 C 想知道每个观察员会观察到多少人?

注意:我们认为一个玩家到达自己的终点后该玩家就会结束游戏,他不能等待一段时间后再被观察员观察到。即对于把结点 j j 作为终点的玩家:若他在第 Wj W_j 秒前到达终点,则在结点 j j 的观察员不能观察到该玩家;若他正好在第 Wj W_j 秒到达终点,则在结点 j j 的观察员可以观察到这个玩家。

题目链接

LYOI#100

题解

首先考虑在序列上的做法。

假设现在我们的序列是一个 [0,L) [0, L) 的序列,那么对于每一个玩家 (Si,Ti) (S_i, T_i) ,有两种情况。

Si<Ti S_i < T_i 时,只有在序列的 [Si,Ti] [S_i, T_i] 这段区间内的某个 j j 才有可能观察到这个玩家。 考虑在位置 j j 的观察员观察到玩家 i i 的充要条件,就是玩家 i i 到达位置 j j 的时间等于位置 j j 的观察员出现的时间,即 Wj W_j ,又因为相邻两个点的距离和玩家的速度都为 1 1 ,所以上述条件可以用式子表示出来,即 jSi=Wj j - S_i = W_j ,移项后得 jWj=Si j - W_j = S_i

对于一个确定的点,它的 jWj j - W_j 是确定的,也就是说这种情况的问题转化为,设 Xj=jWj X_j = j - W_j 区间加一个数,然后对于区间中的某个节点 j j 询问该位置上的数等于 Xj X_j 的个数。这个问题我们可以用差分化思想解决,即对于需要添加的一个数,在位置 S S 把它加上,在位置 T+1 T + 1 把它删去。

Si>Tj S_i > T_j 时,仍可以用差分化思想解决,只不过这个时候能够观察到的先决条件变为 Sij=Wj S_i - j = W_j 。我们设 Xj=j+Wj X_j = j + W_j ,按照上述做法来做即可。

问题转移到树上,考虑树链剖分。

把每一条链看成是一个序列,链上深度最小的点的位置为 0 0 ,向下位置递增,向上位置为负。

这样可以把每个玩家的运动过程拆成从 S S lca(S,T) lca(S, T) 的第一部分以及从 lca(S,T) lca(S, T) T T 的第二部分。

对于第一部分的每一条链上的情况,把它看作是 S>T S > T 的情况,第二部分的每一条链上的情况,把它看作是 S<T S < T 的情况。

我们记链上每个节点的位置为 id id ,其深度为 dep dep ,两点之间的距离为 dist dist

具体来讲,对于第一种情况,考虑链上的深度最大的由 S S lca lca 的必经点,记为 u u ,由它向上走到链上任意深度大于 lca(S,T) lca(S, T) 的点上的观察员都是可能观察到这位玩家的。由它向上走到的某一个合法的,链上坐标为 j j 的点能观察到这个玩家的充要条件是,由 S S u u 的距离加上从 u u j j 的距离等于 Wj W_j ,即 dist(S,u)+dist(u,j) dist(S, u) + dist(u, j) 。用已知的量表示出来就是 S.depu.dep+u.idj.id=Wj S.dep - u.dep + u.id - j.id = W_j

对于第二种情况,在某条链上,将 S S 的坐标看成一个负值,则它的坐标应该是从 S S 到这条链顶端节点的距离的相反数,我们将原有的条件取反,则得到 S=Wjj S = W_j - j

本来想画个图的,但是感觉让我画图还不如让我干说(其实是写题解的时候找不到画图工具了。。),不过有时间会补上图的。具体实现细节见代码吧。还有我的代码在洛谷上是过不去的。。本身 nlogn n\log n 的复杂度就有点不靠谱。。不过如果您会 ____ 的 ____ 的话树剖当然稳稳过啦。

其实这个题有 O(n+m) O(n + m) 的做法,只不过我目前还没有去看 qwq (其实就是懒找什么借口),如果写出线性做法的话再更新辣。

感觉自己是真的菜,NOIP2016 爆炸以后这么久才填完坑。。

代码

#include <cstdio>
#include <cstring>
#include <climits>
#include <new>
#include <queue>
#include <stack>
#include <vector>
#include <algorithm>
const int MAX_N = 300000;
int n, m;
struct Tag {
int x;
bool del;
Tag(bool del, int x) : x(x), del(del) {}
};
struct Node;
struct Edge;
struct Chain;
struct Node {
Edge *e;
Node *fa, *maxc;
Chain *chain;
int size, dep, dfn, id, w, x, ans;
bool vis;
std::vector<Tag> forwTag, bacwTag;
} N[MAX_N];
struct Edge {
Node *fr, *to;
Edge *ne;
Edge() {}
Edge(Node *fr, Node *to) : fr(fr), to(to), ne(fr->e) {}
} _pool[MAX_N * 2], *_end;
inline void addEdge(int fr, int to) {
N[fr].e = new (_end++) Edge(&N[fr], &N[to]);
N[to].e = new (_end++) Edge(&N[to], &N[fr]);
}
struct Chain {
Node *top, *bot;
std::vector<Node *> nodes;
int len;
} chains[MAX_N];
int chainCnt = 0;
template<typename T>
struct Stack {
T s[MAX_N];
int tot;
Stack() : tot(0) {}
void push(T val) {
s[tot++] = val;
}
void pop() {
tot--;
}
bool empty() {
return tot == 0;
}
T top() {
return s[tot - 1];
}
};
Stack<Node *> s;
inline void split() {
// std::stack<Node *> s;
N[1].dep = 1;
s.push(&N[1]);
while (!s.empty()) {
Node *v = s.top();
if (!v->vis) {
v->vis = true;
v->size = 1;
for (Edge *e = v->e; e; e = e->ne) {
if (e->to != v->fa) {
e->to->fa = v;
e->to->dep = v->dep + 1;
s.push(e->to);
}
}
} else {
for (Edge *e = v->e; e; e = e->ne) {
if (e->to->fa == v) {
v->size += e->to->size;
if (!v->maxc || e->to->size > v->maxc->size) v->maxc = e->to;
}
}
s.pop();
}
}
for (int i = 1; i <= n; i++) N[i].vis = false;
s.push(&N[1]);
N[1].dep = 1;
while (!s.empty()) {
Node *v = s.top();
if (!v->vis) {
v->vis = true;
if (!v->fa || v != v->fa->maxc) {
v->chain = &chains[chainCnt++];
v->chain->top = v;
v->id = 0;
} else {
v->chain = v->fa->chain;
v->id = v->fa->id + 1;
}
v->chain->nodes.push_back(v);
v->chain->bot = v;
for (Edge *e = v->e; e; e = e->ne) {
if (e->to->fa == v) s.push(e->to);
}
} else {
s.pop();
}
}
for (int i = 0; i < chainCnt; i++) chains[i].len = chains[i].nodes.size();
}
inline Node *lca(Node *u, Node *v) {
while (u->chain != v->chain) {
if (u->chain->top->dep < v->chain->top->dep) std::swap(u, v);
u = u->chain->top->fa;
}
if (u->dep > v->dep) return v;
else return u;
}
inline int dist(Node *u, Node *v, Node *p) {
return u->dep + v->dep - p->dep * 2;
}
inline void addTag(bool forw, Chain *chain, int s, int t, int x) {
if (forw) {
if (s > t) return;
chain->nodes[s]->forwTag.push_back(Tag(false, x));
chain->nodes[t]->forwTag.push_back(Tag(true, x));
} else {
if (s < t) return;
chain->nodes[s]->bacwTag.push_back(Tag(false, x));
chain->nodes[t]->bacwTag.push_back(Tag(true, x));
}
}
inline void play(Node *s, Node *t) {
if (s == t) {
if (s->w == 0) s->ans++;
return;
}
Node *p = lca(s, t), *u = s, *v = t;
if (dist(s, p, p) == p->w) p->ans++;
if (s != p) {
while (u->chain != p->chain) {
addTag(false, u->chain, u->id, 0, s->dep - u->dep + u->id);
u = u->chain->top->fa;
}
addTag(false, u->chain, u->id, p->id + 1, s->dep - u->dep + u->id);
}
if (t != p) {
while (v->chain != p->chain) {
addTag(true, v->chain, 0, v->id, (s->dep - p->dep) + (v->chain->top->dep - p->dep));
v = v->chain->top->fa;
}
addTag(true, v->chain, p->id + 1, v->id, (s->dep - p->dep) + (v->chain->top->dep - p->dep));
}
}
int _cnt[MAX_N * 4 + 1], *cnt = _cnt + MAX_N * 2;
inline void solve() {
for (int i = 0; i < chainCnt; i++) {
Chain &chain = chains[i];
// forw on tree
for (int j = 0; j < chain.len; j++) {
for (std::vector<Tag>::const_iterator it = chain.nodes[j]->forwTag.begin(); it != chain.nodes[j]->forwTag.end(); it++) {
if (!it->del) cnt[it->x]++;
}
chain.nodes[j]->ans += cnt[chain.nodes[j]->w - j];
for (std::vector<Tag>::const_iterator it = chain.nodes[j]->forwTag.begin(); it != chain.nodes[j]->forwTag.end(); it++) {
if (it->del) cnt[it->x]--;
}
}
// bacw on tree
for (int j = chain.len - 1; j >= 0; j--) {
for (std::vector<Tag>::const_iterator it = chain.nodes[j]->bacwTag.begin(); it != chain.nodes[j]->bacwTag.end(); it++) {
if (!it->del) cnt[it->x]++;
}
chain.nodes[j]->ans += cnt[chain.nodes[j]->w + j];
for (std::vector<Tag>::const_iterator it = chain.nodes[j]->bacwTag.begin(); it != chain.nodes[j]->bacwTag.end(); it++) {
if (it->del) cnt[it->x]--;
}
}
}
}
int main() {
freopen("running.in", "r", stdin);
freopen("running.out", "w", stdout);
_end = _pool;
scanf("%d %d", &n, &m);
for (int i = 1; i < n; i++) {
int u, v;
scanf("%d %d", &u, &v);
addEdge(u, v);
}
split();
for (int i = 1; i <= n; i++) scanf("%d", &N[i].w);
for (int i = 1; i <= m; i++) {
int u, v;
scanf("%d %d", &u, &v);
play(&N[u], &N[v]);
}
solve();
for (int i = 1; i <= n; i++) printf("%d%c", N[i].ans, i == n ? '\n' : ' ');
return 0;
}