动态 DP 学习笔记

(好懒不想写博客)

把每个点的 dp 转移分成 重儿子 和 自己+轻儿子 两个部分。重儿子的转移把矩阵放树上维护,利用矩阵乘法的性质;轻儿子的转移的每次修改时暴力维护,利用 dp 的性质。详情参考 txc 哥哥的博客

树剖代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
// =================================
// author: memset0
// date: 2018.12.18 09:15:21
// website: https://memset0.cn/
// =================================
#include <bits/stdc++.h>
#define int long long
#define rep(i, l, r) for (register int i = l; i <= r; i++)
namespace ringo {
typedef long long ll;
typedef unsigned long long ull;
template <class T> inline void read(T &x) {
x = 0; register char c = getchar(); register bool f = 0;
while (!isdigit(c)) f ^= c == '-', c = getchar();
while (isdigit(c)) x = x * 10 + c - '0', c = getchar();
if (f) x = -x;
}
template <class T> inline void print(T x) {
if (x < 0) putchar('-'), x = -x;
if (x > 9) print(x / 10);
putchar('0' + x % 10);
}
template <class T> inline void maxd(T &a, T b) { if (b > a) a = b; }
template <class T> inline void mind(T &a, T b) { if (b < a) a = b; }
template <class T> inline void print(T x, char c) { print(x), putchar(c); }
template <class T> inline T abs(const T &a) { if (a < 0) return -a; return a; }

const int N = 1e5 + 10, inf = 1e18;
int n, m, u, v, w, dta, pos;
typedef int R[N]; R a, son, fa, top, id, wid, bot, siz, dep
;
int tot = 2, hed[N], nxt[N << 1], to[N << 1];

struct matrix {
int a[2][2];
inline int* operator[] (const size_t &i) { return a[i]; }
void operator ~ () { rep(i, 0, 1) rep(j, 0, 1) print(a[i][j], " \n"[j == 1]); }
} f[N], g[N];

struct node {
int l, r, mid;
matrix x;
} p[N << 2];

inline matrix operator * (const matrix &a, const matrix &b) {
matrix c; memset(c.a, -0x3f, sizeof(c.a));
for (int i = 0; i < 2; i++)
for (int j = 0; j < 2; j++)
for (int k = 0; k < 2; k++)
maxd(c[i][j], a.a[i][k] + b.a[k][j]);
return c;
}

void dfs1(int u) {
siz[u] = 1;
for (int i = hed[u], v = to[i]; i; i = nxt[i], v = to[i])
if (v != fa[u]) {
fa[v] = u, dep[v] = dep[u] + 1, dfs1(v), siz[u] += siz[v];
if (siz[v] > siz[son[u]]) son[u] = v;
}
}

void dfs2(int u, int toppoint) {
top[u] = toppoint, id[u] = ++pos, wid[id[u]] = u;
if (siz[u] == 1) { bot[u] = u; return; }
dfs2(son[u], toppoint), bot[u] = bot[son[u]];
for (int i = hed[u], v = to[i]; i; i = nxt[i], v = to[i])
if (v != fa[u] && v != son[u])
dfs2(v, v);
}

void dfs3(int u) {
f[u][0][0] = g[u][0][0] = 0, f[u][1][0] = g[u][1][0] = a[u];
for (int i = hed[u], v = to[i]; i; i = nxt[i], v = to[i])
if (v != fa[u]) {
dfs3(v);
f[u][0][0] += std::max(f[v][0][0], f[v][1][0]);
f[u][1][0] += f[v][0][0];
if (v != son[u]) {
g[u][0][0] += std::max(f[v][0][0], f[v][1][0]);
g[u][1][0] += f[v][0][0];
}
}
}

void build(int l = 1, int r = n, int u = 1) {
p[u].l = l, p[u].r = r, p[u].mid = (l + r) >> 1;
if (l == r) { p[u].x = g[wid[l]]; return; }
build(l, p[u].mid, u << 1);
build(p[u].mid + 1, r, u << 1 | 1);
p[u].x = p[u << 1].x * p[u << 1 | 1].x;
}

void modify(int k, const matrix &x, int u = 1) {
if (p[u].l == p[u].r) { p[u].x = x; return; }
if (k <= p[u].mid) modify(k, x, u << 1);
else modify(k, x, u << 1 | 1);
p[u].x = p[u << 1].x * p[u << 1 | 1].x;
}

matrix query(int l, int r, int u = 1) {
if (p[u].l == l && p[u].r == r) return p[u].x;
if (r <= p[u].mid) return query(l, r, u << 1);
if (l > p[u].mid) return query(l, r, u << 1 | 1);
return query(l, p[u].mid, u << 1) * query(p[u].mid + 1, r, u << 1 | 1);
}

void maintain(int u) {
if (bot[u] == u) return;
f[u] = query(id[u], id[fa[bot[u]]]) * f[bot[u]];
}

void update(int u, int dta) {
if (u == top[u]) {
g[fa[top[u]]][0][0] -= std::max(f[top[u]][1][0], f[top[u]][0][0]);
g[fa[top[u]]][0][1] -= std::max(f[top[u]][1][0], f[top[u]][0][0]);
}
a[u] += dta, f[u][1][0] += dta, g[u][1][0] += dta, modify(id[u], g[u]);
if (u == top[u]) {
g[fa[top[u]]][0][0] += std::max(f[top[u]][1][0], f[top[u]][0][0]);
g[fa[top[u]]][0][1] += std::max(f[top[u]][1][0], f[top[u]][0][0]);
}
while (u) {
g[fa[top[u]]][0][0] -= std::max(f[top[u]][1][0], f[top[u]][0][0]);
g[fa[top[u]]][0][1] -= std::max(f[top[u]][1][0], f[top[u]][0][0]);
g[fa[top[u]]][1][0] -= f[top[u]][0][0];
maintain(top[u]);
g[fa[top[u]]][0][0] += std::max(f[top[u]][1][0], f[top[u]][0][0]);
g[fa[top[u]]][0][1] += std::max(f[top[u]][1][0], f[top[u]][0][0]);
g[fa[top[u]]][1][0] += f[top[u]][0][0];
if (fa[top[u]]) modify(id[fa[top[u]]], g[fa[top[u]]]);
u = fa[top[u]];
}
}

void main() {
read(n), read(m);
for (int i = 1; i <= n; i++) read(a[i]);
for (int i = 1; i < n; i++) {
read(u), read(v);
nxt[tot] = hed[u], to[tot] = v, hed[u] = tot++;
nxt[tot] = hed[v], to[tot] = u, hed[v] = tot++;
}
dfs1(1), dfs2(1, 1), dfs3(1);
for (int i = 1; i <= n; i++) g[i][0][1] = g[i][0][0], g[i][1][1] = -inf;
build();
for (int i = 1; i <= n; i++) if (i != bot[i]) maintain(i);
for (int i = 1; i <= m; i++) {
read(u), read(w), dta = w - a[u], update(u, dta);
print(std::max(f[1][0][0], f[1][1][0]), '\n');
}
for (int i = 1; i <= n; i++) if (i != bot[i]) maintain(i);
}

} signed main() { return ringo::main(), 0; }

