F - Three Gluttons(CODE FESTIVAL 2017 qual C)

この問題です。
F - Three Gluttons


というわけで解説と考察の道筋を書きます。

実際にした考察

「AとBの残っている寿司」と「あと何個食べるか」を引数にDPが出来そう

  • 「状態」が2^Nになるなぁ...。状態をNで表せるとO(N^3)で解けそう。
  • でも表せたとしてDPの遷移を思いつかないな...

...色々試します...
始めは必ず取らないといけないとか、最後二つは絶対に取れないとか考えましたが、そこから派生出来ない...
(真隣に依存する操作が無い為、端が決まっても嬉しくないなぁ...)

そのうち「逆順をやってみるか」と思います。
逆順で状態をNで表せるように頑張ります。
...そもそもこの問題の場合の逆順ってなんだ???(しばらく混乱)
(ちなみに考察中は逆順に「食べる」という用語で考察したんですが、書いてみたら紛らわしかったので「吐き出す」と書いておきます)

  1. けんかしないで最後まで取る=>0個のこる
  2. そこから一個取る前の状態に戻す
  3. ああ、後ろ二つは絶対に取れないんだっけ。後ろから三つ目に吐き出すか?
  4. なんで後ろから三つ目でないといけないんだ??<=ここ詰めよう
  5. DPみたいにしたいので、今吐き出したものの順位をa,b(0<=a,b<N)と置くか

Aさんが一個目にxを吐き出すとき、他の人はxを食べられないので、Bさんがxをb以上の順位にしていると困るな...
(補足:Bさんがxをbよりも上の順位にしている場合、xもB[b]も残っている状況ではBさんはxを取りたいのでけんかが起こります)
=>逆に言えば、Aさんが吐き出せるのはBさんにとってbよりも下の順位のものだけだな???
ということは...
Aにとってaより上の順位のもの:必ず吐き出すことが可能、下の順位のもの:後からは絶対に吐き出せない
(補足:a<jについてAさんはA[j]よりもA[a]を先に食べる<=>A[j]よりもA[a]を後に吐き出す)
Bも同じ
=>これなら「AとBの残っている(吐き出せる寿司)寿司」をa,bという境界を持ってやることでN^2通りで表せる!
=>よしO(N^3)DP出来た!!!

...これは罠で、DPの状態(引数)は決めることが出来ましたが、まだ遷移が出来ていません。
考えます。

吐き出せるやつを全部試すとN^2かかって間に合わないし...
...思いつかないですね。

仕方がないので「吐き出せるやつを全部試す」の方針をもう一度考えます。
...そういえばこれって、前の方のやつを全部足す形だな...
=>累積和が使えるパターンでは!??(ここは一般的なテクとして知識を持っておかないと辛い)

こういう問題の場合、僕はコードを実際に書いて、それを無理やり整形して解くことが多いので、出来そうという方針が立ったこの段階でコードを書き始めます。

DPの説明

後半にコードが書いてあるので、そちらと見比べながら見てください。

DPの定義式を考えます。
本番中僕は言語化せずに感覚的にやってしまったのですが、頑張って言語化します。

DP[a][b][taberu]を

  • 「Aはa、Bはbを直前に吐き出した」つまり「Aはaより好きなものを、Bはbより好きなものを吐き出すことが出来る」
  • 「A,B,Cはあとtaberu個吐き出す」

の条件を満たすようにCを作ったとき、作れるCの通り数
とします。
A,Bの先頭を吐き出すまでCの構成を続けるので、完成した場合であるa=0,b=0のときCの通り数は1です。
つまり
DP[0][0][0] = 1
です。

例えば、
入力例 3

6
1 2 3 4 5 6
2 1 4 3 6 5

の時は
DP[2][2][1]=20
みたいになります。
a,b,は0-indexed、taberuは個数(0個1個2個...N個)であることに注意してください。
(計算方法は後述)


