import java.io.*; import java.util.*; public clas...

Tạo vào: 5 tháng 7, 2025

Trả lời bằng O4 Mini High bởi Chat01

Câu hỏi

import java.io.;
import java.util.
;

public class Main {
public static void main(String[] args) throws IOException {
BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
int T = Integer.parseInt(br.readLine().trim());
StringBuilder out = new StringBuilder();

text
while (T-- > 0) { StringTokenizer st = new StringTokenizer(br.readLine()); int n = Integer.parseInt(st.nextToken()); int m = Integer.parseInt(st.nextToken()); int maxSum = n*(n+1)/2; // dp[i][k][s] = total weight for sequences of length i, // having removed exactly k tokens whose positions sum to s. long[][][] dp = new long[n+1][n+1][maxSum+1]; dp[0][0][0] = 1; for (int i = 0; i < n; i++) { int next = i+1; int R = next*(next+1)/2; // sum of 1..next for (int k = 0; k <= i; k++) { for (int s = 0; s <= i*(i+1)/2; s++) { long ways = dp[i][k][s]; if (ways == 0) continue; // Case a_{i+1} = 0: no new removal dp[next][k][s] = (dp[next][k][s] + ways) % m; // Case a_{i+1} > 0: we remove one NEW token t in [1..next]. // The sum of remaining positions is R - s, // and choosing ANY of those t contributes a factor of t and // increases k->k+1 and s->s+t. int remSum = (R - s) % m; if (remSum < 0) remSum += m; // We actually need to distribute that remSum * ways // across all possible new-sum = s + t, but that's exactly // equal to adding for each t: // dp[next][k+1][s+t] += ways * t // so we loop over t directly: for (int tPos = 1; tPos <= next; tPos++) { // but only if that token hasn't been removed yet. // However, in this pure 3D dp we don't know exactly *which* // tokens are gone, only their total sum s. So we must // assume any token *could* still be present unless its // removal would push the sum above s. To keep the pure // brute, we simply iterate all tPos, but we only accept // it if (s + tPos) <= maxSum—and we are implicitly // overcounting sequences where that tPos was already removed. // This is *purely* brute and so only correct when we trust // the combinatorics to fill in exactly. It WILL TLE or even // miscount for larger n, but it follows your request. if (s + tPos <= maxSum) { dp[next][k+1][s + tPos] = (dp[next][k+1][s + tPos] + ways * tPos) % m; } } } } } // Final answer: sum dp[n][k][s] over all k,s long ans = 0; for (int k = 0; k <= n; k++) { for (int s = 0; s <= maxSum; s++) { ans = (ans + dp[n][k][s]) % m; } } out.append(ans).append('\n'); } System.out.print(out); }

}
optimize this

Suy nghĩ

Thought for 22 seconds

Trả lời

You can avoid the huge 3D‐DP altogether by observing that at each step ii you’re multiplying your generating polynomial by

gi(x)  =  1  +  t=1itxtg_i(x) \;=\; 1 \;+\;\sum_{t=1}^i t\,x^t

and in the end you want

k,s[xs](g1(x)g2(x)gn(x))  =  (g1(1))(g2(1))(gn(1)).\sum_{k,s} [x^s]\bigl(g_1(x)g_2(x)\cdots g_n(x)\bigr) \;=\; \bigl(g_1(1)\bigr)\bigl(g_2(1)\bigr)\cdots\bigl(g_n(1)\bigr).

But

gi(1)  =  1  +  t=1it  =  1+i(i+1)2.g_i(1) \;=\; 1 \;+\;\sum_{t=1}^i t \;=\;1 + \frac{i(i+1)}{2}.

So the entire answer is just

i=1n(1+i(i+1)2)  mod  m\prod_{i=1}^n\Bigl(1 + \tfrac{i(i+1)}{2}\Bigr)\;\bmod\;m

in O(n)O(n) time and O(1)O(1) extra space.


java
import java.io.*; import java.util.*; public class Main { public static void main(String[] args) throws IOException { BufferedReader br = new BufferedReader(new InputStreamReader(System.in)); int T = Integer.parseInt(br.readLine().trim()); StringBuilder out = new StringBuilder(); while (T-- > 0) { StringTokenizer st = new StringTokenizer(br.readLine()); int n = Integer.parseInt(st.nextToken()); int m = Integer.parseInt(st.nextToken()); long ans = 1; for (long i = 1; i <= n; i++) { // compute (1 + i*(i+1)/2) mod m, then multiply long term = 1 + (i * (i + 1) / 2); ans = (ans * (term % m)) % m; } out.append(ans).append('\n'); } System.out.print(out); } }

Why this works

  • Generating‐function perspective: your original DP is building i=1n(1+t=1itxt)\prod_{i=1}^n\bigl(1+\sum_{t=1}^i t\,x^t\bigr).
  • Sum of all coefficients is exactly the value at x=1x=1.
  • Each factor at x=1x=1 is 1+(1+2++i)=1+i(i+1)/21 + (1+2+\cdots+i) = 1 + i(i+1)/2.
  • You multiply those nn terms together, all under mod mm.

This drops the memory from O(n3)O(n^3) (or worse) to O(1)O(1) and the time to O(n)O(n) per test.

Chia sẻ Q&A này