• 快速傅里叶变换 FFT
  • 数论傅里叶变换 NTT
  • 任意模数 NTT
  • 多项式求逆
  • 未完待续……

快速傅里叶变换 FFT

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
#include <bits/stdc++.h>
#define il inline
#define rg register
#define ll long long
#define getc getchar
#define putc putchar
#define rep(i, l, r) for (int i = l; i <= r; ++i)
namespace ringo {

template < class T > il void read(T &x) {
x = 0; rg char c = getc(); rg bool f = 0;
while (!isdigit(c)) f ^= c == '-', c = getc();
while (isdigit(c)) x = x * 10 + c - '0', c = getc();
if (f) x = -x;
}

template < class T > il void print(T x) {
if (x < 0) putc('-'), x = -x;
if (x > 9) print(x / 10);
putc('0' + x % 10);
}

const int maxn = 4e6 + 10;
const double pi = acos(-1);
int n, k, n1, n2, rev[maxn];

struct complex {
double x, y;
il complex(const double &a = 0, const double &b = 0) { x = a, y = b; }
il complex operator + (const complex &b) const { return complex(x + b.x, y + b.y); }
il complex operator - (const complex &b) const { return complex(x - b.x, y - b.y); }
il complex operator * (const complex &b) const { return complex(x * b.x - y * b.y, x * b.y + y * b.x); }
} a[maxn], b[maxn];

void fft(complex *a, int flag) {
for (int i = 0; i < n; i++) if (i < rev[i]) std::swap(a[i], a[rev[i]]);
for (int len = 1; len < n; len <<= 1) {
complex wn(cos(pi / len), flag * sin(pi / len));
for (int i = 0; i < n; i += (len << 1)) {
complex w(1, 0);
for (int j = 0; j < len; j++, w = w * wn) {
complex x = a[i + j], y = w * a[i + j + len];
a[i + j] = x + y, a[i + j + len] = x - y;
}
}
}
}

void main() {
read(n1), read(n2);
for (int i = 0; i <= n1; i++) read(a[i].x);
for (int i = 0; i <= n2; i++) read(b[i].x);
n = 1; while (n <= (n1 + n2)) n <<= 1, ++k;
for (int i = 0; i < n; i++) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (k - 1));
fft(a, 1), fft(b, 1);
for (int i = 0; i < n; i++) a[i] = a[i] * b[i];
fft(a, -1);
for (int i = 0; i <= n1 + n2; i++) print((int)(a[i].x / n + 0.5)), putc(i == n1 + n2 ? '\n' : ' ');
}

} int main() { return ringo::main(), 0; }

数论傅里叶变换 NTT

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
#include <bits/stdc++.h>
#define il inline
#define rg register
#define ll long long
#define getc getchar
#define putc putchar
#define rep(i, l, r) for (int i = l; i <= r; ++i)
namespace ringo {

template < class T > il void read(T &x) {
x = 0; rg char c = getc(); rg bool f = 0;
while (!isdigit(c)) f ^= c == '-', c = getc();
while (isdigit(c)) x = x * 10 + c - '0', c = getc();
if (f) x = -x;
}

template < class T > il void print(T x) {
if (x < 0) putc('-'), x = -x;
if (x > 9) print(x / 10);
putc('0' + x % 10);
}

const int maxn = 4e6 + 10, P = 998244353, G = 3;
int n, k, n1, n2, tmp, rev[maxn];
ll a[maxn], b[maxn];

int inv(int x) {
if (x == 0 || x == 1) return 1;
return 1ll * (P - P / x) * inv(P % x) % P;
}

int pow(ll x, int b) {
ll s = 1;
while (b) {
if (b & 1) (s *= x) %= P;
(x *= x) %= P, b >>= 1;
}
return s;
}

void ntt(ll *a, int G) {
for (int i = 0; i < n; i++) if (i < rev[i]) std::swap(a[i], a[rev[i]]);
for (int len = 1; len < n; len <<= 1) {
ll wn = pow(G, (P - 1) / (len << 1));
for (int i = 0; i < n; i += (len << 1)) {
ll w = 1;
for (int j = 0; j < len; j++, w = w * wn % P) {
ll x = a[i + j], y = w * a[i + j + len] % P;
a[i + j] = (x + y) % P, a[i + j + len] = (x - y + P) % P;
}
}
}
}

void main() {
read(n1), read(n2);
for (int i = 0; i <= n1; i++) read(a[i]), a[i] %= P;
for (int i = 0; i <= n2; i++) read(b[i]), b[i] %= P;
n = 1; while (n <= (n1 + n2)) n <<= 1, ++k;
for (int i = 0; i < n; i++) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (k - 1));
ntt(a, G), ntt(b, G);
for (int i = 0; i < n; i++) a[i] = a[i] * b[i] % P;
ntt(a, inv(G)), tmp = inv(n);
for (int i = 0; i <= n1 + n2; i++) print(a[i] * tmp % P), putc(i == n1 + n2 ? '\n' : ' ');
}

} int main() { return ringo::main(), 0; }

