当我跨过沉沦的一切
向着永恒开战的时候
你是我的军旗

一个非常套路的 O(n^3) DP 就是用 f_{i, j} 表示做了前 i 项,其中最后一项是前 i 个中第 j 大的方案数,转移即

f_{i, j} = \begin{cases} \displaystyle \sum_{k=1}^{i-1} f_{i-1, k} & (s_{i - 1} = \textrm{<}) \\ \displaystyle \sum_{k=i}^j f_{i-1, k} & (s_{i - 1} = \textrm{>}) \\ \end{cases}

这个 DP 很难再进行优化,需要考虑别的算法。

考虑我们可以固定所有的 \textrm{>} ,并用容斥的方式来解决 \textrm{<} 。这样就变成了只有 \textrm{>}\neq 两种限制。可以看做排列被 \neq 分成了若干段,每段的排列的大小关系都是确定的,假设每段的大小分别为 a_{1 ... m} ,那么贡献即 \displaystyle \frac {n!} {\prod_{i=1}^m a_i} 。另外容斥系数只与硬点了几个 \textrm{>} 有关,可以和组合数放在一起 DP 。具体为

dp_n = \begin{cases} 1 & (n = 0) \\ \displaystyle \sum_{j=0}^{i-1} [j = 0 \textrm{ or } s_j = \textrm{<}] (-1)^{cnt_{i-1} - cnt_j} \frac {dp_{j}} {(i - j)!} & (n > 0) \\ \end{cases}

注意到的这个式子非常类似于分治 NTT 的形式,可以在 O\left(n \log^2 n\right) 的时间复杂度内解决。

代码

