【comptime DP】連鎖行列積問題をコンパイル時の動的計画法で解く【Zig言語】

はじめに

こんにちは。GMO NIKKOのshunkiです。

Zigというプログラミング言語があります。 速いと噂のJavaScriptランタイムBunの開発言語として採用されているようで、にわかに盛り上がったみたいです。 気になったので、私も触ってみました。

この記事はZigの異様に強いコンパイル時計算を悪用し、連鎖行列積問題をコンパイル時に解きます。 何かの参考になるような話ではありませんが、コンパイル時計算の強力さを感じられると思います。 最後までお付き合いいただければ幸いです。

連鎖行列積問題:効率よく計算する順序を求める問題

まずは連鎖行列積問題について簡単に説明しておきましょう。 ただ、コンパイル時に計算することが主題なので、詳細に立ち入った説明は避けておきます。 また、すでに知っている人は読み飛ばしてください。

「連鎖行列積問題」に助詞を加えるとすれば、「連鎖した行列の積の問題」でしょうか。 つまり次の4つがわかれば、連鎖行列積問題を理解できます。

  1. 行列とは何か
  2. 行列の積とは何か
  3. 行列積が連鎖するとはどういうことか
  4. 連鎖行列積の何が問題となるのか

行列とは何か

二次元配列だと思ってください。 つまり同じ型のデータが行と列に並んで集まったデータをひっくるめて行列と呼ぶ感じです。

行列の中に並んだ個々のデータを要素と言います。 高校数学で学ぶ場合は実数の要素を想定することが多いですね。 実数が要素の行列を、特に実行列と呼んだりします。 複素数の要素からなる行列であれば、複素行列です。

行列の正確な定義は数学が得意な人にお譲りしておきましょう。 細かい話をすれば、この記事では、引き算、割り算を使わないため、半環の要素であればヨシという立場を取ります。

例えば、この記事では縦と横に数値を並べたモノを行列と呼ぶことにします。 また、行列の行数rと列数cを用いてr×c行列と言います。

行列の積とは何か

行列の積は、そのまま2つの行列の掛け算を意味しています。 詳細は省きますが、r0×c0行列とr1×c1行列の掛け算ができるのは、c0とr1が一致するときだけです。 そして結果はr0×c1行列になります。 また、行列の掛け算を計算する場合、要素の掛け算がr0×c0×c1(r0×r1×c1でも同じ)回必要です。

例えば、4×3行列の列数と3×2行列の行数は共に3で一致しているため、行列積が計算できます。 また、結果は4×2行列になり、要素の掛け算は4×3×2=24回必要になります。

念のため確認しておくと、左右逆に並べて3×2型と4×3型の行列として行列積を計算しようとしてもできません。 なぜなら、列数2と行数4が一致しないからです。

行列積が連鎖するとはどういうことか

連鎖とは単純に行列を何個も並べて掛け算することを意味します。

ところで、単に並べるといっても掛け算の順序はどうなるのでしょうか。 例えば、行列A、B、Cがあるとき、A×B×Cは(A×B)×Cと理解すべきでしょうか、A×(B×C)と理解すべきでしょうか。 つまり、A×BとB×Cのどちらを先に計算すればいいのか定かではありません。

答えは「どっちでもいい」です。 行列の掛け算は、どこから計算しても計算結果が変わりません(結合律が成り立つ)。

連鎖行列積の何が問題となるのか

どこから行列の積を求めても計算結果が変わらないのであれば、できるだけ都合のいい順序で計算したいと思うのが人情です。 具体的には要素の掛け算の回数を最小化したいという動機を元に、連鎖行列積問題は考察されます。

たとえば、4×3行列A、3×3行列B、3×2行列Cがあったとしましょう。 このとき、計算結果は(A×B)×Cと計算しても、A×(B×C)と計算しても変わりません。

しかし、要素の掛け算の回数は異なります。 (A×B)×Cの場合、A×Bの計算時に4×3×3=36回、(A×B)×Cのときに4×3×2=24回、合計60回の掛け算が必要です。 一方、A×(B×C)の場合、B×Cの計算時に3×3×2=18回、A×(B×C)のときに4×3×2=24回、合計42回の掛け算で済みます。 つまり、この場合はA×(B×C)と計算したほうがお得なわけですね。

連鎖行列積問題を効率よく解くには

