一道非常有意思的(动态)DP 题。

讲道理的我考场上是会 70 分的但是由于 T1 傻逼了太久没调出来(报警了)

假设不修改的情况下答案为 W ,叶子节点的个数为 m,我们把总共的 2^m - 1 种选法分两种情况讨论:

  • W 被选中,这样的情况一共有 2^{m-1} 种,可见只要花费 1 的代价一定可以使根节点的值发生改变,下面不再讨论;
  • W 未被选中,这样的情况一共有 2^{m - 1} - 1 种,我们可以考虑 DP 解决。枚举当前的 C,求出所有花费代价 \leq C 的方案数,显然如果我们可以求得对于 C \in [L - 1, R],差分后就可得到答案数组。

进一步观察我们可以发现对于标号 id < W 的叶子节点,修改为 id + C 一定最优,对于标号 id > W 的叶子节点,修改为 id - C 一定最优。

假设 f_i, g_i 来表示把 i 点修改到 大于 / 小于 W 的概率,显然对于所有叶子节点:

f_u, g_u = \left\{ \begin{aligned} &1, \left\{ \begin{aligned} &\frac 12, &(u - c < W) \\ &0 &(u - c \geq W) \end{aligned} \right. &(u > W) \\ &\left\{ \begin{aligned} &\frac 12, &(u + c > W) \\ &0 &(u + c \leq W) \end{aligned} \right. , 1 &(u < W)\\ &0, 0 &(u = W) \end{aligned} \right.

转移的话也比较直接:

f_u = \left\{ \begin{aligned} 1 - &\prod_{u \rightarrow v} 1 - f_v &\small\text{(奇数层)}\\ &\prod_{u \rightarrow v} f_v &\small\text{(偶数层)} \end{aligned} \right. g_u = \left\{ \begin{aligned} &\prod_{u \rightarrow v} g_v &\small\text{(奇数层)} \\ 1 - &\prod_{u \rightarrow v} 1 - g_v &\small\text{(偶数层)} \end{aligned} \right.

根节点的答案为 \left( 1 - (1-f_1) (1-g_1) \right) \cdot 2^{m-1}

这样可以得到 70 分的好成绩。


现在考虑正解,由于 C 的不断增大,所以每个叶子节点的 fg 值会由 0 改变为 \frac 12,显然这样的改变只有 m-1 次,可以通过适当的方式来处理,也就是动态 DP。

但同时还有个问题,就是奇数层和偶数层的转移不一样。观察可以发现:

\begin{aligned} f_u &= 1 - \prod_{u \rightarrow v} 1 - f_v \\ &= 1 - \prod_{u \rightarrow v} (1 - \prod_{v \rightarrow v'} f_{v'}) \end{aligned}

显然我们可以把在奇数层的两个「1-」分一个到偶数层。

如果我们定义 f'_u 使得

f'_u = \left\{\begin{aligned} &1 - f_u &(dep_u \bmod 2 = 0) \\ &f_u &(dep_u \bmod 2 = 1) \end{aligned}\right.

那么转移可以变为

f'_u = 1 - \prod_{u \rightarrow v} f'_v

g \Rightarrow g' 同理。


下面回到正题动态 DP 了,显然转化后:

\begin{aligned} f'_u = 1 - \prod_{u \rightarrow v} f'_v = 1 - f'_{son_u} \cdot \prod_{v \neq son_u} f'_v \\ g'_u = 1 - \prod_{u \rightarrow v} g'_v = 1 - g'_{son_u} \cdot \prod_{v \neq son_u} g'_v \end{aligned}

写成矩阵转移的形式

\left[ \begin{matrix} -\prod_{v \neq son_u} f'_v & 0 & 1 \\ 0 & -\prod_{v \neq son_u} g'_v & 1 \\ 0 & 0 & 1 \end{matrix} \right] \times \left[ \begin{matrix} f'_{son_u} \\ g'_{son_u} \\ 1 \end{matrix} \right] = \left[ \begin{matrix} f'_u \\ g'_u \\ 1 \end{matrix} \right]

用全局平衡二叉树维护即可。


每次写全局平衡二叉树总是在 debug 上浪费大量时间,看来有必要总结一下了。

经常写错的地方无非两个:

  • 一个是预处理,也就是未修改前的矩阵维护。我一般是事先算好 DP 值,然后在搞成矩阵这样比较暴力的维护
  • 另一个是修改操作,一般说需要注意的就是如果这个点是当前链的根,也就是它 father 的轻儿子,一定要把事先的影响先排除,再把新的换上。

此题中 \prod_{v \neq son_u} f'_v 可能为 0,也就是说全局平衡二叉树在修改的时候可能会有除以 0 的情况。按照和「切树游戏」类似的方法,可以给每个点的轻儿子开一棵线段树,也可以写一个二元组 (x, y) 表示 x \cdot 0^y

代码:

// =================================
//   author: memset0
//   date: 2019.05.02 13:23:23
//   website: https://memset0.cn/
// =================================
#include <bits/stdc++.h>
#define FLAG_F 1926
#define FLAG_G  817
#define ll long long
#define debug(...) ((void)0)
#ifndef debug
#define debug(...) fprintf(stderr,__VA_ARGS__)
#endif
namespace ringo {
template <class T> inline void read(T &x) {
    x = 0; register char c = getchar(); register bool f = 0;
    while (!isdigit(c)) f ^= c == '-', c = getchar();
    while (isdigit(c)) x = x * 10 + c - '0', c = getchar();
    if (f) x = -x;
}
template <class T> inline void print(T x) {
    if (x < 0) putchar('-'), x = -x;
    if (x > 9) print(x / 10);
    putchar('0' + x % 10);
}
template <class T> inline void print(T x, char c) { print(x), putchar(c); }
template <class T> inline void print(T *a, int l, int r, std::string s = "") {
    if (s != "") std::cout << s << ": ";
    for (int i = l; i <= r; i++) print(a[i], " \n"[i == r]);
}

const int N = 2e5 + 10, inf = 1e9, mod = 998244353, half = (mod + 1) >> 1;
bool is_leaf[N];
int fa[N], dep[N];
int n, w, c, m, L, R, LL, RR;
std::vector<int> todo1[N], todo2[N];
int tot = 2, hed[N], to[N << 1], nxt[N << 1];
int f[N], g[N], ans[N], sum[N], son[N], siz[N];

inline int dec(int a, int b) { a -= b; return a < 0 ? a + mod : a; }
inline int inc(int a, int b) { a += b; return a >= mod ? a - mod : a; }
inline int mul(int a, int b) { return (ll)a * b - (ll)a * b / mod * mod; }
inline int inv(int x) { return x < 2 ? 1 : mul(mod - mod / x, inv(mod % x)); }
inline int fpow(int a, int b) { int s = 1; for (; b; b >>= 1, a = mul(a, a)) if (b & 1) s = mul(s, a); return s; }

namespace bst {
    int rt, top, fa[N], stk[N], ch[N][2], bottom[N];

    struct multinum {
        int v, z;
        inline int get() { return z ? 0 : v; }
        inline void operator*=(int p) { p ? v = mul(v, p) : ++z; }
        inline void operator/=(int p) { p ? v = mul(v, inv(p)) : --z; }
        inline void init(int x) { x ? (v = x, z = 0) : (v = 0, z = 1); }
    } lf[N], lg[N];

    struct matrix {
        int a[3][3];
        inline void out() {
            printf("[%d %d %d] [%d %d %d] [%d %d %d]\n",
                a[0][0], a[0][1], a[0][2],
                a[1][0], a[1][1], a[1][2],
                a[2][0], a[2][1], a[2][2]);
        }
    } f[N], g[N], h[N];
    inline matrix operator*(const matrix &a, const matrix &b) {
        matrix c; memset(c.a, 0, sizeof(c.a));
        for (register int i = 0; i < 3; i++)
            for (register int j = 0; j < 3; j++)
                for (register int k = 0; k < 3; k++)
                    c.a[i][j] = (c.a[i][j] + (ll)a.a[i][k] * b.a[k][j]) % mod;
        return c;
    }

    inline void get_value(int x, int &F, int &G) {
        matrix it = f[x] * h[bottom[x]];
        F = it.a[0][0], G = it.a[1][0];
    }

    inline void get_value_with_backup(int x, int &F, int &G, const matrix &backup) {
        matrix it = f[x] * backup;
        F = it.a[0][0], G = it.a[1][0];
    }

    inline void maintain(int u) {
        f[u] = g[u];
        if (ch[u][0]) f[u] = f[ch[u][0]] * f[u];
        if (ch[u][1]) f[u] = f[u] * f[ch[u][1]];
    }   

    int build(int l, int r) {
        // printf("build %d %d => ", l, r), print(stk, l, r);
        int sum = 0, tmp = 0;
        for (int i = l; i <= r; i++) sum += siz[stk[i]] - siz[son[stk[i]]];
        for (int i = l; i <= r; i++) {
            tmp += siz[stk[i]] - siz[son[stk[i]]];
            if ((tmp << 1) >= sum) {
                fa[ch[stk[i]][0] = build(l, i - 1)] = stk[i];
                fa[ch[stk[i]][1] = build(i + 1, r)] = stk[i];
                return maintain(stk[i]), stk[i];
            }
        }
        return 0;
    }

    inline int build(int u) {
        top = 0;
        for (int x = u; x; x = son[x]) stk[++top] = x;
        h[stk[top]].a[0][0] = ringo::f[stk[top]];
        h[stk[top]].a[1][0] = ringo::g[stk[top]];
        h[stk[top]].a[2][0] = 1;
        for (int i = 1; i <= top; i++) {
            bottom[stk[i]] = stk[top];
            if (i == top) {
                g[stk[i]].a[0][0] = 1;
                g[stk[i]].a[1][1] = 1;
                g[stk[i]].a[2][2] = 1;
            } else {
                lf[stk[i]].init(mod - 1);
                lg[stk[i]].init(mod - 1);
                for (int j = hed[stk[i]], v; v = to[j], j; j = nxt[j])
                    if (v != ringo::fa[stk[i]] && v != ringo::son[stk[i]]) {
                        // printf("%d <- %d : %d %d\n", u, v, ringo::f[v], ringo::g[v]);
                        lf[stk[i]] *= ringo::f[v];
                        lg[stk[i]] *= ringo::g[v];
                    }
                // printf(">> %d : [%d %d] [%d %d]\n", stk[i], lf[stk[i]].v, lf[stk[i]].z, lg[stk[i]].v, lg[stk[i]].z);
                g[stk[i]].a[0][0] = lf[stk[i]].get();
                g[stk[i]].a[1][1] = lg[stk[i]].get();
                g[stk[i]].a[0][2] = 1;
                g[stk[i]].a[1][2] = 1;
                g[stk[i]].a[2][2] = 1;
            }
        }
        // printf("==> ");
        // for (int i = 1; i <= top; i++) print(stk[i], " \n"[i == top]);
        return build(1, top);
    }

    inline void modify(int s, int flag) {
        // printf("modify %d %s\n", s, flag == FLAG_F ? "F" : "G");
        matrix backup = h[s];
        for (int x = s, now_f, now_g; x; x = fa[x]) {
            // printf("  >> %d : %d\n", x, ch[fa[x]][0] != x && ch[fa[x]][1] != x && fa[x]);
            if (ch[fa[x]][0] != x && ch[fa[x]][1] != x && fa[x]) {
                if (bottom[x] == s) get_value_with_backup(x, now_f, now_g, backup);
                else get_value(x, now_f, now_g);
                // printf("%d <- %d : [-] %d %d\n", fa[x], x, now_f, now_g);
                lf[fa[x]] /= now_f;
                lg[fa[x]] /= now_g;
                // printf(">>>> [%d %d] [%d %d]\n", lf[fa[x]].v, lf[fa[x]].z, lg[fa[x]].v, lg[fa[x]].z);
            }
            if (x == s) {
                if (flag == FLAG_F) h[x].a[0][0] = half;
                if (flag == FLAG_G) h[x].a[1][0] = half;
            } else {
                g[x].a[0][0] = lf[x].get();
                g[x].a[1][1] = lg[x].get();
            }
            maintain(x);
            if (ch[fa[x]][0] != x && ch[fa[x]][1] != x && fa[x]) {
                get_value(x, now_f, now_g);
                // printf("%d <- %d : [+] %d %d\n", fa[x], x, now_f, now_g);
                lf[fa[x]] *= now_f;
                lg[fa[x]] *= now_g;
                // printf(">>>> [%d %d] [%d %d]\n", lf[fa[x]].v, lf[fa[x]].z, lg[fa[x]].v, lg[fa[x]].z);
            }
        }
    }

}

int init_dfs(int u) {
    siz[u] = 1;
    int res = dep[u] & 1 ? -inf : inf;
    for (int i = hed[u], v; v = to[i], i; i = nxt[i])
        if (v != fa[u]) {
            fa[v] = u, dep[v] = dep[u] + 1;
            res = dep[u] & 1 ? std::max(res, init_dfs(v)) : std::min(res, init_dfs(v));
            siz[u] += siz[v];
            if (siz[v] > siz[son[u]]) son[u] = v;
        }
    if (res == inf || res == -inf) {
        m += is_leaf[u] = true;
        return u;
    }
    return res;
}

int build(int x) {
    for (int u = x; u; u = son[u])
        for (int i = hed[u], v; v = to[i], i; i = nxt[i])
            if (v != fa[u] && v != son[u])
                bst::fa[build(v)] = u;
    return bst::build(x);
}

inline void solve(int u) {
    if (is_leaf[u]) return;
    f[u] = g[u] = 1;
    for (int i = hed[u], v; v = to[i], i; i = nxt[i])
        if (v != fa[u]) {
            solve(v);
            f[u] = mul(f[u], f[v]);
            g[u] = mul(g[u], g[v]);
        }
    f[u] = dec(1, f[u]);
    g[u] = dec(1, g[u]);
}

void main() {
    read(n), read(L), read(R);
    LL = std::max(L - 1, 1), RR = std::max(R, n - 1);
    for (int u, v, i = 1; i < n; i++) {
        read(u), read(v);
        nxt[tot] = hed[u], to[tot] = v, hed[u] = tot++;
        nxt[tot] = hed[v], to[tot] = u, hed[v] = tot++;
    }

    dep[1] = 1, w = init_dfs(1);
    for (int c, u = 1; u <= n; u++) if (is_leaf[u]) {
        if (u > w) {
            c = u - w + 1;
            f[u] = 1, g[u] = 0;
            if (LL <= c && c <= RR) todo1[c].push_back(u);
            else if (c < LL) g[u] = half;
        } else if (u < w) {
            c = w - u + 1;
            g[u] = 1, f[u] = 0;
            if (LL <= c && c <= RR) todo2[c].push_back(u);
            else if (c < LL) f[u] = half;
        } else if (u == w) {
            f[u] = g[u] = 0;
        }
        (dep[u] & 1) ? g[u] = dec(1, g[u]) : f[u] = dec(1, f[u]);
    }
    solve(1);
    bst::rt = build(1);
    bst::get_value(bst::rt, f[1], g[1]);
    // for (int i = 1; i <= n; i++) {
    //  printf("F %d ", i), bst::f[i].out();
    //  printf("G %d ", i), bst::g[i].out();
    //  printf("H %d ", i), bst::h[i].out();
    // }
    // for (int i = 1; i <= n; i++) printf("[%d %d]%c", bst::lf[i].v, bst::lf[i].z, " \n"[i == n]);
    // for (int i = 1; i <= n; i++) printf("[%d %d]%c", bst::lg[i].v, bst::lg[i].z, " \n"[i == n]);
    for (int F, G, i = std::max(L - 1, 1); i <= std::min(R, n - 1); i++) {
        for (auto u : todo1[i]) bst::modify(u, FLAG_G);
        for (auto u : todo2[i]) bst::modify(u, FLAG_F);
        bst::get_value(bst::rt, F, G);
        sum[i] = mul(dec(1, mul(dec(1, F), G)), fpow(2, m - 1));
        // printf("%d => %d : %d %d\n", i, sum[i], F, G);
        // for (int i = 1, F, G; i <= n; i++) bst::get_value(i, F, G), print(F, " \n"[i == n]);
        // for (int i = 1, F, G; i <= n; i++) bst::get_value(i, F, G), print(G, " \n"[i == n]);
    }

    for (int i = L; i <= R; i++) ans[i] = dec(sum[i], sum[i - 1]);
    ans[1] = inc(sum[1], fpow(2, m - 1));
    ans[n] = dec(dec(fpow(2, m - 1), 1), sum[n - 1]);
    print(ans, L, R);
    // print(ans, 1, n, "ans"), print(sum, 1, n, "sum");
}

} signed main() {
#ifdef MEMSET0_LOCAL_ENVIRONMENT
    freopen("1.in", "r", stdin);
    // freopen("1.out", "w", stdout);
#endif
    return ringo::main(), 0;
}

CF809E Surprise me!
上一篇 «
CF908H New Year and Boolean Bridges
» 下一篇