前缀和优化 DP ,O(n^2)
// =================================
//   author: memset0
//   date: 2019.07.08 14:48:24
//   website: https://memset0.cn/
// =================================
#include <bits/stdc++.h>
#define ll long long
#define rep(i, l, r) for (int (i) = (l), __lim = (r); (i) <= __lim; (i)++)
#define for_each(i, a) for (size_t i = 0, __lim = a.size(); i < __lim; ++i)
namespace ringo {

template <class T> inline void read(T &x) {
  x = 0; char c = getchar(); bool f = 0;
  while (!isdigit(c)) f ^= c == '-', c = getchar();
  while (isdigit(c)) x = x * 10 + c - '0', c = getchar();
  if (f) x = -x;
}
template <class T> inline void print(T x) {
  if (x < 0) putchar('-'), x = -x;
  if (x > 9) print(x / 10);
  putchar('0' + x % 10);
}
template <class T> inline void print(T x, char c) { print(x), putchar(c); }
template <class T> inline void print(T a, int l, int r, std::string s = "") {
  if (s != "") std::cout << s << ": ";
  for (int i = l; i <= r; i++) print(a[i], " \n"[i == r]);
}

const int N = 2e3 + 10, mod = 998244353;
char s[N];
int n, m, ans;
int f[N][N], g[N][N];

inline int dec(int a, int b) { a -= b; return a < 0 ? a + mod : a; }
inline int inc(int a, int b) { a += b; return a >= mod ? a - mod : a; }
inline int mul(int a, int b) { return (ll)a * b - (ll)a * b / mod * mod; }

void main() {
  scanf("%s", s + 1);
  n = strlen(s + 1) + 1;
  f[1][1] = g[1][1] = 1;
  for (int i = 2; i <= n; i++) {
    for (int j = 1; j <= i; j++)
      if (s[i - 1] == '<') {
        f[i][j] = g[i - 1][j - 1];
      } else {
        f[i][j] = g[i - 1][j];
      }
    if (s[i] == '<') {
      for (int j = 1; j <= i; j++)
        g[i][j] = inc(g[i][j - 1], f[i][j]);
    } else {
      for (int j = i; j >= 1; j--)
        g[i][j] = inc(g[i][j + 1], f[i][j]);
    }
  }
  for (int i = 1; i <= n; i++) {
    ans = inc(ans, f[n][i]);
  }
  print(ans, '\n');
}

} signed main() {
#ifdef memset0
  freopen("1.in", "r", stdin);
#endif
  return ringo::main(), 0;
}
容斥 DP
#include <bits/stdc++.h>
#define ll long long
#define rep(i, l, r) for (int (i) = (l), __lim = (r); (i) <= __lim; (i)++)
#define for_each(i, a) for (size_t i = 0, __lim = a.size(); i < __lim; ++i)
namespace ringo {

template <class T> inline void read(T &x) {
  x = 0; char c = getchar(); bool f = 0;
  while (!isdigit(c)) f ^= c == '-', c = getchar();
  while (isdigit(c)) x = x * 10 + c - '0', c = getchar();
  if (f) x = -x;
}
template <class T> inline void print(T x) {
  if (x < 0) putchar('-'), x = -x;
  if (x > 9) print(x / 10);
  putchar('0' + x % 10);
}
template <class T> inline void print(T x, char c) { print(x), putchar(c); }
template <class T> inline void print(T a, int l, int r, std::string s = "") {
  if (s != "") std::cout << s << ": ";
  for (int i = l; i <= r; i++) print(a[i], " \n"[i == r]);
}

const int N = 2e3 + 10, mod = 998244353;
char s[N];
int n, dp[N], cnt[N], fac[N], inv_fac[N];

inline int dec(int a, int b) { a -= b; return a < 0 ? a + mod : a; }
inline int inc(int a, int b) { a += b; return a >= mod ? a - mod : a; }
inline int mul(int a, int b) { return (ll)a * b - (ll)a * b / mod * mod; }

void main() {
  scanf("%s", s + 1);
  n = strlen(s + 1) + 1;
  for (int i = 1; i <= n; i++) cnt[i] = cnt[i - 1] + (s[i] == '>');
  fac[0] = fac[1] = inv_fac[0] = inv_fac[1] = 1;
  for (int i = 2; i <= n; i++) fac[i] = mul(fac[i - 1], i);
  for (int i = 2; i <= n; i++) inv_fac[i] = mul(mod - mod / i, inv_fac[mod % i]);
  for (int i = 2; i <= n; i++) inv_fac[i] = mul(inv_fac[i - 1], inv_fac[i]);

  dp[0] = 1;
  for (int i = 1; i <= n; i++)
    for (int j = 0; j < i; j++)
      if (s[j] == '>' || j == 0) {
        dp[i] = inc(dp[i], mul(dp[j], mul(cnt[i - 1] + cnt[j] & 1 ? mod - 1 : 1, inv_fac[i - j])));
      }
  for (int i = 0; i <= n; i++) dp[i] = mul(dp[i], fac[i]);
  // print(dp, 0, n);
  print(dp[n], '\n');
}

} signed main() {
#ifdef memset0
  freopen("1.in", "r", stdin);
#endif
  return ringo::main(), 0;
}
分治 NTT 优化容斥 DP
// =================================
//   author: memset0
//   date: 2019.07.08 15:23:33
//   website: https://memset0.cn/
// =================================
#include <bits/stdc++.h>
#define ll long long
#define rep(i, l, r) for (int (i) = (l), __lim = (r); (i) <= __lim; (i)++)
#define for_each(i, a) for (size_t i = 0, __lim = a.size(); i < __lim; ++i)
namespace ringo {

template <class T> inline void read(T &x) {
  x = 0; char c = getchar(); bool f = 0;
  while (!isdigit(c)) f ^= c == '-', c = getchar();
  while (isdigit(c)) x = x * 10 + c - '0', c = getchar();
  if (f) x = -x;
}
template <class T> inline void print(T x) {
  if (x < 0) putchar('-'), x = -x;
  if (x > 9) print(x / 10);
  putchar('0' + x % 10);
}
template <class T> inline void print(T x, char c) { print(x), putchar(c); }
template <class T> inline void print(T a, int l, int r, std::string s = "") {
  if (s != "") std::cout << s << ": ";
  for (int i = l; i <= r; i++) print(a[i], " \n"[i == r]);
}

const int N = 4e5 + 10, mod = 998244353;
char s[N];
int n, a[N], b[N], w[N], f[N], g[N], cnt[N], rev[N], fac[N], inv_fac[N];

inline int opp(int x) { return x ? mod - x : 0; }
inline int dec(int a, int b) { a -= b; return a < 0 ? a + mod : a; }
inline int inc(int a, int b) { a += b; return a >= mod ? a - mod : a; }
inline int mul(int a, int b) { return (ll)a * b - (ll)a * b / mod * mod; }
inline int inv(int x) { return x < 2 ? 1 : mul(mod - mod / x, inv(mod % x)); }
inline int fpow(int a, int b) { int s = 1; for (; b; b >>= 1, a = mul(a, a)) if (b & 1) s = mul(s, a); return s; }

int init(int len) {
  int lim = 1, k = 0; while (lim < len) lim <<= 1, ++k;
  for (int i = 0; i < lim; i++) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (k - 1));
  static int base_len = 1;
  for (int len = base_len, wn; len < lim; base_len = len <<= 1) {
    wn = fpow(3, (mod - 1) / (len << 1)), w[len] = 1;
    for (int i = 1; i < len; i++) w[i + len] = mul(w[i + len - 1], wn);
  } return lim;
}

