F - Three Gluttons(CODE FESTIVAL 2017 qual C)
この問題です。
F - Three Gluttons
1800点問題自力AC!!!!!!!!!!!!!!
— eiya@受験競プロC++@DDCC (@eiya5498513) 2017年10月22日
しかも所要時間は別の問題の考察してる時間も含めて2時間半!https://t.co/HMyqIqeNNM
というわけで解説と考察の道筋を書きます。
実際にした考察
「AとBの残っている寿司」と「あと何個食べるか」を引数にDPが出来そう
- 「状態」が2^Nになるなぁ...。状態をNで表せるとO(N^3)で解けそう。
- でも表せたとしてDPの遷移を思いつかないな...
...色々試します...
始めは必ず取らないといけないとか、最後二つは絶対に取れないとか考えましたが、そこから派生出来ない...
(真隣に依存する操作が無い為、端が決まっても嬉しくないなぁ...)
そのうち「逆順をやってみるか」と思います。
逆順で状態をNで表せるように頑張ります。
...そもそもこの問題の場合の逆順ってなんだ???(しばらく混乱)
(ちなみに考察中は逆順に「食べる」という用語で考察したんですが、書いてみたら紛らわしかったので「吐き出す」と書いておきます)
- けんかしないで最後まで取る=>0個のこる
- そこから一個取る前の状態に戻す
- ああ、後ろ二つは絶対に取れないんだっけ。後ろから三つ目に吐き出すか?
- なんで後ろから三つ目でないといけないんだ??<=ここ詰めよう
- DPみたいにしたいので、今吐き出したものの順位をa,b(0<=a,b<N)と置くか
Aさんが一個目にxを吐き出すとき、他の人はxを食べられないので、Bさんがxをb以上の順位にしていると困るな...
(補足:Bさんがxをbよりも上の順位にしている場合、xもB[b]も残っている状況ではBさんはxを取りたいのでけんかが起こります)
=>逆に言えば、Aさんが吐き出せるのはBさんにとってbよりも下の順位のものだけだな???
ということは...
Aにとってaより上の順位のもの:必ず吐き出すことが可能、下の順位のもの:後からは絶対に吐き出せない
(補足:a<jについてAさんはA[j]よりもA[a]を先に食べる<=>A[j]よりもA[a]を後に吐き出す)
Bも同じ
=>これなら「AとBの残っている(吐き出せる寿司)寿司」をa,bという境界を持ってやることでN^2通りで表せる!
=>よしO(N^3)DP出来た!!!
...これは罠で、DPの状態(引数)は決めることが出来ましたが、まだ遷移が出来ていません。
考えます。
吐き出せるやつを全部試すとN^2かかって間に合わないし...
...思いつかないですね。
仕方がないので「吐き出せるやつを全部試す」の方針をもう一度考えます。
...そういえばこれって、前の方のやつを全部足す形だな...
=>累積和が使えるパターンでは!??(ここは一般的なテクとして知識を持っておかないと辛い)
こういう問題の場合、僕はコードを実際に書いて、それを無理やり整形して解くことが多いので、出来そうという方針が立ったこの段階でコードを書き始めます。
DPの説明
後半にコードが書いてあるので、そちらと見比べながら見てください。
DPの定義式を考えます。
本番中僕は言語化せずに感覚的にやってしまったのですが、頑張って言語化します。
DP[a][b][taberu]を
- 「Aはa、Bはbを直前に吐き出した」つまり「Aはaより好きなものを、Bはbより好きなものを吐き出すことが出来る」
- 「A,B,Cはあとtaberu個吐き出す」
の条件を満たすようにCを作ったとき、作れるCの通り数
とします。
A,Bの先頭を吐き出すまでCの構成を続けるので、完成した場合であるa=0,b=0のときCの通り数は1です。
つまり
DP[0][0][0] = 1
です。
例えば、
入力例 3
6 1 2 3 4 5 6 2 1 4 3 6 5
の時は
DP[2][2][1]=20
みたいになります。
a,b,は0-indexed、taberuは個数(0個1個2個...N個)であることに注意してください。
(計算方法は後述)
「次にA,Bが吐き出すものの順位」をnext_a,next_bとします。(これは二重ループで決め打ちします)
A[i],B[i]を問題文通りに、i番目に好きな寿司とします。(コードでは0-indexed)
Aがxが何番目に好きかをN_EATED_A[x]、Bがxが何番目に好きかをN_EATED_B[x]とします。(変数の命名は突っ込まない約束です)
AがA[next_a]を吐き出す為には、BがB[next_b]よりもA[next_a]が嫌いである、つまりnext_b < N_EATED_B[A[next_a]]である必要があります。
同様にBがB[next_b]を吐き出す為にはnext_a < N_EATED_A[B[next_b]]である必要があります。
これを満たしているとき、AはA[next_a]、BはB[next_b]を吐き出します。
この時のCの通り数を計算します。
Cは自由に決められるので、吐き出すごとに生まれる制約を使ってCを少しずつ決めてやるイメージでやります。
Cが吐き出すものをA,Bと同じようにC[next_c]と仮置きします。あとN_EATED_C[x]も同じように仮置きします。
まず、A,Bの時と同じ理由で、next_c < N_EATED_C[A[next_a]] かつ next_c < N_EATED_C[B[next_b]]である必要が有ります。
なので、next_cの前にA[next_a]とB[next_b]を好きな順番で入れます。
また、これまたA,Bのときと同じ理由で、next_a < N_EATED_A[C[next_c]] かつnext_b < N_EATED_B[C[next_c]]である必要が有ります。
これを満たす為に、C[next_c]は上の条件を満たす寿司から一つ選んで吐き出します。
この通り数が、コードの
*(3 * tabeta + 1)*(3 * tabeta + 2)*(C_tori[a_next][b_next]- tabeta*3)
に相当します。
tabetaは「今までに吐き出した数」です。(コードは本番のものなので、「吐き出す」が「食べる」になっています)
始め二つが、「A[next_a]とB[next_b]を好きな順番で入れる」の通り数です。Cの中で順番が確定しているのは「今までにA,B,Cが吐き出した総数」個です。そこに好きな順番でA[next_a]とB[next_b]を入れるのでこうなります。
最後は、C[next_c]の選び方の通り数です。C_tori[a_next][b_next]は「next_a < N_EATED_A[x] かつnext_b < N_EATED_B[x]」を満たすxの数です。このうち、「今までにA,B,Cが吐き出した総数」個は既に吐き出されているので、Cが新たに吐き出せるのは(C_tori[a_next][b_next]- tabeta*3)個です。
a_next,b_nextが決まっているときの遷移は、Cを適当に決める=>次の状態に移るという動作をし、これがCの構成の一つのパターンです。
これで遷移を書くと↓のようになります。(言語での説明ができなかった人の顔)
mint res = 0; int32_t tabeta = eat_all - taberu; for (int32_t a_next = a-1; a_next >= 0; --a_next) for (int32_t b_next = b-1; b_next >= 0; --b_next) { //↓つまり、「AがA[a_next],BがB[b_next]を食べる」というのが可能ならば if (b_next < N_EATED_B[A[a_next]] && a_next < N_EATED_A[B[b_next]] && (C_tori[a_next][b_next] > tabeta*3)) { //Cを適当に決める=>次の状態に移るという動作をするので //通り数を数えるDPと同じようにしてやった後、Cの選び方をかける res += func(a_next, b_next, taberu - 1)*(3 * tabeta + 1)*(3 * tabeta + 2)*(C_tori[a_next][b_next]- tabeta*3); } }
↓考察ノート(本番に実際使ったもの)
↓この段階でのコード(メモ化していないので、O(N^5)よりもさらに遅いです)
本番で使ったコードのバグを少し直しただけなので、「吐き出す」が「食べる」になっています。
mintの定義は省略してありますが、要はmodを勝手にとるintです。詳しくはこちら:C++14用mod_int - 永夜の記録
int32_t eat_all; int32_t N_EATED_A[400]; int32_t N_EATED_B[400]; int32_t C_tori[400][400];//a,b int32_t N; int32_t A[400]; int32_t B[400]; mint func(int a, int b, int taberu) { if (a == 0 || b == 0) { if (a == 0 && b == 0 && taberu == 0) { return 1_mi; } else { return 0_mi; } } if (taberu == 0) { assert(false); return 0_mi; } mint res = 0; int32_t tabeta = eat_all - taberu; for (int32_t a_next = a-1; a_next >= 0; --a_next) for (int32_t b_next = b-1; b_next >= 0; --b_next) { if (b_next < N_EATED_B[A[a_next]] && a_next < N_EATED_A[B[b_next]] && (C_tori[a_next][b_next] > tabeta*3)) { res += func(a_next, b_next, taberu - 1)*(3 * tabeta + 1)*(3 * tabeta + 2)*(C_tori[a_next][b_next]- tabeta*3); } } return func3(a, b, taberu); } int main() { using std::endl; in.sync_with_stdio(false); out.sync_with_stdio(false); in.tie(nullptr); out.tie(nullptr); in >> N; eat_all = N / 3; for (int32_t i = 0; i < N; i++) { in >> A[i]; --A[i]; N_EATED_A[A[i]] = i; } for (int32_t i = 0; i < N; i++) { in >> B[i]; --B[i]; N_EATED_B[B[i]] = i; } for (int32_t a = N - 1; a >= 0; --a) for (int32_t b = N - 1; b >= 0; --b) { if (b == N - 1) { C_tori[a][b] = 0; } else { C_tori[a][b] = C_tori[a][b + 1]; //B[b+1]が使えるようになる if (a < N_EATED_A[B[b + 1]]) { ++C_tori[a][b]; } } //for (int32_t i = 0; i < N; i++) //{ // if(a<N_EATED_A[i] && b<N_EATED_B[i]) // ++C_tori[a][b]; //} } //Nよりも順位が良い全てのもの、つまり全ての寿司を吐き出せます out << func(N, N, eat_all) << endl; return 0; }
オーダーを減らす
この段階で、O(N^5)の解法が出来たので、オーダーを減らします。
先ほども書きましたが、
for (int32_t a_next = a-1; a_next >= 0; --a_next) for (int32_t b_next = b-1; b_next >= 0; --b_next) { if (b_next < N_EATED_B[A[a_next]] && a_next < N_EATED_A[B[b_next]] && (C_tori[a_next][b_next] > tabeta*3)) { res += func(a_next, b_next, taberu - 1)*(3 * tabeta + 1)*(3 * tabeta + 2)*(C_tori[a_next][b_next]- tabeta*3); } }
が累積和っぽいテクを使うことで減らせます。(ここは発想ではなく知識です)
ここでeiya流雑なDPを使ってまずbのループを消していきます。
for (int32_t a_next = a-1; a_next >= 0; --a_next) res += func2(a_next,b,taberu);
になるようにfunc2を書きます。
とりあえずコピペしてメモ化します。
bool DP2_u[400][400][400]; mint DP2[400][400][400]; mint func2(int32_t a_next, int32_t b, int32_t taberu) { if (DP2_u[a_next][b][taberu]) { return DP2[a_next][b][taberu]; } DP2_u[a_next][b][taberu] = true; mint res = 0; int32_t tabeta = eat_all - taberu; for (int32_t b_next = b-1; b_next >= 0; --b_next) { if (b_next < N_EATED_B[A[a_next]] && a_next < N_EATED_A[B[b_next]] && (C_tori[a_next][b_next] > tabeta*3)) { res += func(a_next, b_next, taberu - 1)*(3 * tabeta + 1)*(3 * tabeta + 2)*(C_tori[a_next][b_next]- tabeta*3); } } return DP2[a_next][b][taberu] = res; }
func2は[0,b)の
if (b_next < N_EATED_B[A[a_next]] && a_next < N_EATED_A[B[b_next]] && (C_tori[a_next][b_next] > tabeta*3)) { res += func(a_next, b_next, taberu - 1)*(3 * tabeta + 1)*(3 * tabeta + 2)*(C_tori[a_next][b_next]- tabeta*3); }
の和なので、こうします。
bool DP2_u[400][400][400]; mint DP2[400][400][400]; mint func2(int32_t a_next, int32_t b, int32_t taberu) { if(b==0){return 0_mi;}//0_miはmint(0)と同じ意味です。 if (DP2_u[a_next][b][taberu]) { return DP2[a_next][b][taberu]; } DP2_u[a_next][b][taberu] = true; mint res = 0; int32_t tabeta = eat_all - taberu; //b_next==b-1の処理 int b_next = b-1; if (b_next < N_EATED_B[A[a_next]] && a_next < N_EATED_A[B[b_next]] && (C_tori[a_next][b_next] > tabeta*3)) { res += func(a_next, b_next, taberu - 1)*(3 * tabeta + 1)*(3 * tabeta + 2)*(C_tori[a_next][b_next]- tabeta*3); } //b_nextが[0,b-1)の処理 res += func2(a_next, b-1, taberu); return DP2[a_next][b][taberu] = res; }
これでfunc2の一回当たりの実行時間はO(1)になりました。
なのでfuncの
for (int32_t a_next = a-1; a_next >= 0; --a_next) res += func2(a_next,b,taberu);
はこの時点でO(N)です。
同じようにfuncのaのループも消すとこうなります。
実際に書いたコードはこっち:
Submission #1705609 - CODE FESTIVAL 2017 qual C
int32_t eat_all; int32_t N_EATED_A[400]; int32_t N_EATED_B[400]; int32_t C_tori[400][400];//a,b int32_t N; int32_t A[400]; int32_t B[400]; mint func(int a, int b, int taberu); bool DP2_u[400][400][400]; mint DP2[400][400][400]; mint func2(int32_t a_next, int32_t b, int32_t taberu) { if (b == 0) { return 0_mi; }//0_miはmint(0)と同じ意味です。 if (DP2_u[a_next][b][taberu]) { return DP2[a_next][b][taberu]; } DP2_u[a_next][b][taberu] = true; mint res = 0; int32_t tabeta = eat_all - taberu; //b_next==b-1の処理 int b_next = b - 1; if (b_next < N_EATED_B[A[a_next]] && a_next < N_EATED_A[B[b_next]] && (C_tori[a_next][b_next] > tabeta * 3)) { res += func(a_next, b_next, taberu - 1)*(3 * tabeta + 1)*(3 * tabeta + 2)*(C_tori[a_next][b_next] - tabeta * 3); } //b_nextが[0,b-1)の処理 res += func2(a_next, b - 1, taberu); return DP2[a_next][b][taberu] = res; } bool DP3_u[400][400][400]; mint DP3[400][400][400]; mint func(int a, int b, int taberu) { if (a == 0 || b == 0) { if (a == 0 && b == 0 && taberu == 0) { return 1_mi; } else { return 0_mi; } } if (taberu == 0) { return 0_mi; } if (DP3_u[a][b][taberu]) { return DP3[a][b][taberu]; } DP3_u[a][b][taberu] = true; auto a_next = a - 1; DP3[a][b][taberu] += func2(a_next, b, taberu); DP3[a][b][taberu] += func(a - 1, b, taberu); return DP3[a][b][taberu]; } int main() { //先ほどと同じなので省略 }
この時点でO(N^3)で、MLEです。
メモリ節約
DPのメモリを節約する一般的なテク、DP配列の使いまわしをします。(ここも知識)
DPが二つに分かれているので分かりにくいですが、再帰の起点であるfuncから処理を追っていくと、
taberu==tのときの更新の際、直接呼び出しているのは(つまり、漸化式に入っているのは)
- a,bがより低くtaberu==tのfunc又はfunc2
- taberu==t-1のfunc又はfunc2
のみであることが分かります。
よって、taberuを0~Nまで増やしていく方向にforDPをすると、taberuを節約できることが分かります。
この方針で再帰をforDPに書き直すとMLEを解決出来て、ACできます。
実際に書いたコード:Submission #1705841 - CODE FESTIVAL 2017 qual C
解説用に書き直したコード:Submission #1707527 - CODE FESTIVAL 2017 qual C
解いた経緯
あの。これあとこだ出たら何時間頭を酷使する計算になるんですか。(出るつもりだったけど、もう手に力があまり入らない為険しい)
— eiya@受験競プロC++@DDCC (@eiya5498513) 2017年10月22日
(注:9時から19時まで模試でした)
僕のこどふぇす
— eiya@受験競プロC++@DDCC (@eiya5498513) 2017年10月22日
出ます(no submit) https://t.co/8ZuD5qGHJ2
— eiya@受験競プロC++@DDCC (@eiya5498513) 2017年10月22日