定义区间树为线段树的拓展,即每次断开的位置可以不是线段的中心。
给定一个 的区间树和 次询问,每次询问包含一个正整数 , 你需要求出有多少区间的时间复杂度恰好等于 。
。
题解
在线回答询问无意义,考虑利用生成函数处理出所有询问的答案。
询问 选中的线段( 的线段,而非经过的线段),LCA 往两侧深度单调减(且中间平的一段的长度至多为 )。
求出往两侧单调的生成函数合并,类似:
其中 分别表示左/右端点和当前线段的左/右端点相同的线段(不包括完全相同的情况)的生成函数, 是当前节点, 是左儿子, 是右儿子。
处理一些平凡情况,在断点计算贡献:
就能做到 复杂度。
进一步优化复杂度,考虑边分治:假设当前处理子树 ,边分的子树 。递归处理出 的路径上的 。下面考虑 对路径上点的 和 的 的贡献。
前者可以分别考虑 的贡献,通过两次卷积得到。后者则是路径上的 通过一定位移得到。
至于处理可以分别在两侧继续边分。复杂度 。
代码
#include <bits/stdc++.h>
using namespace std;
const int N = 2e5 + 10, mod = 998244353;
int _, n, m, cnt, siz[N], vis[N], mid[N], ch[N][2], fa[N], l[N], r[N], dep[N], rev[N << 2];
struct z {
int x;
z(int x = 0) : x(x) {}
friend inline z operator*(z a, z b) { return (long long)a.x * b.x % mod; }
friend inline z operator-(z a, z b) { return (a.x -= b.x) < 0 ? a.x + mod : a.x; }
friend inline z operator+(z a, z b) { return (a.x += b.x) >= mod ? a.x - mod : a.x; }
} w[N << 2];
vector<z> ans;
inline z fpow(z a, int b) {
z s = 1;
for (; b; b >>= 1, a = a * a)
if (b & 1) s = s * a;
return s;
}
void dft(vector<z> &a, int lim) {
a.resize(lim);
for (int i = 0; i < lim; i++)
if (i < rev[i]) swap(a[i], a[rev[i]]);
for (int len = 1; len < lim; len <<= 1)
for (int i = 0; i < lim; i += (len << 1))
for (int j = 0; j < len; j++) {
z x = a[i + j], y = a[i + j + len] * w[j + len];
a[i + j] = x + y, a[i + j + len] = x - y;
}
}
vector<z> operator+(vector<z> a, const vector<z> &b) {
a.resize(max(a.size(), b.size()));
for (int i = 0; i < b.size(); i++) a[i] = a[i] + b[i];
return a;
}
vector<z> operator-(vector<z> a, const vector<z> &b) {
a.resize(max(a.size(), b.size()));
for (int i = 0; i < b.size(); i++) a[i] = a[i] - b[i];
return a;
}
vector<z> operator*(vector<z> a, vector<z> b) {
int len = a.size() + b.size() - 1, lim = 1, k = 0;
while (lim < len) lim <<= 1, ++k;
for (int i = 0; i < lim; i++) rev[i] = rev[i >> 1] >> 1 | ((i & 1) << (k - 1));
dft(a, lim), dft(b, lim);
for (int i = 0; i < lim; i++) a[i] = a[i] * b[i];
dft(a, lim), reverse(&a[1], &a[lim]);
z inv = fpow(lim, mod - 2);
for (int i = 0; i < lim; i++) a[i] = a[i] * inv;
return a.resize(len), a;
}
void shift(z x, vector<z> &dst, size_t dta) {
dst.resize(max(dst.size(), dta + 1));
dst[dta] = dst[dta] + x;
}
void shift(const vector<z> &src, vector<z> &dst, size_t dta) {
dst.resize(max(dst.size(), src.size() + dta));
for (int i = 0; i < src.size(); i++) dst[i + dta] = dst[i + dta] + src[i];
}
void dfsInit(int &u, int l, int r, int dep) {
if (l == r) {
u = n - 1 + l;
} else {
u = ++cnt;
dfsInit(ch[u][0], l, mid[u], dep + 1), fa[ch[u][0]] = u;
dfsInit(ch[u][1], mid[u] + 1, r, dep + 1), fa[ch[u][1]] = u;
}
::l[u] = l, ::r[u] = r, ::dep[u] = dep;
}
int calcSize(int u) {
if (!u || vis[u]) return 0;
return siz[u] = 1 + calcSize(ch[u][0]) + calcSize(ch[u][1]);
}
pair<int, int> findSubTree(int u, int lim) {
if (!u || vis[u]) return {-1, -1};
if (siz[u] < lim) return {u, siz[u]};
pair<int, int> x = ch[u][0] ? findSubTree(ch[u][0], lim) : make_pair(-1, -1);
pair<int, int> y = ch[u][1] ? findSubTree(ch[u][1], lim) : make_pair(-1, -1);
return x.second > y.second ? x : y;
}
void calc(bool fl, int u, int mov, vector<z> &f) {
if (vis[u] || l[u] == r[u]) return shift(1, f, mov + 1);
shift(1, f, mov + 1);
shift(mod - 1, f, mov + 3);
calc(fl, ch[u][0], mov + (fl ? 2 : 1), f);
calc(fl, ch[u][1], mov + (fl ? 1 : 2), f);
}
pair<vector<z>, vector<z>> fuck(int u) {
if (vis[u] || l[u] == r[u]) return {{0, 1}, {0, 1}};
vector<z> Ll, Lr, Rl, Rr, Lu, Ru;
Lu = Ru = vector<z>{0, 1, 0, mod - 1};
tie(Ll, Rl) = fuck(ch[u][0]);
tie(Lr, Rr) = fuck(ch[u][1]);
shift(Ll, Lu, 1), shift(Lr, Lu, 2);
shift(Rr, Ru, 1), shift(Rl, Ru, 2);
shift(Rl * Lr, ans, dep[u]);
return {Lu, Ru};
}
pair<vector<z>, vector<z>> solve(int u) {
if (vis[u] || l[u] == r[u]) return {{0, 1}, {0, 1}};
int siz = calcSize(u);
int v = findSubTree(u, (siz * 2) / 4).first;
if (v == -1) return fuck(u);
vector<z> Lu, Ru, Lv, Rv, Lt, Rt, T;
tie(Lv, Rv) = solve(v);
vis[v] = 1;
tie(Lu, Ru) = solve(u);
vis[v] = 0;
int lmov = 0, rmov = 0;
for (int p = v; p != u; lmov += ch[fa[p]][0] == p ? 1 : 2, rmov += ch[fa[p]][0] == p ? 2 : 1, p = fa[p]) {
int f = fa[p], q = ch[f][0] == p ? ch[f][1] : ch[f][0];
if (ch[f][0] == p) {
T.clear(), calc(0, q, 0, T), shift(T, Lt, rmov + dep[f] - dep[u]);
} else {
T.clear(), calc(1, q, 0, T), shift(T, Rt, lmov + dep[f] - dep[u]);
}
}
shift(mod - 1, Lv, 1);
shift(mod - 1, Rv, 1);
shift(Lv * Rt, ans, dep[u]);
shift(Rv * Lt, ans, dep[u]);
shift(Lv, Lu, lmov);
shift(Rv, Ru, rmov);
return {Lu, Ru};
}
void solution() {
for (int i = 1; i < n; i++) cin >> mid[i];
dfsInit(_, 1, n, 1);
solve(1);
for (int i = 1; i < (n << 1); i++) shift(1, ans, dep[i]);
for (int i = 1; i < n; i++) shift(mod - 1, ans, dep[i] + 2);
for (int i = 1, q; i <= m; i++) {
cin >> q;
cout << (q < ans.size() ? ans[q].x : 0) << '\n';
}
}
void recycle() {
cnt = 0;
ans.clear();
memset(ch, 0, sizeof(ch));
memset(fa, 0, sizeof(fa));
}
int main() {
cin.tie(0)->sync_with_stdio(0);
for (int len = 1; len < (N << 1); len <<= 1) {
z wn = fpow(3, (mod - 1) / (len << 1));
w[len] = 1;
for (int i = 1; i < len; i++) w[i + len] = w[i + len - 1] * wn;
}
while (cin >> n >> m) solution(), recycle();
}