TopCoder SRM444 Div1Hard
問題
4は不吉な数だから、4のパターンを含む数を避けたい。具体的には、次の条件を満たす正の整数の数を求めたい。
- その整数は、高々N桁である。
- その整数は、「4444」という並びを含まない。例えば、「45444474」とかは禁止である。
- その整数の桁数は、「10以上で、各桁が4のみからなる」ある整数の倍数であってはならない。すなわち、桁数が「44, 444, 4444, …」の倍数であってはならない。
解法
上の位から順に埋めていくことを考える。今、「最後に4が何個連続しているか」に注目して場合の数を求めると、漸化式が立てられることがわかる。これは線形の漸化式だから、行列を使ってすぐに求めることが出来る。掛けあわせる行列をXとおくと、(1 + X + X^2 + … + X^(N-1)) * (8, 1, 0, 0) (実際は縦長)の各要素の和を求めればいいことがわかる。…の中に「44」とか「132」とかを含んではいけないが、後で包除原理をつかって追いだしてやればよい。
コード
#include <cstdio> #define MOD 1000000007 #define ADD(X,Y) ( (X) = ((X) + (Y))%MOD ) struct matrix { long long val[4][4]; }; void print(matrix& m) { for(int i=0;i<4;i++){ for(int j=0;j<4;j++) printf("%lld ", m.val[i][j]); puts(""); } } matrix operator*(const matrix& a, const matrix& b) { matrix ret; for(int i=0;i<4;i++){ for(int j=0;j<4;j++){ ret.val[i][j] = 0; for(int k=0;k<4;k++){ ret.val[i][j] += a.val[i][k] * b.val[k][j]; } ret.val[i][j] %= MOD; } } return ret; } matrix operator+(const matrix& a, const matrix& b) { matrix ret; for(int i=0;i<4;i++){ for(int j=0;j<4;j++){ ret.val[i][j] = (a.val[i][j] + b.val[i][j]) % MOD; } } return ret; } matrix pow(matrix X, long long p) { if(p==1) return X; matrix tmp = pow(X, p/2); tmp = tmp * tmp; if(p%2==1) tmp = tmp * X; return tmp; } matrix powsum(matrix X, long long p) { if(p==1) return X; if(p%2==1){ matrix tmp = powsum(X, p-1) * X; return tmp + X; } matrix tmp = powsum(X, p/2); tmp = tmp + tmp * pow(X, p/2); return tmp; } long long calc(matrix& m) { return 8 * (m.val[0][0] + m.val[1][0] + m.val[2][0] + m.val[3][0]) + (m.val[0][1] + m.val[1][1] + m.val[2][1] + m.val[3][1]); } long long gcd(long long x, long long y) { if(y==0) return x; return gcd(y, x%y); } const int b[4][4] = { {9, 9, 9, 9}, {1, 0, 0, 0}, {0, 1, 0, 0}, {0, 0, 1, 0} }; long long bad[9] = { 44LL, 444LL, 4444LL, 44444LL, 444444LL, 4444444LL, 44444444LL, 444444444LL, 4444444444LL }; class AvoidFour { matrix bas; public: int count(long long N) { if(N==1) return 9; //N--; for(int i=0;i<4;i++) for(int j=0;j<4;j++) bas.val[i][j] = b[i][j]; matrix tmp; long long ret = 9; tmp = powsum(bas, N-1); print(tmp); ADD(ret, calc(tmp)); matrix pb, t2; for(int i=1;i<512;i++){ long long lcm = 1; int sign = 0; for(int j=8;j>=0;j--) if(i&(1<<j)){ long long g = gcd(lcm, bad[j]); lcm *= bad[j] / g; sign ^= 1; if(lcm > N) goto nex; } //printf("%lld %d\n", lcm, sign); pb = pow(bas, lcm); if(N/lcm>1){ t2 = powsum(pb, N / lcm - 1); t2 = t2 * pow(bas, lcm-1) + pow(bas, lcm-1); }else{ t2 = pow(bas, lcm-1); } if(sign==0) ADD(ret, calc(t2)); else ADD(ret, (MOD - calc(t2)%MOD)); nex: continue; } return (int)(ret % MOD); } };