「次にA,Bが吐き出すものの順位」をnext_a,next_bとします。(これは二重ループで決め打ちします)
A[i],B[i]を問題文通りに、i番目に好きな寿司とします。(コードでは0-indexed)
Aがxが何番目に好きかをN_EATED_A[x]、Bがxが何番目に好きかをN_EATED_B[x]とします。(変数の命名は突っ込まない約束です)
AがA[next_a]を吐き出す為には、BがB[next_b]よりもA[next_a]が嫌いである、つまりnext_b < N_EATED_B[A[next_a]]である必要があります。
同様にBがB[next_b]を吐き出す為にはnext_a < N_EATED_A[B[next_b]]である必要があります。
これを満たしているとき、AはA[next_a]、BはB[next_b]を吐き出します。
この時のCの通り数を計算します。

Cは自由に決められるので、吐き出すごとに生まれる制約を使ってCを少しずつ決めてやるイメージでやります。

Cが吐き出すものをA,Bと同じようにC[next_c]と仮置きします。あとN_EATED_C[x]も同じように仮置きします。
まず、A,Bの時と同じ理由で、next_c < N_EATED_C[A[next_a]] かつ next_c < N_EATED_C[B[next_b]]である必要が有ります。
なので、next_cの前にA[next_a]とB[next_b]を好きな順番で入れます。
また、これまたA,Bのときと同じ理由で、next_a < N_EATED_A[C[next_c]] かつnext_b < N_EATED_B[C[next_c]]である必要が有ります。
これを満たす為に、C[next_c]は上の条件を満たす寿司から一つ選んで吐き出します。
f:id:eiya5498513:20171023103858j:plain
この通り数が、コードの

*(3 * tabeta + 1)*(3 * tabeta + 2)*(C_tori[a_next][b_next]- tabeta*3)

に相当します。
tabetaは「今までに吐き出した数」です。(コードは本番のものなので、「吐き出す」が「食べる」になっています)
始め二つが、「A[next_a]とB[next_b]を好きな順番で入れる」の通り数です。Cの中で順番が確定しているのは「今までにA,B,Cが吐き出した総数」個です。そこに好きな順番でA[next_a]とB[next_b]を入れるのでこうなります。
最後は、C[next_c]の選び方の通り数です。C_tori[a_next][b_next]は「next_a < N_EATED_A[x] かつnext_b < N_EATED_B[x]」を満たすxの数です。このうち、「今までにA,B,Cが吐き出した総数」個は既に吐き出されているので、Cが新たに吐き出せるのは(C_tori[a_next][b_next]- tabeta*3)個です。

a_next,b_nextが決まっているときの遷移は、Cを適当に決める=>次の状態に移るという動作をし、これがCの構成の一つのパターンです。
これで遷移を書くと↓のようになります。(言語での説明ができなかった人の顔)

mint res = 0;
int32_t tabeta = eat_all - taberu;
for (int32_t a_next = a-1; a_next >= 0; --a_next)
for (int32_t b_next = b-1; b_next >= 0; --b_next)
{
	//↓つまり、「AがA[a_next],BがB[b_next]を食べる」というのが可能ならば
	if (b_next &lt; N_EATED_B[A[a_next]] && a_next &lt; N_EATED_A[B[b_next]] && (C_tori[a_next][b_next] > tabeta*3)) {
		//Cを適当に決める=>次の状態に移るという動作をするので
		//通り数を数えるDPと同じようにしてやった後、Cの選び方をかける
		res += func(a_next, b_next, taberu - 1)*(3 * tabeta + 1)*(3 * tabeta + 2)*(C_tori[a_next][b_next]- tabeta*3);
	}
}

↓考察ノート(本番に実際使ったもの)
f:id:eiya5498513:20171023000702j:plain

↓この段階でのコード(メモ化していないので、O(N^5)よりもさらに遅いです)
本番で使ったコードのバグを少し直しただけなので、「吐き出す」が「食べる」になっています。
mintの定義は省略してありますが、要はmodを勝手にとるintです。詳しくはこちら:C++14用mod_int - 永夜の記録

int32_t eat_all;
int32_t N_EATED_A[400];
int32_t N_EATED_B[400];
int32_t C_tori[400][400];//a,b
int32_t N;
int32_t A[400];
int32_t B[400];
mint func(int a, int b, int taberu)
{
	if (a == 0 || b == 0) {
		if (a == 0 && b == 0 && taberu == 0) {
			return 1_mi;
		}
		else {
			return 0_mi;
		}
	}
	if (taberu == 0) {
		assert(false);
		return 0_mi;
	}

	mint res = 0;
	int32_t tabeta = eat_all - taberu;
	for (int32_t a_next = a-1; a_next >= 0; --a_next)
	for (int32_t b_next = b-1; b_next >= 0; --b_next)
	{
		if (b_next < N_EATED_B[A[a_next]] && a_next < N_EATED_A[B[b_next]] && (C_tori[a_next][b_next] > tabeta*3)) {
			res += func(a_next, b_next, taberu - 1)*(3 * tabeta + 1)*(3 * tabeta + 2)*(C_tori[a_next][b_next]- tabeta*3);
		}
	}
	return func3(a, b, taberu);
}