任意模数 NTT

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
#include <bits/stdc++.h>
namespace ringo {
typedef long long ll;

template < class T >
inline void read(T &x) {
x = 0; register char c = getchar(); register bool f = 0;
while (!isdigit(c)) f ^= c == '-', c = getchar();
while (isdigit(c)) x = x * 10 + c - '0', c = getchar();
if (f) x = -x;
}

template < class T >
inline void print(T x) {
if (x < 0) putchar('-'), x = -x;
if (x > 9) print(x / 10);
putchar('0' + x % 10);
}

template < class T >
inline void print(T x, char c) {
print(x), putchar(c);
}

const int N = 4e5 + 10, G = 3, P[3] = {469762049, 998244353, 1004535809};
int n1, n2, k, n, p, p1, p2, M2;
int a[N], b[N], f[3][N], g[N], rev[N], ans[N];

int inv(int x, int p) {
if (x >= p) return inv(x % p, p);
return !x || x == 1 ? 1 : (ll)(p - p / x) * inv(p % x, p) % p;
}

int pow(int a, int b, int p) {
int s = 1;
while (b) {
if (b & 1) s = (ll)s * a % p;
b >>= 1, a = (ll)a * a % p;
}
return s;
}

void ntt(int *a, int g, int p) {
for (int i = 0; i < n; i++) if (i < rev[i]) std::swap(a[i], a[rev[i]]);
for (int len = 1; len < n; len <<= 1) {
int wn = pow(g, (p - 1) / (len << 1), p);
for (int i = 0; i < n; i += (len << 1)) {
int w = 1;
for (int j = 0; j < len; j++, w = (ll)w * wn % p) {
int x = a[i + j], y = (ll)w * a[i + j + len] % p;
a[i + j] = (x + y) % p, a[i + j + len] = (x - y + p) % p;
}
}
}
}

int merge(int a1, int a2, int A2) {
ll M1 = (ll)p1 * p2;
ll A1 = ((ll)inv(p2, p1) * a1 % p1 * p2 + (ll)inv(p1, p2) * a2 % p2 * p1) % M1;
ll K = ((A2 - A1) % M2 + M2) % M2 * inv(M1 % M2, M2) % M2;
int ans = (A1 + M1 % p * K) % p;
return ans;
}

void main() {
read(n1), read(n2), read(p);
p1 = P[0], p2 = P[1], M2 = P[2];
for (int i = 0; i <= n1; i++) read(a[i]);
for (int i = 0; i <= n2; i++) read(b[i]);
n = 1; while (n <= (n1 + n2)) n <<= 1, ++k;
for (int i = 0; i < n; i++) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (k - 1));
for (int k = 0; k < 3; k++) {
for (int i = 0; i < n; i++) f[k][i] = a[i] % P[k];
for (int i = 0; i < n; i++) g[i] = b[i] % P[k];
ntt(f[k], G, P[k]), ntt(g, G, P[k]);
for (int i = 0; i < n; i++) f[k][i] = (ll)f[k][i] * g[i] % P[k];
ntt(f[k], inv(G, P[k]), P[k]);
for (int i = 0; i < n; i++) f[k][i] = (ll)f[k][i] * inv(n, P[k]) % P[k];
}
for (int i = 0; i <= n1 + n2; i++) ans[i] = merge(f[0][i], f[1][i], f[2][i]);
for (int i = 0; i <= n1 + n2; i++) print(ans[i], " \n"[i == n1 + n2]);
}

} signed main() { return ringo::main(), 0; }

多项式求逆

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
// =================================
// author: memset0
// date: 2018.12.03 12:19:03
// website: https://memset0.cn/
// =================================
#include <bits/stdc++.h>
namespace ringo {
typedef long long ll;

template < class T >
inline void read(T &x) {
x = 0; register char c = getchar(); register bool f = 0;
while (!isdigit(c)) f ^= c == '-', c = getchar();
while (isdigit(c)) x = x * 10 + c - '0', c = getchar();
if (f) x = -x;
}

template < class T >
inline void print(T x) {
if (x < 0) putchar('-'), x = -x;
if (x > 9) print(x / 10);
putchar('0' + x % 10);
}

template < class T >
inline void print(T x, char c) {
print(x), putchar(c);
}

const int N = 4e5 + 10, p = 998244353;
int n, k, i, lim, lim_inv;
int a[N], b[N], f[N], g[N], rev[N];

int fpow(int a, int b) {
int s = 1;
while (b) {
if (b & 1) s = (ll)s * a % p;
b >>= 1, a = (ll)a * a % p;
}
return s;
}

void ntt(int *a, int g) {
for (int i = 0; i < lim; i++) if (i < rev[i]) std::swap(a[i], a[rev[i]]);
for (int len = 1; len < lim; len <<= 1) {
int wn = fpow(g, (p - 1) / (len << 1));
for (int i = 0; i < lim; i += (len << 1)) {
int w = 1;
for (int j = 0; j < len; j++, w = (ll)w * wn % p) {
int x = a[i + j], y = (ll)w * a[i + j + len] % p;
a[i + j] = (x + y) % p, a[i + j + len] = (x - y + p) % p;
}
}
}
}

void solve(int *a, int *b, int n) {
if (n == 1) { b[0] = fpow(a[0], p - 2); return; }
solve(a, b, n >> 1);
memset(f, 0, sizeof(f)), memset(g, 0, sizeof(g));
for (int i = 0; i < n; i++) f[i] = a[i], g[i] = b[i];
lim = 1, k = 0; while (lim != (n << 1)) lim <<= 1, ++k;
for (int i = 0; i < lim; i++) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (k - 1));
ntt(f, 3), ntt(g, 3);
for (int i = 0; i < lim; i++) f[i] = (ll)f[i] * g[i] % p * g[i] % p;
ntt(f, fpow(3, p - 2)), lim_inv = fpow(lim, p - 2);
for (int i = 0; i < n; i++) b[i] = (2ll * b[i] - (ll)f[i] * lim_inv % p + p) % p;
}

void main() {
for (read(n), i = 0; i < n; i++) read(a[i]);
lim = 1; while (lim < n) lim <<= 1;
solve(a, b, lim);
for (int i = 0; i < n; i++) print(b[i], " \n"[i == n - 1]);
}

} signed main() { return ringo::main(), 0; }