TopCoder SRM445 Div1Hard

問題

大文字小文字含めて、アルファベットは52文字ある。単純な換字暗号は、元の文字列の文字ごとに、変換後の文字を決めておいて変換する暗号化方法である。今回、ある文字を、同じ文字(大文字、小文字は区別せず)に変換してはならない。たとえば、JをjやJに変換してはならない。元の文字列と、変換後の文字列が与えられたとき、何通りの換字規則があるかを1234567891で割った余りを求めよ。

解法

まず、文字列に現れる文字については対応は一意に決まる。まず、それらに対して矛盾が生じてないか(異なる文字が同じ文字に対応してないか、同じ文字が異なる文字に変換されてないか、同じ文字に変換されてないか)を確かめ、矛盾があれば0を返す。
注目するのは現れない文字である。変換対応表(ある文字からある文字へ変換が許されるか許されないかを書いた表)を想像すると、何箇所か、「同じ文字に変換しちゃいけないルール」によって×が書いてある。ここで求めたいのは、変化対応表のすべての行、列に1個ずつ○を書く方法である(当然×が書いてあるところには書いちゃいけない)。ここで、×は同じ文字への変換の場所にしか現れない、すなわちあまり現れない。すると、包除原理が使えそうに思えてくる。求める答えは、「(-1^i * i箇所の×の上に○を書いて、他は全部の場所に○を書いていい場合の数)の総和」になる。「他は全部の場所に○を書いていい場合の数」というのは階乗で一発で求まる。また、「i箇所の×の上に○を書く場合の数」は、変換対応表の様子を見るとDPで求められることがわかる。

コード

#include <string>

using namespace std;

#define MOD 1234567891
#define ADD(X,Y) (X) = ( (X) + (Y) ) % MOD

int c2i(char c)
{
	if(c>='a') return (c-'a')*2;
	return (c-'A')*2+1;
}

class TheEncryptionDivOne
{
	int mt[52];
	long long dp[2][53];
	long long kaijo[53];
	
	int c1[26], c2[26];
	
public:
	int count(string msg, string enc)
	{
		for(int i=0;i<52;i++) mt[i] = -1;
		for(int i=0;i<26;i++) c1[i] = c2[i] = 0;
		
		for(int i=0;i<msg.size();i++){
			int v1=c2i(msg[i]), v2=c2i(enc[i]);
			if(mt[v1]==-1){
				mt[v1] = v2;
				if(v1/2==v2/2) return 0;
			}else if(mt[v1]!=v2) return 0;
		}
		for(int i=0;i<52;i++)
			for(int j=0;j<52;j++) if(i!=j){
				if(mt[i]!=-1 && mt[i]==mt[j]) return 0;
			}
			
		for(int i=0;i<52;i++){
			if(mt[i]==-1) c1[i/2]++;
		}
		for(int i=0;i<52;i++){
			int p = -1;
			for(int j=0;j<52;j++) if(mt[j]==i) p = 1;
			if(p==-1) c2[i/2]++;
		}

		for(int i=0;i<53;i++) dp[0][i] = dp[1][i] = 0;
		int t = 0;
		dp[t][0] = 1;
		
		for(int i=0;i<26;i++){
			if(c1[i]*c2[i]==0) continue;
			
			for(int j=0;j<=52;j++) dp[1-t][j] = 0;
			
			if(c1[i]==2&&c2[i]==2){
				for(int j=0;j<=52;j++){
					ADD(dp[1-t][j], dp[t][j]);
					if(j>0) ADD(dp[1-t][j], dp[t][j-1]*4);
					if(j>1) ADD(dp[1-t][j], dp[t][j-2]*2);
				}
			}else if(c1[i]==2||c2[i]==2){
				for(int j=0;j<=52;j++){
					ADD(dp[1-t][j], dp[t][j]);
					if(j>0) ADD(dp[1-t][j], dp[t][j-1]*2);
				}
			}else{
				for(int j=0;j<=52;j++){
					ADD(dp[1-t][j], dp[t][j]);
					if(j>0) ADD(dp[1-t][j], dp[t][j-1]);
				}
			}
			
			t = 1-t;
		}
		
		kaijo[0] = 1;
		for(int i=1;i<=52;i++) kaijo[i] = (kaijo[i-1]*i)%MOD;
		
		long long ret = 0;
		int sm = 0;
		for(int i=0;i<26;i++) sm += c1[i];
		
		for(int i=0;i<=sm;i++){
			if(i%2==0) ADD(ret, kaijo[sm-i] * dp[t][i]);
			else ADD(ret, MOD - (kaijo[sm-i] * dp[t][i])%MOD);
		}
		
		return ret;
	}
};