全局平衡二叉树代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
// =================================
// author: memset0
// date: 2018.12.19 09:56:29
// website: https://memset0.cn/
// =================================
#include <bits/stdc++.h>
#define rep(i, l, r) for (register int i = l; i <= r; i++)
namespace ringo {
typedef long long ll;
typedef unsigned long long ull;
template <class T> inline void read(T &x) {
x = 0; register char c = getchar(); register bool f = 0;
while (!isdigit(c)) f ^= c == '-', c = getchar();
while (isdigit(c)) x = x * 10 + c - '0', c = getchar();
if (f) x = -x;
}
template <class T> inline void print(T x) {
if (x < 0) putchar('-'), x = -x;
if (x > 9) print(x / 10);
putchar('0' + x % 10);
}
template <class T> inline void maxd(T &a, T b) { if (b > a) a = b; }
template <class T> inline void mind(T &a, T b) { if (b < a) a = b; }
template <class T> inline void print(T x, char c) { print(x), putchar(c); }
template <class T> inline T abs(const T &a) { if (a < 0) return -a; return a; }

const int N = 1e5 + 10, inf = 1e9;
int n, m, u, v, w, rt, top, dta, ch[N][2];
typedef int R[N]; R a, stk, son, fa, siz, fat, dep;
int tot = 2, hed[N], to[N << 1], nxt[N << 1];

struct matrix {
int a[2][2];
inline int* operator [] (const size_t i) { return a[i]; }
void operator ~ () const { printf("[%d %d %d %d]", a[0][0], a[0][1], a[1][0], a[1][1]); }
} f[N], g[N], sum[N];

inline matrix operator * (const matrix &a, const matrix &b) {
return (matrix) {{{
std::max(a.a[0][0] + b.a[0][0], a.a[0][1] + b.a[1][0]),
std::max(a.a[0][0] + b.a[0][1], a.a[0][1] + b.a[1][1])},{
std::max(a.a[1][0] + b.a[0][0], a.a[1][1] + b.a[1][0]),
std::max(a.a[1][0] + b.a[0][1], a.a[1][1] + b.a[1][1])
}}};
}

void dfs(int u) {
siz[u] = 1;
for (int i = hed[u], v = to[i]; i; i = nxt[i], v = to[i])
if (v != fat[u]) {
fat[v] = u, dep[v] = dep[u] + 1, dfs(v), siz[u] += siz[v];
if (siz[v] > siz[son[u]]) son[u] = v;
}
}

void dfs2(int u) {
f[u][0][0] = g[u][0][0] = 0, f[u][1][0] = g[u][1][0] = a[u];
for (int i = hed[u], v = to[i]; i; i = nxt[i], v = to[i])
if (v != fat[u]) {
dfs2(v);
f[u][0][0] += std::max(f[v][0][0], f[v][1][0]);
f[u][1][0] += f[v][0][0];
if (v != son[u]) {
g[u][0][0] += std::max(f[v][0][0], f[v][1][0]);
g[u][0][1] += std::max(f[v][0][0], f[v][1][0]);
g[u][1][0] += f[v][0][0];
}
}
}

bool is_root(int u) { return ch[fa[u]][0] != u && ch[fa[u]][1] != u; }

void update(int u) {
f[u] = g[u];
if (ch[u][0]) f[u] = f[ch[u][0]] * f[u];
if (ch[u][1]) f[u] = f[u] * f[ch[u][1]];
}

int sbuild(int l, int r) {
int sum = 0, now = 0;
for (int i = l; i <= r; i++) sum += siz[stk[i]] - siz[son[stk[i]]];
for (int i = l; i <= r; i++) {
now += siz[stk[i]] - siz[son[stk[i]]];
if ((now << 1) >= sum) {
fa[ch[stk[i]][0] = sbuild(l, i - 1)] = stk[i];
fa[ch[stk[i]][1] = sbuild(i + 1, r)] = stk[i];
return update(stk[i]), stk[i];
}
}
return 0;
}

int build(int u) {
for (int x = u; x; x = son[x])
for (int i = hed[x], v = to[i]; i; i = nxt[i], v = to[i])
if (v != fat[x] && v != son[x])
fa[build(v)] = x;
top = 0; for (int x = u; x; x = son[x]) stk[++top] = x;
return sbuild(1, top);
}

void update(int u, int dta) {
for (int x = u; x; x = fa[x])
if (is_root(x) && fa[x]) {
g[fa[x]][0][0] -= std::max(f[x][0][0], f[x][1][0]);
g[fa[x]][0][1] -= std::max(f[x][0][0], f[x][1][0]);
g[fa[x]][1][0] -= f[x][0][0];
if (u == x) g[x][1][0] += dta;
update(x);
g[fa[x]][0][0] += std::max(f[x][0][0], f[x][1][0]);
g[fa[x]][0][1] += std::max(f[x][0][0], f[x][1][0]);
g[fa[x]][1][0] += f[x][0][0];
} else {
if (u == x) g[x][1][0] += dta;
update(x);
}
}

void main() {
read(n), read(m);
for (int i = 1; i <= n; i++) read(a[i]);
for (int i = 1; i < n; i++) {
read(u), read(v);
nxt[tot] = hed[u], to[tot] = v, hed[u] = tot++;
nxt[tot] = hed[v], to[tot] = u, hed[v] = tot++;
}
f[0] = {{{1, 0}, {0, 1}}};
for (int i = 1; i <= n; i++) g[i][1][1] = -inf;
dfs(1), dfs2(1), rt = build(1);
for (int i = 1; i <= m; i++) {
read(u), read(w), dta = w - a[u], a[u] = w;
update(u, dta), print(std::max(f[rt][0][0], f[rt][1][0]), '\n');
}
}

} signed main() { return ringo::main(), 0; }
Your browser is out-of-date!

Update your browser to view this website correctly. Update my browser now

×