考虑最暴力的做法,列出 n 个生成函数,FWT 到一起,复杂度 \mathcal O(n \times 2^k \times k) ,显然不能通过。

考虑三元组 (a_i, b_i, c_i)ans_k 的贡献,等价于三元组 (0, a_i \operatorname{xor} b_i, a_i \operatorname{xor} c_i)ans_{k \operatorname{xor} a_i} 的贡献。考虑这样将题意转化后,生成函数只有 0a_i \operatorname{xor} b_ia_i \operatorname{xor} c_i 下标上有值(分别为 a, b, c)。那么经过一次 FWT 后,只有四种值,分别为 a+b+c, a+b-c, a-b+c, a-b-c

考虑将 n 个生成函数加起来做 FWT ,假设 a+b+c, a+b-c, a-b+c, a-b-c 分别在其中出现了 x, y, z, w 次,那么这个位置在正常做法情况下的值应该为 (a+b+c)^x \times (a+b-c)^y \times (a-b+c)^z \times (a-b-c)^w ,也就是说我们需要解出这个 x, y, z, w 的值。

首先显然的 x + y + z + w = pp 是这位上当前的值。需要注意的是,这个 a, b, c 并不一定需要是题目中给出的,所以我们可以分别取 (0, 1, 0)(0, 0, 1) ,列出两个方程,此时我们若再通过改变 (a, b, c) 来列方程,则可以通过两个已有的来表示,本质相同的,需要考虑别的方式得到新的方程。

考虑构造一个新的多项式,对于每个 k = b_i \operatorname{xor} c_ik \text{+=} 1 ,这样的话可以得到一个新的方程 x + y + z + w = qq 是这位上当前的值。

这样的话我们就得到了四个本质不同方程,可以解除上面的 x, y, z, w ,得出正常 FWT 后应该得到的点值,再将其 IFWT 一次即可得到答案。

出于一些原因这篇题解的变量名可能有点重复,请自行感谢理解 :)

代码(变量名可能和上面题解不大一样):

// =================================
//   author: memset0
//   date: 2019.04.07 18:24:19
//   website: https://memset0.cn/
// =================================
#include <bits/stdc++.h>
#define ll long long
#define debug(...) ((void)0)
#ifndef debug
#define debug(...) fprintf(stderr,__VA_ARGS__)
#endif
namespace ringo {
template <class T> inline void read(T &x) {
    x = 0; register char c = getchar(); register bool f = 0;
    while (!isdigit(c)) f ^= c == '-', c = getchar();
    while (isdigit(c)) x = x * 10 + c - '0', c = getchar();
    if (f) x = -x;
}
template <class T> inline void print(T x) {
    if (x < 0) putchar('-'), x = -x;
    if (x > 9) print(x / 10);
    putchar('0' + x % 10);
}
template <class T> inline void print(T x, char c) { print(x), putchar(c); }

const int N = 2e5 + 10, mod = 998244353;
int n, m, p, q, r, lim, tot;
int a[N], b[N], c[N], f[N], g[N], h[N];

inline int dec(int a, int b) { a -= b; return a < 0 ? a + mod : a; }
inline int sub(int a, int b) { a += b; return a >= mod ? a - mod : a; }
inline int mul(int a, int b) { return (ll)a * b - (ll)a * b / mod * mod; }
inline int inv(int x) { return x < 2 ? 1 : mul(mod - mod / x, inv(mod % x)); }
inline int fpow(int a, int b) { int s = 1; for (; b; b >>= 1, a = mul(a, a)) if (b & 1) s = mul(s, a); return s; }

void fwt(int *a) {
    for (int len = 1; len < lim; len <<= 1) 
        for (int i = 0; i < lim; i += (len << 1))
            for (int j = 0; j < len; j++) {
                int x = a[i + j], y = a[i + j + len];
                a[i + j] = sub(x, y), a[i + j + len] = dec(x, y);
            }
}

void ifwt(int *a) {
    for (int len = 1; len < lim; len <<= 1)
        for (int i = 0; i < lim; i += (len << 1))
            for (int j = 0; j < len; j++) {
                int x = a[i + j], y = a[i + j + len];
                a[i + j] = mul(sub(x, y), inv(2)), a[i + j + len] = mul(dec(x, y), inv(2));
            }
}

void main() {
    read(n), read(m), read(p), read(q), read(r);
    for (int i = 1; i <= n; i++) read(a[i]), read(b[i]), read(c[i]);
    for (int i = 1; i <= n; i++) tot ^= a[i], ++f[a[i] ^ b[i]], ++g[a[i] ^ c[i]], ++h[b[i] ^ c[i]];
    lim = 1 << m;
    fwt(f), fwt(g), fwt(h);
    int t0 = sub(sub(p, q), r), t1 = dec(sub(p, q), r), t2 = sub(dec(p, q), r), t3 = dec(dec(p, q), r);
    for (int i = 0; i < lim; i++) {
        int x = mul(sub(n, sub(f[i], sub(g[i], h[i]))), inv(4)),
            y = dec(mul(sub(n, f[i]), inv(2)), x),
            z = dec(mul(sub(n, g[i]), inv(2)), x),
            w = dec(mul(sub(n, h[i]), inv(2)), x);
        f[i] = mul(fpow(t0, x), mul(fpow(t1, y), mul(fpow(t2, z), fpow(t3, w))));
    }
    ifwt(f);
    for (int i = 0; i < lim; i++) print(f[i ^ tot], ' ');   
}

} signed main() {
#ifdef MEMSET0_LOCAL_ENVIRONMENT
    freopen("1.in", "r", stdin);
#endif
    return ringo::main(), 0;
}

巧妙的思路 FWT

洛谷5293 [HNOI2019]白兔之舞
上一篇 «
Codeforces 数据结构做题记录
» 下一篇