int main()
{
	using std::endl;
	in.sync_with_stdio(false);
	out.sync_with_stdio(false);
	in.tie(nullptr);
	out.tie(nullptr);

	in >> N; eat_all = N / 3;
	for (int32_t i = 0; i < N; i++)
	{
		in >> A[i]; --A[i];
		N_EATED_A[A[i]] = i;
	}
	for (int32_t i = 0; i < N; i++)
	{
		in >> B[i]; --B[i];
		N_EATED_B[B[i]] = i;
	}


	for (int32_t a = N - 1; a >= 0; --a)
		for (int32_t b = N - 1; b >= 0; --b)
		{
			if (b == N - 1) {
				C_tori[a][b] = 0;
			}
			else {
				C_tori[a][b] = C_tori[a][b + 1];
				//B[b+1]が使えるようになる
				if (a < N_EATED_A[B[b + 1]]) {
					++C_tori[a][b];
				}
			}
			//for (int32_t i = 0; i < N; i++)
			//{
			//	if(a<N_EATED_A[i] && b<N_EATED_B[i])
			//		++C_tori[a][b];
			//}
		}

//Nよりも順位が良い全てのもの、つまり全ての寿司を吐き出せます

	out << func(N, N, eat_all) << endl;

	return 0;
}

オーダーを減らす

この段階で、O(N^5)の解法が出来たので、オーダーを減らします。
先ほども書きましたが、

for (int32_t a_next = a-1; a_next >= 0; --a_next)
for (int32_t b_next = b-1; b_next >= 0; --b_next)
{
	if (b_next < N_EATED_B[A[a_next]] && a_next < N_EATED_A[B[b_next]] && (C_tori[a_next][b_next] > tabeta*3)) {
		res += func(a_next, b_next, taberu - 1)*(3 * tabeta + 1)*(3 * tabeta + 2)*(C_tori[a_next][b_next]- tabeta*3);
	}
}

が累積和っぽいテクを使うことで減らせます。(ここは発想ではなく知識です)
ここでeiya流雑なDPを使ってまずbのループを消していきます。

for (int32_t a_next = a-1; a_next >= 0; --a_next)
	res += func2(a_next,b,taberu);

になるようにfunc2を書きます。
とりあえずコピペしてメモ化します。

bool DP2_u[400][400][400];
mint DP2[400][400][400];
mint func2(int32_t a_next, int32_t b, int32_t taberu)
{
	if (DP2_u[a_next][b][taberu]) {
		return DP2[a_next][b][taberu];
	}
	DP2_u[a_next][b][taberu] = true;

	mint res = 0;
	int32_t tabeta = eat_all - taberu;
 
	for (int32_t b_next = b-1; b_next >= 0; --b_next)
	{
		if (b_next < N_EATED_B[A[a_next]] && a_next < N_EATED_A[B[b_next]] && (C_tori[a_next][b_next] > tabeta*3)) {
			res += func(a_next, b_next, taberu - 1)*(3 * tabeta + 1)*(3 * tabeta + 2)*(C_tori[a_next][b_next]- tabeta*3);
		}
	}
	return DP2[a_next][b][taberu] = res;
}

func2は[0,b)の

if (b_next < N_EATED_B[A[a_next]] && a_next < N_EATED_A[B[b_next]] && (C_tori[a_next][b_next] > tabeta*3)) {
	res += func(a_next, b_next, taberu - 1)*(3 * tabeta + 1)*(3 * tabeta + 2)*(C_tori[a_next][b_next]- tabeta*3);
}

の和なので、こうします。

