Tree Swapping (Hard) This is the hard version of t...

Erstellt am: 24. Juni 2026

Beantwortet mit GPT-5.5 Thinking von Chat01

Frage

Tree Swapping (Hard)
This is the hard version of the problem. Here, the sum of
N
N is

2

10
5
≤2⋅10
5
.

You are given a tree with
N
N nodes labelled
1
1 through
N
N.

For a given pair of binary strings
S
S and
T
T, we define
f
(
S
,
T
)
f(S,T) as follows:

In one operation, you can swap
S
u
S
u

and
S
v
S
v

if
(
u
,
v
)
(u,v) are adjacent in the tree.
f
(
S
,
T
)
f(S,T) is defined as the minimum swaps necessary to make
S

T
S=T.
In case it is not possible to make
S

T
S=T,
f
(
S
,
T
)
f(S,T) would be defined as
0
0.
You are given a string
S
S where some characters are
?
? instead, indicating that they are not fixed yet.

Over all possible binary strings
S
S (formed by replacing
?
? by
0
0 or
1
1), find the sum of
f
(
S
,
T
)
f(S,T) modulo
998244353
998244353.

Input Format
The first line of input will contain a single integer
T
T, denoting the number of test cases.
Each test case consists of multiple lines of input.
The first line contains a single integer
N
N.
The next
N

1
N−1 lines contain
2
2 integers each -
u
u and
v
v, representing an edge.
The next line contains a string
S
S.
The next line contains a binary string
T
T.
Output Format
For each test case, output the sum of
f
(
S
,
T
)
f(S,T) over all valid
S
S modulo
998244353
998244353.

Constraints
1

T

10
4
1≤T≤10
4

2

N

2

10
5
2≤N≤2⋅10
5

1

u
,
v

N
1≤u,v≤N
S
i

{
0
,
1
,
?
}
S
i

∈{0,1,?}
T
i

{
0
,
1
}
T
i

∈{0,1}
The sum of
N
N over all test cases does not exceed
2

10
5
2⋅10
5
.
Sample 1:
Input
Output
3
2
1 2
??
01
3
1 2
2 3
0??
111
5
1 2
1 3
1 4
1 5
0??1?
10101
1
0
7
Explanation:
Test Case 1: There are
4
4 possible strings:

S

00
,
11
S=00,11 :
f
(
S
,
T
)

0
f(S,T)=0 because it is impossible to make
S

T
S=T.
S

01
S=01 :
f
(
S
,
T
)

0
f(S,T)=0 because
S

T
S=T.
S

10
S=10 :
f
(
S
,
T
)

1
f(S,T)=1 because you can do one swap of
S
1
S
1

and
S
2
S
2

.
this is the problem give me the soln in the cpp

Gedanke

undefined

Gedanke

undefined

Python

python
from math import comb MOD=998244353 def H_direct(n,r,b,m): s=0 for x in range(0,b+1): if 0<=r-x<=n-b: s+=comb(b,x)*comb(n-b,r-x)*abs(x-m) return s def H_formula(n,r,b,m): total=comb(n,r) tx=b*comb(n-1,r-1) if r>0 else 0 U=UX=0 for x in range(0,b+1): if x>m and 0<=r-x<=n-b: w=comb(b,x)*comb(n-b,r-x) U+=w UX+=x*w return m*total - tx + 2*(UX-m*U) for n in range(1,10): for r in range(n+1): for b in range(n+1): for m in range(-2,b+3): d=H_direct(n,r,b,m) f=H_formula(n,r,b,m) if d!=f: print("bad",n,r,b,m,d,f); raise SystemExit print("ok")

Python

python
from math import factorial def transform_s(n,r,b,m): # raw rational with integers numerator? multiply by facs L=b-m # S_less y<L S=SX=0 for y in range(0, min(L,b)+1): # y<L actually range(L), but y<=b if y>=L: break z=b-y if 0<=z<=r and 0<=y<=n-r: raw=1/(factorial(y)*factorial(n-r-y)*factorial(z)*factorial(r-z)) S+=raw SX+=z*raw factor=factorial(b)*factorial(n-b) return S*factor, SX*factor def upper_count(n,r,b,m): U=UX=0 for x in range(0,b+1): if x>m and 0<=r-x<=n-b: w=comb(b,x)*comb(n-b,r-x) U+=w; UX+=x*w return U,UX for n in range(1,10): for r in range(n+1): for b in range(n+1): for m

Python

