题解 [AT_dp_v]
题目大意
给出结点数为 $n$ 的树,求以每个结点为根,对包含根的联通块计数。
思路
显然是换根 dp。
首先考虑一个根节点的情况,设 $f_u$ 表示以 $u$ 为根节点的子树,强制选 $u$ 的方案数,于是有:
$$ f_u=\prod\limits_{v\in son_u}(f_v+1) $$考虑换根,设 $ans_u$ 表示以 $u$ 为整棵树的根,强制选 $u$ 的方案数。于是有:
$$ ans_u=\frac{ans_{fa_u}}{f_u+1}\times f_u $$然而要对 $m$ 取模,但 $m$ 不一定是质数,所以没法做除法。
考虑暴力相乘,设 $g_u$ 为除去 $u$ 的子树,以 $u$ 为根节点,强制选 $u$ 的方案数,得到:
$$ g_u=g_{fa_u}\prod_{v\in fa_u \operatorname{and}v\ne u}(f_v+1) $$于是问题变成如何求 $\prod_{v\in fa_u \operatorname{and}v\ne u}(f_v+1)$。
想不到前缀后缀积怎么办?
不难想到分治,设 $fa_u$ 的儿子区间 $[l,r]$ 表示第 $l$ 到第 $r$ 个儿子,设区间中点 $mid$。
先将区间 $[mid+1,r]$ 的 $f_v+1$ 乘起来,记到 $mul$ 中,然后递归区间 $[l,mid]$,再把 $mul$ 还原,将区间 $[l,mid]$ 的 $f_v+1$ 乘起来,递归区间 $[mid+1,r]$。
当 $l=r$ 时 $mul$ 即为 $\prod_{v\in fa_u \operatorname{and}v\ne u}(f_v+1)$,于是就得到了 $g_u$。
最终答案即为 $f_u\times g_u$。
空间虽然小了点,但时间复杂度为 $O(n\log n)$。
Code:
#include <bits/stdc++.h>
#define ll long long
#define ull unsigned long long
#define lp(i, j, n) for(int i = j; i <= n; ++i)
#define dlp(i, n, j) for(int i = n; i >= j; --i)
#define mst(n, v) memset(n, v, sizeof(n))
#define mcy(n, v) memcpy(n, v, sizeof(v))
#define INF 1e18
#define MAX4 0x3f3f3f3f
#define MAX8 0x3f3f3f3f3f3f3f3f
#define pii pair<int, int>
#define pll pair<ll, ll>
#define co(x) cerr << (x) << ' '
#define cod(x) cerr << (x) << endl
#define fi first
#define se second
#define eps 1e-8
#define lc(x) ((x) << 1)
#define rc(x) ((x) << 1 ^ 1)
#define pb(x) emplace_back(x)
using namespace std;
const int N = 100010;
int n, m;
int f[N], g[N];
struct edge { int v, nxt; } E[N << 1];
int en, hd[N];
void add(int u, int v) { E[++en] = { v, hd[u] }, hd[u] = en; }
void dfs1(int u, int fa) {
f[u] = 1;
for(int i = hd[u]; i; i = E[i].nxt) {
int v = E[i].v;
if(v == fa) continue;
dfs1(v, u), f[u] = 1ll * f[u] * (f[v] + 1) % m;
}
}
int son[N], tot;
ll mul;
void solve(int l, int r) {
if(l == r) return g[son[l]] = (mul + 1) % m, void();
ll t = mul; int mid = l + r >> 1;
lp(i, mid + 1, r) mul = mul * (f[son[i]] + 1) % m;
solve(l, mid);
mul = t; lp(i, l, mid) mul = mul * (f[son[i]] + 1) % m;
solve(mid + 1, r);
}
void dfs2(int u, int fa) {
tot = 0, mul = g[u];
for(int i = hd[u]; i; i = E[i].nxt) {
if(E[i].v != fa) son[++tot] = E[i].v;
}
if(tot >= 1) solve(1, tot);
for(int i = hd[u]; i; i = E[i].nxt) {
if(E[i].v != fa) dfs2(E[i].v, u);
}
}
signed main() {
// freopen(".in", "r", stdin);
// freopen(".out", "w", stdout);
#ifndef READ
ios::sync_with_stdio(false);
cin.tie(0);
#endif
cin >> n >> m;
int u, v;
lp(i, 1, n - 1) cin >> u >> v, add(u, v), add(v, u);
dfs1(1, 0), g[1] = 1, dfs2(1, 0);
lp(i, 1, n) cout << 1ll * f[i] * g[i] % m << endl;
return 0;
}
暂无评论,快来抢沙发!