山有木兮木有枝
心悦君兮君不知

DP 。用 f_{u, 0} 表示 u 点被染黑的概率,f_{u, 1} 表示 u 点被染白,但是其祖先中有点被染黑的概率,用 f_{u, 2} 表示 u 点被染白,且其祖先没有被染黑的点的概率。所求答案即 \displaystyle{\sum_{i=1}^n f_{i, 0}}

考虑转移,对于一次线段树上的修改操作,可以把线段树上的点分为以上几类。

  • 在路径的终点的点

对应 p[u].l == l && p[u].r == r

这样的点的无论如何都会被染黑,即

\left( \begin{matrix} f_0 \\ f_1 \\ f_2 \end{matrix} \right) \Rightarrow \left( \begin{matrix} f_0 + \frac 12 f_1 + \frac 12 f_2 \\ \frac 12 f_1 \\ \frac 12 f_2 \end{matrix} \right)
  • 在路径上但不在终点的点

这样的点的标记会被 pushdown 且自己没有被打上标记

\left( \begin{matrix} f_0 \\ f_1 \\ f_2 \end{matrix} \right) \Rightarrow \left( \begin{matrix} \frac 12 f_0 \\ \frac 12 f_1 \\ \frac 12 f_0 + \frac 12 f_1 + f_2 \\ \end{matrix} \right)
  • 在路径上的点旁边的点

这样的点会被其祖先的标记 pushdown 到,但由于自己没有被经过所以标记不会被 pushdown

\left( \begin{matrix} f_0 \\ f_1 \\ f_2 \end{matrix} \right) \Rightarrow \left( \begin{matrix} f_0 + \frac 12 f_1 \\ \frac 12 f_1 \\ f_2 \\ \end{matrix} \right)
  • 在终点的子树里的点

这样的点的祖先中有个点被打上了标记,要更新前面的状态

\left( \begin{matrix} f_0 \\ f_1 \\ f_2 \end{matrix} \right) \Rightarrow \left( \begin{matrix} f_0 \\ f_1 + \frac 12 f_2\\ \frac 12 f_2 \\ \end{matrix} \right)

上面的这些状态转移可以用矩阵完成。

最后一种可以对 dfs 序建一棵线段树来进行区间乘矩阵,也可以直接打成懒标记来下放。前者复杂度为 O(n \log^2 n) ,后者复杂度为 O(n \log n)

也来说点无用的废话呢 ...

考试的时候以为这题并不可做,只花了 15min 写了暴力就没有再想它。

一直在调 T3 的 DP ,然而因为刚开始写了个假的 + T1 发现暴力写挂了,并没有调出来。

T2 考完出来才知道并不难,仔细想想很快就会了。然而一切都已经结束了呢。

NOIP 和 ZJOI Round 1 都已经考过了,自己的分数也的确有点过低了,

有些时候尽管还是很难以接受自己的失败,但是总归是要面对的:

加油吧。

代码:

