首先我们把原来的距离数组 $p$ 差分为数组 $a$。原题可以等同为在 $a$ 数组中选择 $k$ 个不相邻的数使得总和最小。

假设我们已经选择了 $a_i$ ,那么 $a_{i-1}$ 和 $a_{i+1}$ 要么同时选择,要么同时没有被选择。同时,如果我们同时选择,需要的花费即 $V_{a_{i+1}} + V_{a_{i-1}} - V_{a_i}$ 。我们维护一个堆和双向链表,每次从小根堆选择堆顶,把 $a_i$、 $a_{i-1}$ 和 $a_{i+1}$ 同时删除,再新建一个价值为 $V_{a_{i+1}} + V_{a_{i-1}} - V_{a_i}$ 的节点,扔到堆里,重复 $k$ 次就能得到答案。

代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
// ==============================
// author: memset0
// website: https://memset0.cn
// ==============================
#include <bits/stdc++.h>
#define ll long long
#define rep(i,l,r) for (int i = l; i <= r; i++)
#define getc(x) getchar(x)
#define putc(x) putchar(x)

template <typename T> inline void read(T &x) {
x = 0; register char ch; register bool fl = 0;
while (ch = getc(), ch < 48 || 57 < ch) fl ^= ch == '-'; x = (ch & 15);
while (ch = getc(), 47 < ch && ch < 58) x = (x << 1) + (x << 3) + (ch & 15);
if (fl) x = -x;
}
template <typename T> inline void readc(T &x) {
while (x = getc(), !islower(x) && !isupper(x));
}
template <typename T> inline void print(T x, char c = ' ') {
static int buf[40];
if (x == 0) { putc('0'); putc(c); return; }
if (x < 0) putc('-'), x = -x;
for (buf[0] = 0; x; x /= 10) buf[++buf[0]] = x % 10 + 48;
while (buf[0]) putc((char) buf[buf[0]--]);
putc(c);
}

const int maxn = 1000010;
int n, m, pos, l[maxn], r[maxn], tmp[maxn];
ll ans, val[maxn];
bool vis[maxn];

struct node {
int id;
ll val;
} u, v;
bool operator < (const node &a, const node &b) {
return a.val > b.val;
}
std::priority_queue < node > q;

int main() {
read(n), read(m), --n;
for (int i = 1; i <= n + 1; i++)
read(tmp[i]);
for (int i = 1; i <= n; i++)
val[i] = tmp[i + 1] - tmp[i];
for (int i = 1; i <= n; i++)
l[i] = i - 1, r[i] = i + 1;
pos = n + 1;
for (int i = 1; i <= n; i++)
q.push(node{i, val[i]});
val[0] = val[n + 1] = 1e9;
for (int i = 1; i <= m; i++) {
while (vis[q.top().id] && q.size()) q.pop();
if (!q.size()) break;
u = q.top(), q.pop();
vis[u.id] = vis[l[u.id]] = vis[r[u.id]] = 1;
ans += u.val, v.id = ++pos;
l[v.id] = l[l[u.id]], r[v.id] = r[r[u.id]];
r[l[v.id]] = v.id, l[r[v.id]] = v.id;
val[v.id] = v.val = val[l[u.id]] + val[r[u.id]] - val[u.id];
q.push(v);
}
print(ans, '\n');
return 0;
}