$$\sum_{i=1}^{n} \sum_{j=1}^{m} (n \% i) \times (m \% j) \ \ \ (i \not= j)$$

$$= (\sum_{i=1}^{n} n \% i) \times (\sum_{j=1}^{m} m \% j) - \sum_{i=1}^{\min(n, m)} (n \% i) \times (m \% j)$$

可以转化成 $A \times B - C$ 的形式,分别来求。

其中求 $A$ 和 $B$ 的方式是一样的,可以数论分块,也可以直接打表找规律。鉴于这部分不难,笔者写了后者,而描述起来而笔者又非常懒因此忽略求 $A$ 、 $B$ 直接讲求 $C$ 。当然你也可以通过推 $C$ 的方式自己推一下 $A$ 、 $B$ 。

$$C = \sum_{i=1}^{\min(n, m)} (n \% i) \times (m \% i)$$

$$= \sum_{i=1} ^ {\min(n, m)} (n - i \times \lfloor \frac {n} {i} \rfloor) \times (m - i \times \lfloor \frac {m} {i} \rfloor )$$

$$= \sum_{i=1} ^ {\min(n, m)} n \times m - m \times i \times \lfloor \frac {n} {i} \rfloor - n \times i \times \lfloor \frac {m} {i} \rfloor + i ^ 2 \times \lfloor \frac {m} {i} \rfloor \lfloor \frac {m} {i} \rfloor $$

数论分块即可。

鉴于笔者写代码时思路非常混乱,下面仅供对拍:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
// ==============================
// author: memset0
// website: https://memset0.cn
// ==============================
#include <bits/stdc++.h>
#define ll long long
#define rep(i,l,r) for (int i = l; i <= r; i++)
#define getc(x) getchar(x)
#define putc(x) putchar(x)

template <typename T> inline void read(T &x) {
x = 0; register char ch; register bool fl = 0;
while (ch = getc(), ch < 48 || 57 < ch) fl ^= ch == '-'; x = (ch & 15);
while (ch = getc(), 47 < ch && ch < 58) x = (x << 1) + (x << 3) + (ch & 15);
if (fl) x = -x;
}
template <typename T> inline void print(T x, char c = '\n') {
static int buf[40];
if (x == 0) { putc('0'); putc(c); return; }
if (x < 0) putc('-'), x = -x;
for (buf[0] = 0; x; x /= 10) buf[++buf[0]] = x % 10 + 48;
while (buf[0]) putc((char) buf[buf[0]--]);
putc(c);
}

const ll p = 19940417;

ll n, m;

void update(ll &x) {
x = (x % p + p) % p;
}

ll sum(ll n) {
ll a = n, b = n - 1;
if (a % 2 == 0) a /= 2;
else b /= 2;
a %= p, b %= p;
return a * b % p;
}

ll sum(ll l, ll r) {
ll a = l + r, b = r - l + 1;
if (a % 2 == 0) a /= 2;
else b /= 2;
a %= p, b %= p;
return a * b % p;
}

ll sum2(ll n) {
ll a = n, b = n + 1, c = (n << 1) + 1;
if (a % 2 == 0) a /= 2;
else if (b % 2 == 0) b /= 2;
else c /= 2;
if (a % 3 == 0) a /= 3;
else if (b % 3 == 0) b /= 3;
else c /= 3;
a %= p, b %= p, c %= p;
return a * b % p * c % p;
}

ll solve(ll n) {
ll m, ans = 0, sqn = sqrt(n), x, i, t;
for (i = 2, x = n; x > sqn; i++) {
t = m = (n - (n / i)) - (n - x);
m %= p;
ans += (n % x) % p * (m % p) % p, update(ans);
ans += sum(m) % p * ((i - 1) % p) % p, update(ans);
x -= t;
}
for (i = 1; i <= x; i++)
ans += (n % i) % p, update(ans);
return ans;
}

ll solve2(ll n, ll m) {
if (n > m) std::swap(n, m);
ll ans = n * n % p * m % p;
for (ll l = 1, r; l <= n; l = r + 1) {
r = std::min(n / (n / l), m / (m / l));
ans -= n % p * (m / l % p) % p * sum(l, r) % p, update(ans);
ans -= m % p * (n / l % p) % p * sum(l, r) % p, update(ans);
ans += (n / l % p) * (m / l % p) % p * ((sum2(r) - sum2(l - 1) + p) % p) % p, update(ans);
}
return ans;
}

int main() {
read(n), read(m);
print((solve(n) * solve(m) % p - solve2(n, m) + p) % p);
return 0;
}