// =================================
//   author: memset0
//   website: https://memset0.cn/
// =================================
#include <bits/stdc++.h>
#define ll long long
#define debug(...) ((void)0)
#ifndef debug
#define debug(...) fprintf(stderr,__VA_ARGS__)
#endif
namespace ringo {
template <class T> inline void read(T &x) {
    x = 0; register char c = getchar(); register bool f = 0;
    while (!isdigit(c)) f ^= c == '-', c = getchar();
    while (isdigit(c)) x = x * 10 + c - '0', c = getchar();
    if (f) x = -x;
}
template <class T> inline void print(T x) {
    if (x < 0) putchar('-'), x = -x;
    if (x > 9) print(x / 10);
    putchar('0' + x % 10);
}
template <class T> inline void print(T x, char c) { print(x), putchar(c); }

const int N = 1e5 + 10, L = 3, mod = 998244353;
int n, m, ans;

inline int dec(int a, int b) { a -= b; return a < 0 ? a + mod : a; }
inline int sub(int a, int b) { a += b; return a >= mod ? a - mod : a; }
inline int mul(int a, int b) { return (ll)a * b - (ll)a * b / mod * mod; } 

inline int sub(int a, int b, int c) { return sub(a, sub(b, c)); }
inline int half(int x) { return x & 1 ? (x + mod) >> 1 : x >> 1; }

struct matrix {
#define f0 a[0][0]
#define f1 a[1][0]
#define f2 a[2][0]
    int a[L][L];
    inline matrix() {}
    inline matrix(char c) { memset(a, c, sizeof(a)); }
    inline matrix(int x, int y, int z) { a[0][0] = x, a[1][0] = y, a[2][0] = z; }
    inline void out() const {
        for (int i = 0; i < 3; i++)
            printf("{%d %d %d}%c", a[i][0], a[i][1], a[i][2], " \n"[i == 2]);
    }
    friend inline matrix operator * (const matrix &a, const matrix &b) {
        matrix c(0);
        for (register int i = 0; i < L; i++)
            for (register int j = 0; j < L; j++)
                for (register int k = 0; k < L; k++)
                    c.a[i][j] = (c.a[i][j] + (ll)a.a[i][k] * b.a[k][j]) % mod;
        return c;
    }
    inline matrix move0() { return matrix(sub(f0, half(f1), half(f2)), half(f1), half(f2)); }
    inline matrix move1() { return matrix(half(f0), half(f1), sub(half(f0), half(f1), f2)); }
    inline matrix move2() { return matrix(sub(f0, half(f1)), half(f1), f2); }
#undef f0
#undef f1
#undef f2
} I, A, pow[N];

void matrix_init() {
    for (int i = 0; i < L; i++) I.a[i][i] = 1;
    pow[0] = I;
    A.a[0][0] = 1, A.a[0][1] = 0, A.a[0][2] =  0;
    A.a[1][0] = 0, A.a[1][1] = 1, A.a[1][2] = (mod + 1) >> 1;
    A.a[2][0] = 0, A.a[2][1] = 0, A.a[2][2] = (mod + 1) >> 1;
    for (int i = 1; i <= m; i++) pow[i] = pow[i - 1] * A;
}

struct node {
    int l, r, mid, tag;
    matrix x;
} p[N << 2];

void move0(int u) {
//  printf(">> move0 %d <= %d %d %d\n", u, p[u].x.a[0][0], p[u].x.a[1][0], p[u].x.a[2][0]);
    ans = dec(ans, p[u].x.a[0][0]);
    p[u].x = p[u].x.move0();
    ans = sub(ans, p[u].x.a[0][0]);
//  printf(">> move0 %d => %d %d %d\n", u, p[u].x.a[0][0], p[u].x.a[1][0], p[u].x.a[2][0]);
}

void move1(int u) {
//  printf(">> move1 %d <= %d %d %d\n", u, p[u].x.a[0][0], p[u].x.a[1][0], p[u].x.a[2][0]);
    ans = dec(ans, p[u].x.a[0][0]);
    p[u].x = p[u].x.move1();
    ans = sub(ans, p[u].x.a[0][0]);
//  printf(">> move1 %d => %d %d %d\n", u, p[u].x.a[0][0], p[u].x.a[1][0], p[u].x.a[2][0]);
}

void move2(int u) {
//  printf(">> move2 %d <= %d %d %d\n", u, p[u].x.a[0][0], p[u].x.a[1][0], p[u].x.a[2][0]);
    ans = dec(ans, p[u].x.a[0][0]);
    p[u].x = p[u].x.move2();
    ans = sub(ans, p[u].x.a[0][0]);
//  printf(">> move2 %d => %d %d %d\n", u, p[u].x.a[0][0], p[u].x.a[1][0], p[u].x.a[2][0]);
}

void pushup(int u, int k) {
//  printf("pushup %d %d\n", u, k);
    p[u].tag += k;
    ans = dec(ans, p[u].x.a[0][0]);
    p[u].x = pow[k] * p[u].x;
    ans = sub(ans, p[u].x.a[0][0]);
}

void pushdown(int u) {
    if (p[u].l == p[u].r) return;
    pushup(u << 1, p[u].tag);
    pushup(u << 1 | 1, p[u].tag);
    p[u].tag = 0;
}

void build(int u, int l, int r) {
    p[u].l = l, p[u].r = r, p[u].mid = (l + r) >> 1;
    p[u].x.a[2][0] = 1;
    if (l == r) { return; }
    build(u << 1, l, p[u].mid);
    build(u << 1 | 1, p[u].mid + 1, r);
}

void modify(int u, int l, int r) {
//  printf(">> modify %d %d %d [%d %d]\n", u, l, r, p[u].l, p[u].r);
    pushdown(u);
    if (p[u].l == l && p[u].r == r) {
        move0(u);
        if (p[u].l != p[u].r) {
            pushup(u << 1, 1);
            pushup(u << 1 | 1, 1);
        }
        return;
    }
    move1(u);
    if (r <= p[u].mid) {
        modify(u << 1, l, r);
        move2(u << 1 | 1);
    } else if (l > p[u].mid) {
        modify(u << 1 | 1, l, r);
        move2(u << 1);
    } else {
        modify(u << 1, l, p[u].mid);
        modify(u << 1 | 1, p[u].mid + 1, r);
    }
}

void dfs(int u) {
    pushdown(u);
    if (p[u].l == p[u].r) return;
    dfs(u << 1), dfs(u << 1 | 1);
}

void main() {
    read(n), read(m);
    matrix_init();
    build(1, 1, n);
    for (int i = 1, l, r, opt, times = 1; i <= m; i++)
        if (read(opt), opt == 1) {
            read(l), read(r);
            modify(1, l, r);
            times = sub(times, times);
        } else {
//          printf(">> %d * %d = %d\n", ans, times, mul(ans, times));
            print(mul(ans, times), '\n');
        }
}

} signed main() {
#ifdef MEMSET0_LOCAL_ENVIRONMENT
    freopen("1.in", "r", stdin);
#endif
    return ringo::main(), 0;
}

线段树 矩阵优化 DP

Codeforces 数据结构做题记录
上一篇 «
洛谷5278 算术天才⑨与等差数列
» 下一篇