TopCoder SRM440 Div1Hard
問題
約数に1以外の平方数を含まないような数を「square-free number」と呼ぶ。各要素がsquare-free numberで、要素の積もsquare-free numberであるような整数の集合は、square-free setである。1〜K個の要素を含み、各要素が2〜Nの範囲にあるようなsquare-free setの数を求めよ。
解法
square-free setの要素の積を素因数分解してみると、同じ素数は高々1回しか現れない。これを利用して、「N以下の素数が並んでる中から適当に積がN以下になるように取るという操作をK回以下行う」場合の数を求めればいいことがわかる。
ここで、500以下とはいえ素数はたくさんあるので、すべての素数に対してビットDPをかけるのは非現実的である。しかし、素数を「2,3,5,7,11,13,17,19,23」と「29以上」の2グループに分けると、後者のグループの中の数同士はくっつかないことがわかる。(29*31>500より。また、実際は23も後者に含めてよい)前者のグループの素数に対してはビットDPを行い、後者のグループの素数に対しては何個使ったかだけ持っておけばよい。また、今までに何個の数を作ったかも覚えておく。
後者のグループの残り個数だけ覚えておいても、「N=70において29は2とはくっつくけど3とはくっつかない」とかそういうことを区別する必要が出てくる。しかし、ビットDPグループの積が大きいものから考えていくと、小さいものなら後者のグループの素数は大きいもので使えた物すべて使えるため、何個使ったかを見れば解けることがわかる。
最後に、「後者のグループの素数単体で使う」場合を、Cなどを使って求めれば答えになる。そもそも後者のグループの素数が存在しない場合は、普通に足せば良い。
コード
#include <cstdio> #define MOD 1000000007 #define ADD(X,Y) ( (X) = ( (X) + (Y) ) % MOD ) class SquareFreeSets { int primes[500], pcnt; bool erat[501]; int C[510][510]; void make_prime(int N) { for(int i=2;i<=N;i++) erat[i] = false; pcnt = 0; for(int i=2;i<=N;i++){ if(!erat[i]){ primes[pcnt++] = i; for(int j=i*i;j<=N;j+=i) erat[j] = true; } } } void make_C() { for(int i=0;i<=500;i++) C[0][i] = 0; C[0][0] = 1; for(int j=1;j<=500;j++){ C[j][0] = 1; for(int k=1;k<=500;k++) C[j][k] = (C[j-1][k] + C[j-1][k-1]) % MOD; } } int dp[1<<9][2]; long long dp2[1<<9][10][10]; public: int countPerfect(int N, int K) { make_prime(500); make_C(); for(int i=0;i<(1<<9);i++){ dp[i][0] = dp[i][1] = 0; long long mul = 1; for(int j=0;j<9;j++) if(i&(1<<j)) mul *= primes[j]; if(mul > N) continue; dp[i][0] = 1; dp[i][1] = 0; for(int k=9;;k++) if(primes[k] * mul <= N) dp[i][1]++; else break; } int pv = 0; for(int i=0;i<pcnt;i++) if(primes[i] <= N) pv++; pv -= 9; long long ret = 0; for(int j=0;j<(1<<9);j++){ for(int k=0;k<=9;k++){ for(int l=0;l<=9;l++){ dp2[j][k][l] = 0; } } } dp2[0][0][0] = 1; for(int i=N;i>=2;i--){ int pos = -1; for(int j=0;j<(1<<9);j++){ long long mul = 1; for(int k=0;k<9;k++) if(j&(1<<k)){ mul *= primes[k]; if(mul>N) mul = 0; } if(mul==i) pos = j; } //printf("%d\n", pos); if(pos==-1) continue; for(int j=0;j<(1<<9);j++) if((j&pos)==0){ for(int k=0;k<=9;k++){ for(int l=0;l<=9;l++) if(dp2[j][k][l]){ if(l<9) ADD(dp2[j | pos][k][l+1], dp2[j][k][l] * dp[pos][0]); if(k<9&&l<9&&dp[pos][1]>k) ADD(dp2[j | pos][k+1][l+1], dp2[j][k][l] * (dp[pos][1] - k)); } } } } for(int i=0;i<(1<<9);i++){ for(int j=0;j<=9;j++){ for(int k=0;k<=9;k++){ if(pv>0){ for(int l=0;l<=K-k;l++) if(dp2[i][j][k]){ ADD(ret, dp2[i][j][k] * C[pv-j][l]); } }else{ if(k<=K) ADD(ret, dp2[i][j][k]); } } } } ADD(ret, MOD-1); return (int)(ret % MOD); } };