bool DP2_u[400][400][400];
mint DP2[400][400][400];
mint func2(int32_t a_next, int32_t b, int32_t taberu)
{
	if(b==0){return 0_mi;}//0_miはmint(0)と同じ意味です。
	if (DP2_u[a_next][b][taberu]) {
		return DP2[a_next][b][taberu];
	}
	DP2_u[a_next][b][taberu] = true;

	mint res = 0;
	int32_t tabeta = eat_all - taberu;
 	
 	//b_next==b-1の処理
	int b_next = b-1;
 	if (b_next < N_EATED_B[A[a_next]] && a_next < N_EATED_A[B[b_next]] && (C_tori[a_next][b_next] > tabeta*3)) {
		res += func(a_next, b_next, taberu - 1)*(3 * tabeta + 1)*(3 * tabeta + 2)*(C_tori[a_next][b_next]- tabeta*3);
	}
	//b_nextが[0,b-1)の処理
	res += func2(a_next, b-1, taberu);
	return DP2[a_next][b][taberu] = res;
}

これでfunc2の一回当たりの実行時間はO(1)になりました。

なのでfuncの

for (int32_t a_next = a-1; a_next >= 0; --a_next)
	res += func2(a_next,b,taberu);

はこの時点でO(N)です。

同じようにfuncのaのループも消すとこうなります。
実際に書いたコードはこっち:
Submission #1705609 - CODE FESTIVAL 2017 qual C

int32_t eat_all;
int32_t N_EATED_A[400];
int32_t N_EATED_B[400];
int32_t C_tori[400][400];//a,b
int32_t N;
int32_t A[400];
int32_t B[400];
mint func(int a, int b, int taberu);

bool DP2_u[400][400][400];
mint DP2[400][400][400];
mint func2(int32_t a_next, int32_t b, int32_t taberu)
{
	if (b == 0) { return 0_mi; }//0_miはmint(0)と同じ意味です。
	if (DP2_u[a_next][b][taberu]) {
		return DP2[a_next][b][taberu];
	}
	DP2_u[a_next][b][taberu] = true;

	mint res = 0;
	int32_t tabeta = eat_all - taberu;

	//b_next==b-1の処理
	int b_next = b - 1;
	if (b_next < N_EATED_B[A[a_next]] && a_next < N_EATED_A[B[b_next]] && (C_tori[a_next][b_next] > tabeta * 3)) {
		res += func(a_next, b_next, taberu - 1)*(3 * tabeta + 1)*(3 * tabeta + 2)*(C_tori[a_next][b_next] - tabeta * 3);
	}
	//b_nextが[0,b-1)の処理
	res += func2(a_next, b - 1, taberu);
	return DP2[a_next][b][taberu] = res;
}

bool DP3_u[400][400][400];
mint DP3[400][400][400];
mint func(int a, int b, int taberu)
{
	if (a == 0 || b == 0) {
		if (a == 0 && b == 0 && taberu == 0) {
			return 1_mi;
		}
		else {
			return 0_mi;
		}
	}
	if (taberu == 0) {
		return 0_mi;
	}

	if (DP3_u[a][b][taberu]) {
		return DP3[a][b][taberu];
	}
	DP3_u[a][b][taberu] = true;

	auto a_next = a - 1;
	DP3[a][b][taberu] += func2(a_next, b, taberu);
	DP3[a][b][taberu] += func(a - 1, b, taberu);
	return DP3[a][b][taberu];
}

int main() {
//先ほどと同じなので省略
}

この時点でO(N^3)で、MLEです。

メモリ節約

DPのメモリを節約する一般的なテク、DP配列の使いまわしをします。(ここも知識)

DPが二つに分かれているので分かりにくいですが、再帰の起点であるfuncから処理を追っていくと、
taberu==tのときの更新の際、直接呼び出しているのは(つまり、漸化式に入っているのは)

  • a,bがより低くtaberu==tのfunc又はfunc2
  • taberu==t-1のfunc又はfunc2

のみであることが分かります。
よって、taberuを0~Nまで増やしていく方向にforDPをすると、taberuを節約できることが分かります。
この方針で再帰をforDPに書き直すとMLEを解決出来て、ACできます。

実際に書いたコード:Submission #1705841 - CODE FESTIVAL 2017 qual C
解説用に書き直したコード:Submission #1707527 - CODE FESTIVAL 2017 qual C

解いた経緯


(注:9時から19時まで模試でした)