beet's soil

競プロのことなど

2013 TCO Algorithm Round 3A - TrickyInequality

すげー

↓の最後の問題について詳しく書きます
beet-aizu.hatenablog.com


求めるものは、
 \displaystyle \sum_{x_1=1}^{t} \sum_{x_2=1}^{t} \ldots \sum_{x_n=1}^{t} \binom{s-\sum_{i=1}^{n}{x_i}}{m-n}

一つ箱を追加して「s以下」を「ちょうどs」に言い換える
追加した箱以外には一つ以上ボールを入れないといけないので
 s- (\sum_{i=1}^{n}{x_i}) - (m-n) 個のボールを  m-n + 1 個の箱に入れる重複組み合わせになる。


ここで、  Y = \sum_{i=1}^{n}{x_i} と置くと、
  f(Y) = \binom{s-Y}{m-n} = \cfrac{(s-Y)(s-Y-1)\ldots(s-Y-(m-n)+1)}{(m-n)!} は Y のm-n次多項式になる。
この多項式 O((m-n)^2) くらいで陽に計算できる
 f(Y) = \displaystyle \sum_{k=0}^{m-n} a_k Y^k とおく。


最初の式に対して式変形を行うと
 \displaystyle \sum_{x_1=1}^{t} \sum_{x_2=1}^{t} \ldots \sum_{x_n=1}^{t} \binom{s-\sum_{i=1}^{n}{x_i}}{m-n} = \displaystyle \sum_{k=0}^{m-n} a_k \left (\sum_{x_1=1}^{t} \sum_{x_2=1}^{t} \ldots \sum_{x_n=1}^{t} Y^k \right) = \displaystyle \sum_{k=0}^{m-n} a_k \left (\sum_{x_1=1}^{t} \sum_{x_2=1}^{t} \ldots \sum_{x_n=1}^{t} \left (\sum_{i=1}^{n} x_i \right )^k \right )
となる。


 \displaystyle \left (\sum_{i=1}^{n} x_i \right )^k は、多項係数の性質から  \displaystyle \sum_{d_1 + d_2 + \ldots + d_n = k} k! \prod_{i=1}^{n} \frac{{x_i}^{d_i}}{d_i!}  である。

したがって、  \displaystyle \sum_{x_1=1}^{t} \sum_{x_2=1}^{t} \ldots \sum_{x_n=1}^{t} \left (\sum_{i=1}^{n} x_i \right )^k  = \displaystyle \sum_{x_1=1}^{t} \sum_{x_2=1}^{t} \ldots \sum_{x_n=1}^{t}  \sum_{d_1 + d_2 + \ldots + d_n = k} k! \prod_{i=1}^{n} \frac{{x_i}^{d_i}}{d_i!} = \sum_{d_1 + d_2 + \ldots + d_n = k} k! \prod_{i=1}^{n} \left( \sum_{x_i=1}^{t} \frac{{x_i}^{d_i}}{d_i!} \right)

これは  \displaystyle \left( \sum_{x=1}^{t} \frac{{x}^{d}}{d!} \right)  d (0 \le d \le m-n) ごとに計算したものを多項式として捉え、それを繰り返し自乗法を用いて  n 乗することで  O((m-n)^2 \log n ) で求められる。

あとはこれを実装すれば終わり 全体の計算量は  O((m-n)^2 \log n ) になる。

ソースコード

modintは省略

class TrickyInequality {
public:
  using ll = long long;
  using M = Mint<int>;
  vector<M> mul(vector<M> as,vector<M> bs){
    int sz=as.size()+bs.size()-1;
    vector<M> cs(sz,0);
    for(int i=0;i<(int)as.size();i++)
      for(int j=0;j<(int)bs.size();j++)
        cs[i+j]+=as[i]*bs[j];
    return cs;
  }

  int countSolutions(long long s, int t, int n, int m) {
    int d=m-n;

    vector<M> vs({1});
    for(int i=0;i<d;i++){
      vector<M> ws;
      ws.emplace_back(s-i);
      ws.emplace_back(-M(1));
      vs=mul(vs,ws);
    }

    vector<M> sm(d+1,0);
    for(int i=1;i<=t;i++){
      M po(1);
      for(int j=0;j<=d;j++){
        sm[j]+=po;
        po*=M(i);
      }
    }

    {
      M res{1};
      for(int j=1;j<=d;j++){
        res*=M(j);
        sm[j]/=res;
      }

      for(int j=0;j<=d;j++) vs[j]/=res;
    }

    vector<M> dp(d+1,M(0));
    dp[0]=M(1);
    {
      int p=n;
      while(p){
        if(p&1) dp=mul(dp,sm);
        sm=mul(sm,sm);
        p>>=1;
        dp.resize(d+1);
        sm.resize(d+1);
      }
    }

    M ans{vs[0]*dp[0]};
    {
      M res{1};
      for(int j=1;j<=d;j++){
        res*=M(j);
        ans+=vs[j]*dp[j]*res;
      }
    }

    return ans.v;
  }
};