TopCoder SRM444 Div1Hard

問題

4は不吉な数だから、4のパターンを含む数を避けたい。具体的には、次の条件を満たす正の整数の数を求めたい。

  • その整数は、高々N桁である。
  • その整数は、「4444」という並びを含まない。例えば、「45444474」とかは禁止である。
  • その整数の桁数は、「10以上で、各桁が4のみからなる」ある整数の倍数であってはならない。すなわち、桁数が「44, 444, 4444, …」の倍数であってはならない。

解法

上の位から順に埋めていくことを考える。今、「最後に4が何個連続しているか」に注目して場合の数を求めると、漸化式が立てられることがわかる。これは線形の漸化式だから、行列を使ってすぐに求めることが出来る。掛けあわせる行列をXとおくと、(1 + X + X^2 + … + X^(N-1)) * (8, 1, 0, 0) (実際は縦長)の各要素の和を求めればいいことがわかる。…の中に「44」とか「132」とかを含んではいけないが、後で包除原理をつかって追いだしてやればよい。

コード

#include <cstdio>

#define MOD 1000000007
#define ADD(X,Y) ( (X) = ((X) + (Y))%MOD )

struct matrix
{
	long long val[4][4];
};

void print(matrix& m)
{
	for(int i=0;i<4;i++){
		for(int j=0;j<4;j++) printf("%lld ", m.val[i][j]);
		puts("");
	}
}

matrix operator*(const matrix& a, const matrix& b)
{
	matrix ret;
	for(int i=0;i<4;i++){
		for(int j=0;j<4;j++){
			ret.val[i][j] = 0;
			for(int k=0;k<4;k++){
				ret.val[i][j] += a.val[i][k] * b.val[k][j];
			}
			ret.val[i][j] %= MOD;
		}
	}
	return ret;
}

matrix operator+(const matrix& a, const matrix& b)
{
	matrix ret;
	for(int i=0;i<4;i++){
		for(int j=0;j<4;j++){
			ret.val[i][j] = (a.val[i][j] + b.val[i][j]) % MOD;
		}
	}
	return ret;
}

matrix pow(matrix X, long long p)
{
	if(p==1) return X;
	matrix tmp = pow(X, p/2);
	tmp = tmp * tmp;
	if(p%2==1) tmp = tmp * X;
	
	return tmp;
}

matrix powsum(matrix X, long long p)
{
	if(p==1) return X;
	if(p%2==1){
		matrix tmp = powsum(X, p-1) * X;
		return tmp + X;
	}
	matrix tmp = powsum(X, p/2);
	tmp = tmp + tmp * pow(X, p/2);
	return tmp;
}

long long calc(matrix& m)
{
	return 8 * (m.val[0][0] + m.val[1][0] + m.val[2][0] + m.val[3][0]) + (m.val[0][1] + m.val[1][1] + m.val[2][1] + m.val[3][1]);
}

long long gcd(long long x, long long y)
{
	if(y==0) return x;
	return gcd(y, x%y);
}

const int b[4][4] = {
	{9, 9, 9, 9},
	{1, 0, 0, 0},
	{0, 1, 0, 0},
	{0, 0, 1, 0}
};

long long bad[9] = {
	44LL,
	444LL,
	4444LL,
	44444LL,
	444444LL,
	4444444LL,
	44444444LL,
	444444444LL,
	4444444444LL
};

class AvoidFour
{
	matrix bas;
public:
	int count(long long N)
	{
		if(N==1) return 9;
		//N--;
		
		for(int i=0;i<4;i++)
			for(int j=0;j<4;j++) bas.val[i][j] = b[i][j];
			
		matrix tmp;
		long long ret = 9;
		
		tmp = powsum(bas, N-1);
		print(tmp);
		ADD(ret, calc(tmp));
		
		matrix pb, t2;
		
		for(int i=1;i<512;i++){
			long long lcm = 1;
			int sign = 0;
			for(int j=8;j>=0;j--)	if(i&(1<<j)){
				long long g = gcd(lcm, bad[j]);
				lcm *= bad[j] / g;
				sign ^= 1;
				if(lcm > N) goto nex;
			}
			
			//printf("%lld %d\n", lcm, sign);
			pb = pow(bas, lcm);
			if(N/lcm>1){
				t2 = powsum(pb, N / lcm - 1);
				t2 = t2 * pow(bas, lcm-1) + pow(bas, lcm-1);
			}else{
				t2 = pow(bas, lcm-1);
			}
			
			if(sign==0) ADD(ret, calc(t2));
			else ADD(ret, (MOD - calc(t2)%MOD));
nex:
			continue;
		}
		
		return (int)(ret % MOD);
	}
};