python
from math import factorial def transform_s(n,r,b,m): # raw rational with integers numerator? multiply by facs L=b-m # S_less y<L S=SX=0 for y in range(0, min(L,b)+1): # y<L actually range(L), but y<=b if y>=L: break z=b-y if 0<=z<=r and 0<=y<=n-r: raw=1/(factorial(y)*factorial(n-r-y)*factorial(z)*factorial(r-z)) S+=raw SX+=z*raw factor=factorial(b)*factorial(n-b) return S*factor, SX*factor def upper_count(n,r,b,m): U=UX=0 for x in range(0,b+1): if x>m and 0<=r-x<=n-b: w=comb(b,x)*comb(n-b,r-x) U+=w; UX+=x*w return U,UX for n in range(1,10): for r in range(n+1): for b in range(n+1): for m

Gedanke

Final approach and C++ implementation

I need to provide a final solution with explanation and C++ code, making sure to mention the complexity. I'll avoid making it too complex, and I'll include the "hard version" with NTT block decomposition explained in an understandable way. The code will ensure we handle MOD correctly when necessary, subtracting only once to keep the sum below MOD.

Thought for 6m 40s

Antwort

Below is a full C++17 solution for the hard version.

Key ideas:

  1. For a fixed completed S, minimum adjacent swaps on a tree equals edges e#1S(subtree)#1T(subtree)\sum_{\text{edges } e} |\#1_S(\text{subtree})-\#1_T(\text{subtree})|
  2. Only completions with the same total number of 1s as T are possible.
  3. For every edge, its contribution depends only on:
    • b = number of ? in that side,
    • m = required selected ? count for balance.
  4. We need many hypergeometric prefix sums, so the code uses sqrt/block decomposition + NTT modulo 998244353.
