联测的题,刚好前几天做过类似的 idea 结果就 AC 了 QAQ ...

考虑对于每个 k ,处理出 b_{1 ... n - k + 1} ,对于 b 的每个长度超过 k 的极长连续段连边,可以用调和级数证明总次数不会超过 O(n \log n) 。使用后缀数组维护,类似「优秀的拆分」,也可以直接二分。后缀数组的复杂度是 O(n \log n)

考虑连边,可以利用「萌萌哒」一题的 idea ,在倍增数组上递归合并,如果左右两边已经相同就 return 掉否则就计算到答案中。可以证明这个的复杂度是均摊 \log n 级别的。

故如果使用后缀数组 + ST 表 + 倍增并查集,复杂度为 O(n \log n \times \alpha(n)) ,可以通过此题。在 n 较小 T 较大的数据点可能因为常数原因 TLE ,可以考虑对于 n 足够小的部分 O(n^2) 暴力做。

代码:

// =================================
//   author: memset0
//   date: 2019.03.15 08:52:07
//   website: https://memset0.cn/
// =================================
#include <bits/stdc++.h>
#define ll long long
#define debug(...) ((void)0)
#ifndef debug
#define debug(...) fprintf(stderr,__VA_ARGS__)
#endif
namespace ringo {
template <class T> inline void read(T &x) {
    x = 0; register char c = getchar(); register bool f = 0;
    while (!isdigit(c)) f ^= c == '-', c = getchar();
    while (isdigit(c)) x = x * 10 + c - '0', c = getchar();
    if (f) x = -x;
}
template <class T> inline void print(T x) {
    if (x < 0) putchar('-'), x = -x;
    if (x > 9) print(x / 10);
    putchar('0' + x % 10);
}
template <class T> inline void print(T x, char c) { print(x), putchar(c); }

const int N = 1200010, M = 20;
int T, W, L, n, a[N], w[N], b[N], fa[M][N];
ll ans;
std::vector <std::pair <int, int> > c;

namespace for_small_data {
    const int N = 3e5 + 10;
    int b[N], fa[N]; ll ans;
    struct edge {
        int u, v, w;
        inline bool operator < (const edge &other) const {
            return w < other.w;
        }
    };
    std::vector <edge> E;
    int find(int x) {
        if (fa[x] == x) return x;
        return fa[x] = find(fa[x]);
    }
    void solve(int n) {
        ans = 0, E.clear();
        for (int i = 1; i <= n; i++) fa[i] = i;
        for (auto it : c) {
            int w = it.first, k = it.second;
            b[n + 1 - k] = 0;
            for (int i = 1; i + k <= n; i++) b[i] = a[i] == a[i + k];
            for (int i = 1, cnt = 0; i + k <= n + 1; i++)
                if (b[i]) ++cnt;
                else {
                    if (cnt >= k) {
                        for (int u = i - cnt, v, f_u, f_v; u <= i - 1; u++) {
                            v = u + k, f_u = find(u), f_v = find(v);
                            if (f_u != f_v) fa[f_u] = f_v, ans += w;
                        }
                    } cnt = 0;
                }
        } print(ans, '\n');
    }
}

int find(int x, int k) {
    if (fa[k][x] == x) return x;
    return fa[k][x] = find(fa[k][x], k);
}

void solve(int u, int v, int k) {
    int f_u = find(u, k), f_v = find(v, k);
    if (f_u == f_v) { ans -= (ll)W * (1 << k); return; }
    fa[k][f_u] = f_v;
    if (k == 0) return;
    solve(u, v, k - 1);
    solve(u + (1 << (k - 1)), v + (1 << (k - 1)), k - 1);
}

void merge(int a, int b, int c, int d) {
    // printf("merge [%d %d] => [%d %d]\n", a, b, c, d);
    ans += (ll)W * (b - a + 1);
    for (int i = L - 1; i >= 0; i--)
        if ((b - a + 1) >= (1 << i)) {
            solve(a, c, i);
            a += 1 << i, c += 1 << i;
        }
}

void mem(int *a, int n) {
    int l = 1; while (l < n) l <<= 1;
    memset(a, 0, std::min(N, l << 1) << 2);
}

int log[N], tmp[N], tax[N];
struct SA {
    int siz, len;
    int sa[N], rnk[N], height[N], st[N][20];
    int min(int l, int r) {
        if (l > r) std::swap(l, r); l++; int t = log[r - l + 1];
        return std::min(st[l][t], st[r - (1 << t) + 1][t]);
    }
    void sort() {
        for (int i = 1; i <= siz; i++) tax[i] = 0;
        for (int i = 1; i <= len; i++) tax[rnk[i]]++;
        for (int i = 1; i <= siz; i++) tax[i] += tax[i - 1];
        for (int i = len; i >= 1; i--) sa[tax[rnk[tmp[i]]]--] = tmp[i];
    }
    void build(int *s, int _len) {
        len = _len, mem(sa, len), mem(rnk, len), mem(tmp, len), mem(height, len), siz = n;
        for (int i = 1; i <= len; i++) rnk[i] = s[i], tmp[i] = i;
        sort();
        for (int now = 1, cnt = 0; cnt < len; siz = cnt, now <<= 1) {
            cnt = 0;
            for (int i = len; i >= len - now + 1; i--) tmp[++cnt] = i;
            for (int i = 1; i <= len; i++) if (sa[i] > now) tmp[++cnt] = sa[i] - now;
            sort(), std::swap(rnk, tmp), rnk[sa[1]] = cnt = 1;
            for (int i = 2; i <= len; i++)
                rnk[sa[i]] = (tmp[sa[i]] == tmp[sa[i - 1]] && tmp[sa[i] + now] == tmp[sa[i - 1] + now]) ? cnt : ++cnt;
        }
        for (int i = 1, j, ans = 0; i <= len; i++) {
            j = sa[rnk[i] - 1]; if (ans) --ans;
            while (s[i + ans] == s[j + ans]) ++ans;
            height[rnk[i]] = ans;
        } height[1] = 0;
        for (int i = 1; i <= len; i++) st[i][0] = height[i];
        for (int i = 1; i < L; i++)
            for (int j = 1; j + (1 << i) - 1 <= len; j++)
                st[j][i] = std::min(st[j][i - 1], st[j + (1 << (i - 1))][i - 1]);
    }
} f[2];

void main() {
    // freopen("endless.in", "r", stdin), freopen("endless.out", "w", stdout);
    log[0] = -1; for (int i = 1; i < N; i++) log[i] = log[i >> 1] + 1;
    for (read(T); T--; ) {
        read(n), L = log[n] + 1, ans = 0, c.clear(), mem(a, n);
        for (int i = 1; i <= n; i++) read(a[i]);
        for (int i = 1; i <= (n >> 1); i++) read(w[i]), c.push_back(std::make_pair(w[i], i));
        std::sort(c.begin(), c.end());
        if (n <= 3000) { for_small_data::solve(n); continue; }
        f[0].build(a, n), std::reverse(a + 1, a + n + 1);
        f[1].build(a, n), std::reverse(a + 1, a + n + 1);
        for (int k = 0; k < L; k++) for (int i = 1; i <= n; i++) fa[k][i] = i;
        for (auto &it : c) {
            int w = it.first, k = it.second; W = w;
            for (int l = k, r = k + 1, lcp, lcs; r <= n; l += k, r += k) {
                lcp = r + k <= n ? f[0].min(f[0].rnk[r], f[0].rnk[r + k]) : 0;
                lcs = l + k <= n ? f[1].min(f[1].rnk[n - l + 1], f[1].rnk[n - l - k + 1]) : 0;
                // printf("k=%d l=%d r=%d lcs=%d lcp=%d\n", k, l, r, lcs, lcp);
                if (lcp + lcs < k) continue;
                merge(l - lcs + 1, r + lcp - 1, l - lcs + 1 + k, r + lcp - 1 + k);
            }
        } print(ans, '\n');
    }
}

} signed main() { return ringo::main(), 0; }

标签: none

已有 2 条评论

  1. sunset sunset

    merge这个函数你写成两个log了。。。可以取k=log(r-l+1),然后unite(k,a,c),unite(k,b-2^k+1,d-2^k+1)搞成一个log= =

添加新评论