TopCoder SRM440 Div1Hard

問題

約数に1以外の平方数を含まないような数を「square-free number」と呼ぶ。各要素がsquare-free numberで、要素の積もsquare-free numberであるような整数の集合は、square-free setである。1〜K個の要素を含み、各要素が2〜Nの範囲にあるようなsquare-free setの数を求めよ。

解法

square-free setの要素の積を素因数分解してみると、同じ素数は高々1回しか現れない。これを利用して、「N以下の素数が並んでる中から適当に積がN以下になるように取るという操作をK回以下行う」場合の数を求めればいいことがわかる。
ここで、500以下とはいえ素数はたくさんあるので、すべての素数に対してビットDPをかけるのは非現実的である。しかし、素数を「2,3,5,7,11,13,17,19,23」と「29以上」の2グループに分けると、後者のグループの中の数同士はくっつかないことがわかる。(29*31>500より。また、実際は23も後者に含めてよい)前者のグループの素数に対してはビットDPを行い、後者のグループの素数に対しては何個使ったかだけ持っておけばよい。また、今までに何個の数を作ったかも覚えておく。
後者のグループの残り個数だけ覚えておいても、「N=70において29は2とはくっつくけど3とはくっつかない」とかそういうことを区別する必要が出てくる。しかし、ビットDPグループの積が大きいものから考えていくと、小さいものなら後者のグループの素数は大きいもので使えた物すべて使えるため、何個使ったかを見れば解けることがわかる。
最後に、「後者のグループの素数単体で使う」場合を、Cなどを使って求めれば答えになる。そもそも後者のグループの素数が存在しない場合は、普通に足せば良い。

コード

#include <cstdio>

#define MOD 1000000007
#define ADD(X,Y) ( (X) = ( (X) + (Y) ) % MOD )

class SquareFreeSets
{
	int primes[500], pcnt;
	bool erat[501];
	int C[510][510];
	
	void make_prime(int N)
	{
		for(int i=2;i<=N;i++) erat[i] = false;
		pcnt = 0;
		for(int i=2;i<=N;i++){
			if(!erat[i]){
				primes[pcnt++] = i;
				for(int j=i*i;j<=N;j+=i) erat[j] = true;
			}
		}
	}
	
	void make_C()
	{
		for(int i=0;i<=500;i++) C[0][i] = 0;
		C[0][0] = 1;
		for(int j=1;j<=500;j++){
			C[j][0] = 1;
			for(int k=1;k<=500;k++) C[j][k] = (C[j-1][k] + C[j-1][k-1]) % MOD;
		}
	}
	
	int dp[1<<9][2];
	long long dp2[1<<9][10][10];
public:
	int countPerfect(int N, int K)
	{
		make_prime(500);
		make_C();
		
		for(int i=0;i<(1<<9);i++){
			dp[i][0] = dp[i][1] = 0;
			
			long long mul = 1;
			for(int j=0;j<9;j++) if(i&(1<<j)) mul *= primes[j];
			if(mul > N) continue;
					
			dp[i][0] = 1;
			dp[i][1] = 0;
			for(int k=9;;k++) if(primes[k] * mul <= N) dp[i][1]++; else break;
			
		}
		
		int pv = 0;
		for(int i=0;i<pcnt;i++) if(primes[i] <= N) pv++;
		pv -= 9;
		
		long long ret = 0;
		
		for(int j=0;j<(1<<9);j++){
			for(int k=0;k<=9;k++){
				for(int l=0;l<=9;l++){
					dp2[j][k][l] = 0;
				}
			}
		}
		dp2[0][0][0] = 1;
		
		for(int i=N;i>=2;i--){
			int pos = -1;
			for(int j=0;j<(1<<9);j++){
				long long mul = 1;
				for(int k=0;k<9;k++) if(j&(1<<k)){
					mul *= primes[k];
					if(mul>N) mul = 0;
				}
				if(mul==i) pos = j;
			}
			
			//printf("%d\n", pos);
			if(pos==-1) continue;
			for(int j=0;j<(1<<9);j++) if((j&pos)==0){
				for(int k=0;k<=9;k++){
					for(int l=0;l<=9;l++) if(dp2[j][k][l]){
						if(l<9) ADD(dp2[j | pos][k][l+1], dp2[j][k][l] * dp[pos][0]);
						if(k<9&&l<9&&dp[pos][1]>k) ADD(dp2[j | pos][k+1][l+1], dp2[j][k][l] * (dp[pos][1] - k));
					}
				}
			}
		}
		
		for(int i=0;i<(1<<9);i++){
			for(int j=0;j<=9;j++){
				for(int k=0;k<=9;k++){
					if(pv>0){
						for(int l=0;l<=K-k;l++) if(dp2[i][j][k]){
							ADD(ret, dp2[i][j][k] * C[pv-j][l]);
						}
					}else{
						if(k<=K) ADD(ret, dp2[i][j][k]);
					}
				}
			}
		}
		
		ADD(ret, MOD-1);
		return (int)(ret % MOD);
	}
};