cpp
#include <bits/stdc++.h> using namespace std; const int MOD = 998244353; const int G = 3; int modpow(long long a, long long e) { long long r = 1; while (e) { if (e & 1) r = r * a % MOD; a = a * a % MOD; e >>= 1; } return (int)r; } void ntt(vector<int>& a, bool invert) { int n = (int)a.size(); for (int i = 1, j = 0; i < n; i++) { int bit = n >> 1; for (; j & bit; bit >>= 1) j ^= bit; j ^= bit; if (i < j) swap(a[i], a[j]); } for (int len = 2; len <= n; len <<= 1) { int wlen = modpow(G, (MOD - 1) / len); if (invert) wlen = modpow(wlen, MOD - 2); for (int i = 0; i < n; i += len) { long long w = 1; for (int j = 0; j < len / 2; j++) { int u = a[i + j]; int v = (int)(a[i + j + len / 2] * w % MOD); int x = u + v; if (x >= MOD) x -= MOD; int y = u - v; if (y < 0) y += MOD; a[i + j] = x; a[i + j + len / 2] = y; w = w * wlen % MOD; } } } if (invert) { int invn = modpow(n, MOD - 2); for (int &x : a) { x = (int)((long long)x * invn % MOD); } } } struct Query { int b; int L; int m; }; int main() { ios::sync_with_stdio(false); cin.tie(nullptr); const int MAXN = 200000; vector<int> fact(MAXN + 1), invfact(MAXN + 1); fact[0] = 1; for (int i = 1; i <= MAXN; i++) { fact[i] = (long long)fact[i - 1] * i % MOD; } invfact[MAXN] = modpow(fact[MAXN], MOD - 2); for (int i = MAXN; i >= 1; i--) { invfact[i - 1] = (long long)invfact[i] * i % MOD; } auto C = [&](int n, int k) -> int { if (n < 0 || k < 0 || k > n) return 0; return (long long)fact[n] * invfact[k] % MOD * invfact[n - k] % MOD; }; int tc; cin >> tc; while (tc--) { int N; cin >> N; vector<vector<int>> adj(N + 1); for (int i = 0; i < N - 1; i++) { int u, v; cin >> u >> v; adj[u].push_back(v); adj[v].push_back(u); } string S, T; cin >> S >> T; vector<int> parent(N + 1, 0), order; order.reserve(N); parent[1] = -1; order.push_back(1); for (int i = 0; i < (int)order.size(); i++) { int v = order[i]; for (int to : adj[v]) { if (to == parent[v]) continue; parent[to] = v; order.push_back(to); } } vector<int> subQ(N + 1), subFixedOne(N + 1), subTargetOne(N + 1); for (int i = 1; i <= N; i++) { subQ[i] = (S[i - 1] == '?'); subFixedOne[i] = (S[i - 1] == '1'); subTargetOne[i] = (T[i - 1] == '1'); } for (int i = N - 1; i > 0; i--) { int v = order[i]; int p = parent[v]; subQ[p] += subQ[v]; subFixedOne[p] += subFixedOne[v]; subTargetOne[p] += subTargetOne[v]; } int q = subQ[1]; int fixedOnes = subFixedOne[1]; int targetOnes = subTargetOne[1]; int R = targetOnes - fixedOnes; if (R < 0 || R > q) { cout << 0 << '\n'; continue; } bool useZeros = (R > q - R); int r = useZeros ? q - R : R; vector<Query> queries; queries.reserve(N - 1); int maxB = 0; for (int i = 1; i < N; i++) { int v = order[i]; int b = subQ[v]; int m = subTargetOne[v] - subFixedOne[v]; // Use the smaller side of the edge. if (b > q - b) { b = q - b; m = R - m; } // If selecting zeros is better than selecting ones. if (useZeros) { m = b - m; } maxB = max(maxB, b); queries.push_back({b, 0, m}); } long long ans = 0; // Only one possible assignment of the selected type. if (r == 0) { for (auto &qu : queries) { ans += llabs((long long)qu.m) % MOD; if (ans >= MOD) ans -= MOD; } cout << ans % MOD << '\n'; continue; } int M = maxB + 1; for (auto &qu : queries) { long long L = (long long)qu.b - qu.m; if (L < 0) L = 0; if (L > M) L = M; qu.L = (int)L; } int totalWays = C(q, r); int chooseX = C(q - 1, r - 1); vector<int> P(M); for (int y = 0; y < M; y++) { P[y] = (long long)invfact[y] * invfact[q - r - y] % MOD; } int qLen = min(r, maxB) + 1; vector<int> Q(qLen), QX(qLen); for (int z = 0; z < qLen; z++) { Q[z] = (long long)invfact[z] * invfact[r - z] % MOD; QX[z] = (long long)z * Q[z] % MOD; } int queryCount = max(1, (int)queries.size()); long double opt = sqrt( (long double)max(1, maxB) * max(1, qLen) * 18.0L / queryCount ); int B = (int)opt; if (B < 64) B = 64; if (B > 1600) B = 1600; if (B > M) B = M; int blocks = (M + B - 1) / B; vector<vector<int>> groups(blocks + 1); for (int i = 0; i < (int)queries.size(); i++) { groups[queries[i].L / B].push_back(i); } int len = 1; while (len < B + qLen - 1) len <<= 1; vector<int> qhat(len), qxhat(len); for (int i = 0; i < qLen; i++) { qhat[i] = Q[i]; qxhat[i] = QX[i]; } ntt(qhat, false); ntt(qxhat, false); vector<int> curS(M, 0), curX(M, 0); vector<int> fa(len), fs(len), fx(len); for (int blockId = 0; blockId <= blocks; blockId++) { int tailStart = blockId * B; for (int qi : groups[blockId]) { Query &qu = queries[qi]; int b = qu.b; int L = qu.L; long long s = curS[b]; long long ux = curX[b]; int end = min(L, M); for (int y = tailStart; y < end; y++) { int z = b - y; if (0 <= z && z < qLen) { int val = (long long)P[y] * Q[z] % MOD; s += val; if (s >= MOD) s -= MOD; int valx = (long long)P[y] * QX[z] % MOD; ux += valx; if (ux >= MOD) ux -= MOD; } } int factor = (long long)fact[b] * fact[q - b] % MOD; int U = s * factor % MOD; int UX = ux * factor % MOD; int totalX = (long long)b * chooseX % MOD; int mm = qu.m % MOD; if (mm < 0) mm += MOD; long long h = ((long long)mm * totalWays - totalX + 2LL * ((UX - (long long)mm * U) % MOD)) % MOD; if (h < 0) h += MOD; ans += h; ans %= MOD; } if (blockId == blocks) break; int start = blockId * B; int blockLen = min(B, M - start); fill(fa.begin(), fa.end(), 0); for (int i = 0; i < blockLen; i++) { fa[i] = P[start + i]; } ntt(fa, false); for (int i = 0; i < len; i++) { fs[i] = (long long)fa[i] * qhat[i] % MOD; fx[i] = (long long)fa[i] * qxhat[i] % MOD; } ntt(fs, true); ntt(fx, true); int maxK = min(blockLen + qLen - 2, maxB - start); for (int k = 0; k <= maxK; k++) { int b = start + k; curS[b] += fs[k]; if (curS[b] >= MOD) curS[b] -= MOD; curX[b] += fx[k]; if (curX[b] >= MOD) curX[b] -= MOD; } } cout << ans % MOD << '\n'; } return 0; }

Complexity is roughly:

O(N+QBNTT+NB)O\left(N + \frac{Q}{B}\cdot \text{NTT} + N\cdot B\right)

with adaptive block size, and works for total N ≤ 2e5.

Diesen Q&A teilen