void ntt(int *a, int lim) {
  for (int i = 0; i < lim; i++) if (i < rev[i]) std::swap(a[i], a[rev[i]]);
  for (int len = 1; len < lim; len <<= 1)
    for (int i = 0; i < lim; i += (len << 1))
      for (int j = 0; j < len; j++) {
        int x = a[i + j], y = mul(a[i + j + len], w[j + len]);
        a[i + j] = inc(x, y), a[i + j + len] = dec(x, y);
      }
}

void solve(int l, int r) {
  if (l == r) {
    if (!l) f[0] = 1;
    else f[l] = cnt[l - 1] & 1 ? opp(f[l]) : f[l];
    if (s[l] == '>' || !l) g[l] = cnt[l] & 1 ? opp(f[l]) : f[l];
    return;
  }
  int mid = (l + r) >> 1;
  solve(l, mid);
  // printf("solve %d %d : [%d %d] -> [%d %d]\n", l, r, l, mid, mid + 1, r);
  int len1 = mid - l + 1, len2 = r - l;
  int lim = init(len1 + len2 - 1), inv_lim = inv(lim);
  for (int i = 0; i < lim; i++) a[i] = i < len1 ? g[i + l] : 0;
  for (int i = 0; i < lim; i++) b[i] = i < len2 ? inv_fac[i + 1] : 0;
  // print(a, 0, lim - 1), print(b, 0, lim - 1);
  ntt(a, lim), ntt(b, lim);
  for (int i = 0; i < lim; i++) a[i] = mul(a[i], b[i]);
  std::reverse(a + 1, a + lim), ntt(a, lim);
  for (int i = 0; i < lim; i++) a[i] = mul(a[i], inv_lim);
  // print(a, 0, lim - 1), putchar('\n');
  for (int i = mid + 1; i <= r; i++) f[i] = inc(f[i], a[i - l - 1]);
  solve(mid + 1, r);
}

void main() {
  scanf("%s", s + 1);
  n = strlen(s + 1) + 1;
  for (int i = 1; i <= n; i++) cnt[i] = cnt[i - 1] + (s[i] == '>');
  fac[0] = fac[1] = inv_fac[0] = inv_fac[1] = 1;
  for (int i = 2; i <= n; i++) fac[i] = mul(fac[i - 1], i);
  for (int i = 2; i <= n; i++) inv_fac[i] = mul(mod - mod / i, inv_fac[mod % i]);
  for (int i = 2; i <= n; i++) inv_fac[i] = mul(inv_fac[i - 1], inv_fac[i]);
  solve(0, n);
  // print(f, 0, n);
  // print(g, 0, n);
  for (int i = 0; i <= n; i++) f[i] = mul(f[i], fac[i]);
  // print(f, 0, n);
  print(f[n], '\n');
}

} signed main() {
#ifdef memset0
  freopen("1.in", "r", stdin);
#endif
  return ringo::main(), 0;
}

分治 NTT 巧妙的思路 容斥 DP 排列 DP Favorite

仅有一条评论

  1. Jack_killer Jack_killer

    您太强了。

LOJ6627 等比数列三角形
上一篇 «
LOJ573 「LibreOJ NOI Round #2」单枪匹马
» 下一篇