連鎖行列積問題は言い換えると、連鎖行列積に要素の掛け算の回数が最小になるように括弧を正しく並べる問題です。 括弧を正しく並べる通り数はカタラン数として知られており、爆発的(指数関数)に増えることが知られています。 そのため、行列積の連鎖が少し伸びただけで、あり得る組み合わせが爆発的に増えてしまいます。

こう言うと、連鎖行列積問題は現実的な時間で解けないのかと不安になるかもしれません。 でも安心してください。 動的計画法を使って効率的(三次多項式)に連鎖行列積問題を解く方法があります。 この記事では、この動的計画法のアルゴリズムをコンパイル時に実行します。

なお、さらに効率のいいアルゴリズムも知られていますが、この記事では使いません。

行列の実装:コンパイル時に行列積が可能か確認する

連鎖行列積問題を確認したので、あとは実装していくだけですね。 ところが、想定以上にコードが長くなってしまいました。 完全なコードは記事の最後に載せておくので、ここでは簡単な説明に留めておきます。

行列の実装が気になる人はMut関数を見てみてください。 なんで関数なんだと思うかもしれませんが、Zigはコンパイル時に型を値のように扱えます。 そして当然のように型を引数にとって、型を返す関数が書けます。 ついでに、コンパイル時定数も型を生成する引数として扱えます。

つまり、行数と列数の情報を型に含めるなら、型関数を書けばいいわけです。 この記事の実装ではMat(u8, 4, 3)で、要素の型がu8の4×3行列の型が生成できます。

また、型のパラメータを取り出せるように定義しています。 例えば、Mat(u8, 4, 3).Itemと書くと型u8が得られ、 Mat(u8, 4, 3).rdimと書くと定数4が、 Mat(u8, 4, 3).cdimで定数3が得られます。 これで、列数と行数の一致が確認できます。 そして、戻り値の行列の行数、列数も引数の型から導出できます。

つまり、行列積の安全性をコンパイル時に保障できます。

素朴な連鎖行列積の実装:先頭から順番に掛けていく

速度比較のために、先頭から順番に行列を掛けていくnaiveMMC関数を用意しました。 naiveMMCの本体は_naiveMMCで、単純な再帰関数になっています。

なぜループを使わないのか疑問でしょうか。 しかし、考えてみるとループの中でaccumulator = accumulator.mul(args[i])と書こうにも、 行列積を計算するたびに累積変数accumulatorの行数と列数が変わってしまう可能性があります。 これではaccumulatorにどんな型(Mut(u8, ?, ?))をつければいいかわかりません。

もちろん、再帰関数でも、戻り値の型は変わってしまいます。 ですが、再帰するたびに変数が定義されるため、関数内の変数には個別に型が付きます。 つまり、型を固定する必要がないので、コンパイルが通るわけです。

動的計画法による連鎖行列積の実装

この記事の主題である、コンパイル時に動的計画法の連鎖行列積を解くのはMMC関数です。 MMCで動的計画法を実行し、その結果をもとに_MMCで行列積を計算します。

MMCcomptime blk: {...}の部分がコンパイル時に動的計画法を実行するコードです。 ほとんど通常の動的計画法のコードと変わりません。 違いは、@setEvalBranchQuotaを指定している所ぐらいでしょうか。 @setEvalBranchQuotaを指定しているのは、長めの連鎖を解く時にコンパイラが諦めないようにするためです。

_MMCは動的計画法で構成された最適な順序を辿って、行列積を計算します。 やや面倒な条件分岐をしていますが、中間的に生成される行列のメモリ管理を行うためのもので、あまり気にする必要はありません。

実行時間を素朴な実装と動的計画法の実装で簡単に比べてみる

naiveMMCMMCに100×90行列、90×80行列、…、20×10行列の連鎖行列積を計算させ、その実行時間を比べてみます。

最もコストが高い計算順序はnaiveMMCの先頭から計算する方法で、2,400,000回です。 一方、最もコストが低い計算順序は末尾から計算する方法で328,000回です。 理論どおりであれば、naiveMMCの約7倍MMCが速いはずです。

10回ほど動かしてみると、naiveMMCの約8倍MMCが速かったです。 マイクロベンチマークにさほど意味はないでしょうが、ほとんど理論値どおりに早くなりました。 やったー。


完全なコード