はじめに
こんにちは。GMO NIKKOのshunkiです。
Zigというプログラミング言語があります。 速いと噂のJavaScriptランタイムBunの開発言語として採用されているようで、にわかに盛り上がったみたいです。 気になったので、私も触ってみました。
この記事はZigの異様に強いコンパイル時計算を悪用し、連鎖行列積問題をコンパイル時に解きます。 何かの参考になるような話ではありませんが、コンパイル時計算の強力さを感じられると思います。 最後までお付き合いいただければ幸いです。
連鎖行列積問題:効率よく計算する順序を求める問題
まずは連鎖行列積問題について簡単に説明しておきましょう。 ただ、コンパイル時に計算することが主題なので、詳細に立ち入った説明は避けておきます。 また、すでに知っている人は読み飛ばしてください。
「連鎖行列積問題」に助詞を加えるとすれば、「連鎖した行列の積の問題」でしょうか。 つまり次の4つがわかれば、連鎖行列積問題を理解できます。
- 行列とは何か
- 行列の積とは何か
- 行列積が連鎖するとはどういうことか
- 連鎖行列積の何が問題となるのか
行列とは何か
二次元配列だと思ってください。 つまり同じ型のデータが行と列に並んで集まったデータをひっくるめて行列と呼ぶ感じです。
行列の中に並んだ個々のデータを要素と言います。 高校数学で学ぶ場合は実数の要素を想定することが多いですね。 実数が要素の行列を、特に実行列と呼んだりします。 複素数の要素からなる行列であれば、複素行列です。
行列の正確な定義は数学が得意な人にお譲りしておきましょう。 細かい話をすれば、この記事では、引き算、割り算を使わないため、半環の要素であればヨシという立場を取ります。
例えば、この記事では縦と横に数値を並べたモノを行列と呼ぶことにします。 また、行列の行数rと列数cを用いてr×c行列と言います。
行列の積とは何か
行列の積は、そのまま2つの行列の掛け算を意味しています。 詳細は省きますが、r0×c0行列とr1×c1行列の掛け算ができるのは、c0とr1が一致するときだけです。 そして結果はr0×c1行列になります。 また、行列の掛け算を計算する場合、要素の掛け算がr0×c0×c1(r0×r1×c1でも同じ)回必要です。
例えば、4×3行列の列数と3×2行列の行数は共に3で一致しているため、行列積が計算できます。 また、結果は4×2行列になり、要素の掛け算は4×3×2=24回必要になります。
念のため確認しておくと、左右逆に並べて3×2型と4×3型の行列として行列積を計算しようとしてもできません。 なぜなら、列数2と行数4が一致しないからです。
行列積が連鎖するとはどういうことか
連鎖とは単純に行列を何個も並べて掛け算することを意味します。
ところで、単に並べるといっても掛け算の順序はどうなるのでしょうか。 例えば、行列A、B、Cがあるとき、A×B×Cは(A×B)×Cと理解すべきでしょうか、A×(B×C)と理解すべきでしょうか。 つまり、A×BとB×Cのどちらを先に計算すればいいのか定かではありません。
答えは「どっちでもいい」です。 行列の掛け算は、どこから計算しても計算結果が変わりません(結合律が成り立つ)。
連鎖行列積の何が問題となるのか
どこから行列の積を求めても計算結果が変わらないのであれば、できるだけ都合のいい順序で計算したいと思うのが人情です。 具体的には要素の掛け算の回数を最小化したいという動機を元に、連鎖行列積問題は考察されます。
たとえば、4×3行列A、3×3行列B、3×2行列Cがあったとしましょう。 このとき、計算結果は(A×B)×Cと計算しても、A×(B×C)と計算しても変わりません。
しかし、要素の掛け算の回数は異なります。 (A×B)×Cの場合、A×Bの計算時に4×3×3=36回、(A×B)×Cのときに4×3×2=24回、合計60回の掛け算が必要です。 一方、A×(B×C)の場合、B×Cの計算時に3×3×2=18回、A×(B×C)のときに4×3×2=24回、合計42回の掛け算で済みます。 つまり、この場合はA×(B×C)と計算したほうがお得なわけですね。
連鎖行列積問題を効率よく解くには
連鎖行列積問題は言い換えると、連鎖行列積に要素の掛け算の回数が最小になるように括弧を正しく並べる問題です。 括弧を正しく並べる通り数はカタラン数として知られており、爆発的(指数関数)に増えることが知られています。 そのため、行列積の連鎖が少し伸びただけで、あり得る組み合わせが爆発的に増えてしまいます。
こう言うと、連鎖行列積問題は現実的な時間で解けないのかと不安になるかもしれません。 でも安心してください。 動的計画法を使って効率的(三次多項式)に連鎖行列積問題を解く方法があります。 この記事では、この動的計画法のアルゴリズムをコンパイル時に実行します。
なお、さらに効率のいいアルゴリズムも知られていますが、この記事では使いません。
行列の実装:コンパイル時に行列積が可能か確認する
連鎖行列積問題を確認したので、あとは実装していくだけですね。 ところが、想定以上にコードが長くなってしまいました。 完全なコードは記事の最後に載せておくので、ここでは簡単な説明に留めておきます。
行列の実装が気になる人はMut
関数を見てみてください。 なんで関数なんだと思うかもしれませんが、Zigはコンパイル時に型を値のように扱えます。 そして当然のように型を引数にとって、型を返す関数が書けます。 ついでに、コンパイル時定数も型を生成する引数として扱えます。
つまり、行数と列数の情報を型に含めるなら、型関数を書けばいいわけです。 この記事の実装ではMat(u8, 4, 3)
で、要素の型がu8
の4×3行列の型が生成できます。
また、型のパラメータを取り出せるように定義しています。 例えば、Mat(u8, 4, 3).Item
と書くと型u8
が得られ、 Mat(u8, 4, 3).rdim
と書くと定数4が、 Mat(u8, 4, 3).cdim
で定数3が得られます。 これで、列数と行数の一致が確認できます。 そして、戻り値の行列の行数、列数も引数の型から導出できます。
つまり、行列積の安全性をコンパイル時に保障できます。
素朴な連鎖行列積の実装:先頭から順番に掛けていく
速度比較のために、先頭から順番に行列を掛けていくnaiveMMC
関数を用意しました。 naiveMMC
の本体は_naiveMMC
で、単純な再帰関数になっています。
なぜループを使わないのか疑問でしょうか。 しかし、考えてみるとループの中でaccumulator = accumulator.mul(args[i])
と書こうにも、 行列積を計算するたびに累積変数accumulator
の行数と列数が変わってしまう可能性があります。 これではaccumulator
にどんな型(Mut(u8, ?, ?)
)をつければいいかわかりません。
もちろん、再帰関数でも、戻り値の型は変わってしまいます。 ですが、再帰するたびに変数が定義されるため、関数内の変数には個別に型が付きます。 つまり、型を固定する必要がないので、コンパイルが通るわけです。
動的計画法による連鎖行列積の実装
この記事の主題である、コンパイル時に動的計画法の連鎖行列積を解くのはMMC
関数です。 MMC
で動的計画法を実行し、その結果をもとに_MMC
で行列積を計算します。
MMC
のcomptime blk: {...}
の部分がコンパイル時に動的計画法を実行するコードです。 ほとんど通常の動的計画法のコードと変わりません。 違いは、@setEvalBranchQuota
を指定している所ぐらいでしょうか。 @setEvalBranchQuota
を指定しているのは、長めの連鎖を解く時にコンパイラが諦めないようにするためです。
_MMC
は動的計画法で構成された最適な順序を辿って、行列積を計算します。 やや面倒な条件分岐をしていますが、中間的に生成される行列のメモリ管理を行うためのもので、あまり気にする必要はありません。
実行時間を素朴な実装と動的計画法の実装で簡単に比べてみる
naiveMMC
とMMC
に100×90行列、90×80行列、…、20×10行列の連鎖行列積を計算させ、その実行時間を比べてみます。
最もコストが高い計算順序はnaiveMMC
の先頭から計算する方法で、2,400,000回です。 一方、最もコストが低い計算順序は末尾から計算する方法で328,000回です。 理論どおりであれば、naiveMMC
の約7倍MMC
が速いはずです。
1 2 3 4 5 6 7 |
> zig version 0.9.1 > zig build > .\zig-out\bin\matrix_chain_multiplication.exe 10 times naiveMMC: 505075200ns MMC : 62486900ns |
10回ほど動かしてみると、naiveMMC
の約8倍MMC
が速かったです。 マイクロベンチマークにさほど意味はないでしょうが、ほとんど理論値どおりに早くなりました。 やったー。
完全なコード
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 |
const std = @import("std"); const trait = std.meta.trait; const Allocator = std.mem.Allocator; pub fn main() anyerror!void { var arena = std.heap.ArenaAllocator.init(std.heap.page_allocator); defer arena.deinit(); const allocator = arena.allocator(); const m9 = try Mat(u8, 100, 90).init(allocator); defer m9.deinit(); const m8 = try Mat(u8, 90, 80).init(allocator); defer m8.deinit(); const m7 = try Mat(u8, 80, 70).init(allocator); defer m7.deinit(); const m6 = try Mat(u8, 70, 60).init(allocator); defer m6.deinit(); const m5 = try Mat(u8, 60, 50).init(allocator); defer m5.deinit(); const m4 = try Mat(u8, 50, 40).init(allocator); defer m4.deinit(); const m3 = try Mat(u8, 40, 30).init(allocator); defer m3.deinit(); const m2 = try Mat(u8, 30, 20).init(allocator); defer m2.deinit(); const m1 = try Mat(u8, 20, 10).init(allocator); defer m1.deinit(); const N: u32 = 10; var i: u32 = 0; i = 0; const naive_mmc_start = std.time.nanoTimestamp(); while (i < N) : (i += 1) { const naive_m = try naiveMMC(.{ m9, m8, m7, m6, m5, m4, m3, m2, m1 }); naive_m.deinit(); } const naive_mmc_end = std.time.nanoTimestamp(); const naive_mmc_time = naive_mmc_end - naive_mmc_start; i = 0; const mmc_start = std.time.nanoTimestamp(); while (i < N) : (i += 1) { const m = try MMC(.{ m9, m8, m7, m6, m5, m4, m3, m2, m1 }); m.deinit(); } const mmc_end = std.time.nanoTimestamp(); const mmc_time = mmc_end - mmc_start; std.debug.print( \\ {} times \\ naiveMMC: {}ns \\ MMC : {}ns , .{ N, naive_mmc_time, mmc_time }); } pub fn Mat(comptime T: type, comptime rdim: usize, cdim: usize) type { if (!comptime trait.isNumber(T)) @compileError("Expected number type, found '" ++ @typeName(T) ++ "'"); return struct { const Self = @This(); pub const Item = T; pub const rdim = rdim; pub const cdim = cdim; allocator: Allocator, items: [rdim * cdim]Item, pub fn init(allocator: Allocator) Allocator.Error!*Self { const self = try allocator.create(Self); errdefer allocator.destroy(self); self.allocator = allocator; self.items = [_]Item{0} ** (rdim * cdim); return self; } pub fn deinit(self: *Self) void { self.allocator.destroy(self); } pub fn set(self: *Self, row: usize, col: usize, value: Item) void { self.items[cdim * row + col] = value; } pub fn get(self: Self, row: usize, col: usize) Item { return self.items[cdim * row + col]; } pub fn setRow(self: *Self, row: usize, values: []const Item) void { for (values) |v, i| self.set(row, i, v); } pub fn getRow(self: Self, row: usize) []const Item { return self.items[(cdim * row)..(cdim * (row + 1))]; } pub fn mul(self: *const Self, other: anytype) Allocator.Error!*Mat(Item, rdim, @TypeOf(other.*).cdim) { const Other = @TypeOf(other.*); comptime { if (Other != Mat(Item, Other.rdim, Other.cdim)) @compileError("Expected Mat type, found '" ++ @typeName(Other) ++ "'"); if (cdim != Other.rdim) @compileError("Expected cdim == rdim, found '" ++ @typeName(Self) ++ "', '" ++ @typeName(Other) ++ "'"); } var m = try Mat(T, rdim, Other.cdim).init(self.allocator); errdefer m.deinit(); var i: usize = 0; while (i < rdim) : (i += 1) { var k: usize = 0; while (k < cdim) : (k += 1) { var j: usize = 0; while (j < Other.cdim) : (j += 1) { m.set(i, j, m.get(i, j) + self.get(i, k) * other.get(k, j)); } } } return m; } }; } pub fn naiveMMC(args: anytype) Allocator.Error!*Mat(@TypeOf(args[0].*).Item, @TypeOf(args[0].*).rdim, @TypeOf(args[args.len - 1].*).cdim) { if (comptime args.len < 2) @compileError("Expected args.len >= 2, found '" ++ @typeName(@TypeOf(args)) ++ "'"); return try _naiveMMC(@TypeOf(args[0].*).Item, args.len, args); } fn _naiveMMC(comptime T: type, comptime end: usize, args: anytype) Allocator.Error!*Mat(T, @TypeOf(args[0].*).rdim, @TypeOf(args[end - 1].*).cdim) { if (end == 2) return try args[0].mul(args[1]); const m = try _naiveMMC(T, end - 1, args); defer m.deinit(); return try m.mul(args[end - 1]); } pub fn MMC(args: anytype) Allocator.Error!*Mat(@TypeOf(args[0].*).Item, @TypeOf(args[0].*).rdim, @TypeOf(args[args.len - 1].*).cdim) { if (comptime args.len < 2) @compileError("Expected args.len >= 2, found '" ++ @typeName(@TypeOf(args)) ++ "'"); return try _MMC(@TypeOf(args[0].*).Item, 0, args.len - 1, args, comptime blk: { const n = args.len; @setEvalBranchQuota(n * n * n); var descent: [n][n]usize = undefined; var min_cost: [n][n]u64 = undefined; var i = 0; while (i < n) : (i += 1) { descent[i] = [_]usize{std.math.maxInt(usize)} ** n; min_cost[i] = [_]u64{std.math.maxInt(u32)} ** n; min_cost[i][i] = 0; } var len = 1; while (len < n) : (len += 1) { i = 0; while (i < n - len) : (i += 1) { var j = i + len; var k = i; while (k < j) : (k += 1) { const mul_cost = @TypeOf(args[i].*).rdim * @TypeOf(args[k].*).cdim * @TypeOf(args[j].*).cdim; const cost = min_cost[i][k] + min_cost[k + 1][j] + mul_cost; if (cost < min_cost[i][j]) { descent[i][j] = k; min_cost[i][j] = cost; } } } } break :blk &descent; }); } pub fn _MMC(comptime T: type, comptime start: usize, comptime end: usize, args: anytype, comptime descent: *const [args.len][args.len]usize) Allocator.Error!*Mat(T, @TypeOf(args[start].*).rdim, @TypeOf(args[end].*).cdim) { const mid = descent[start][end]; if (start == mid and mid + 1 == end) { return args[start].mul(args[end]); } if (mid == start) { const m = try _MMC(T, mid + 1, end, args, descent); defer m.deinit(); return try args[start].mul(m); } if (mid + 1 == end) { const m = try _MMC(T, start, mid, args, descent); defer m.deinit(); return try m.mul(args[end]); } const m0 = try _MMC(T, start, mid, args, descent); defer m0.deinit(); const m1 = try _MMC(T, mid + 1, end, args, descent); defer m1.deinit(); return try m0.mul(m1); } test { const allocator = std.testing.allocator; var a = try Mat(i32, 4, 3).init(allocator); defer a.deinit(); a.set(0, 0, 1); a.set(0, 1, 0); a.set(0, 2, 1); a.setRow(1, &[_]i32{ 2, 1, 1 }); a.setRow(2, &[_]i32{ 0, 1, 1 }); a.setRow(3, &[_]i32{ 1, 1, 2 }); var b = try Mat(i32, 3, 3).init(allocator); defer b.deinit(); b.setRow(0, &[_]i32{ 1, 2, 1 }); b.setRow(1, &[_]i32{ 2, 3, 1 }); b.setRow(2, &[_]i32{ 4, 2, 2 }); var c = try Mat(i32, 3, 2).init(allocator); defer c.deinit(); c.setRow(0, &[_]i32{ -1, 2 }); c.setRow(1, &[_]i32{ 1, -2 }); c.setRow(2, &[_]i32{ -2, 1 }); { try comptime std.testing.expect(@TypeOf(a.*).rdim == 4); try comptime std.testing.expect(@TypeOf(a.*).cdim == 3); try comptime std.testing.expect(@TypeOf(b.*).rdim == 3); try comptime std.testing.expect(@TypeOf(b.*).cdim == 3); try comptime std.testing.expect(@TypeOf(c.*).rdim == 3); try comptime std.testing.expect(@TypeOf(c.*).cdim == 2); } { try std.testing.expect(a.get(0, 0) == 1); try std.testing.expect(a.get(1, 0) == 2); try std.testing.expect(a.get(2, 1) == 1); try std.testing.expectEqualSlices(i32, b.getRow(0), &[_]i32{ 1, 2, 1 }); try std.testing.expectEqualSlices(i32, b.getRow(1), &[_]i32{ 2, 3, 1 }); try std.testing.expectEqualSlices(i32, b.getRow(2), &[_]i32{ 4, 2, 2 }); } { var x = try a.mul(b); defer x.deinit(); try comptime std.testing.expect(@TypeOf(x.*).rdim == 4); try comptime std.testing.expect(@TypeOf(x.*).cdim == 3); try std.testing.expectEqualSlices(i32, x.getRow(0), &[_]i32{ 5, 4, 3 }); try std.testing.expectEqualSlices(i32, x.getRow(1), &[_]i32{ 8, 9, 5 }); try std.testing.expectEqualSlices(i32, x.getRow(2), &[_]i32{ 6, 5, 3 }); try std.testing.expectEqualSlices(i32, x.getRow(3), &[_]i32{ 11, 9, 6 }); } { var x = try a.mul(b); defer x.deinit(); var y = try x.mul(c); defer y.deinit(); try comptime std.testing.expect(@TypeOf(y.*).rdim == 4); try comptime std.testing.expect(@TypeOf(y.*).cdim == 2); try std.testing.expectEqualSlices(i32, y.getRow(0), &[_]i32{ -7, 5 }); try std.testing.expectEqualSlices(i32, y.getRow(1), &[_]i32{ -9, 3 }); try std.testing.expectEqualSlices(i32, y.getRow(2), &[_]i32{ -7, 5 }); try std.testing.expectEqualSlices(i32, y.getRow(3), &[_]i32{ -14, 10 }); } { var x = try naiveMMC(.{ a, b }); defer x.deinit(); try comptime std.testing.expect(@TypeOf(x.*).rdim == 4); try comptime std.testing.expect(@TypeOf(x.*).cdim == 3); try std.testing.expectEqualSlices(i32, x.getRow(0), &[_]i32{ 5, 4, 3 }); try std.testing.expectEqualSlices(i32, x.getRow(1), &[_]i32{ 8, 9, 5 }); try std.testing.expectEqualSlices(i32, x.getRow(2), &[_]i32{ 6, 5, 3 }); try std.testing.expectEqualSlices(i32, x.getRow(3), &[_]i32{ 11, 9, 6 }); } { var x = try naiveMMC(.{ a, b, c }); defer x.deinit(); try comptime std.testing.expect(@TypeOf(x.*).rdim == 4); try comptime std.testing.expect(@TypeOf(x.*).cdim == 2); try std.testing.expectEqualSlices(i32, x.getRow(0), &[_]i32{ -7, 5 }); try std.testing.expectEqualSlices(i32, x.getRow(1), &[_]i32{ -9, 3 }); try std.testing.expectEqualSlices(i32, x.getRow(2), &[_]i32{ -7, 5 }); try std.testing.expectEqualSlices(i32, x.getRow(3), &[_]i32{ -14, 10 }); } { var i = try Mat(i32, 3, 3).init(allocator); defer i.deinit(); i.setRow(0, &[_]i32{ 1, 0, 0 }); i.setRow(1, &[_]i32{ 0, 1, 0 }); i.setRow(2, &[_]i32{ 0, 0, 1 }); const len: usize = 50; var ms: [len]*Mat(i32, 3, 3) = undefined; for (ms) |_, k| ms[k] = i; var x = try naiveMMC(ms); defer x.deinit(); try comptime std.testing.expect(@TypeOf(x.*).rdim == 3); try comptime std.testing.expect(@TypeOf(x.*).cdim == 3); try std.testing.expectEqualSlices(i32, x.getRow(0), &[_]i32{ 1, 0, 0 }); try std.testing.expectEqualSlices(i32, x.getRow(1), &[_]i32{ 0, 1, 0 }); try std.testing.expectEqualSlices(i32, x.getRow(2), &[_]i32{ 0, 0, 1 }); } { var x = try MMC(.{ a, b }); defer x.deinit(); try comptime std.testing.expect(@TypeOf(x.*).rdim == 4); try comptime std.testing.expect(@TypeOf(x.*).cdim == 3); try std.testing.expectEqualSlices(i32, x.getRow(0), &[_]i32{ 5, 4, 3 }); try std.testing.expectEqualSlices(i32, x.getRow(1), &[_]i32{ 8, 9, 5 }); try std.testing.expectEqualSlices(i32, x.getRow(2), &[_]i32{ 6, 5, 3 }); try std.testing.expectEqualSlices(i32, x.getRow(3), &[_]i32{ 11, 9, 6 }); } { var x = try MMC(.{ a, b, c }); defer x.deinit(); try comptime std.testing.expect(@TypeOf(x.*).rdim == 4); try comptime std.testing.expect(@TypeOf(x.*).cdim == 2); try std.testing.expectEqualSlices(i32, x.getRow(0), &[_]i32{ -7, 5 }); try std.testing.expectEqualSlices(i32, x.getRow(1), &[_]i32{ -9, 3 }); try std.testing.expectEqualSlices(i32, x.getRow(2), &[_]i32{ -7, 5 }); try std.testing.expectEqualSlices(i32, x.getRow(3), &[_]i32{ -14, 10 }); } { var i = try Mat(i32, 3, 3).init(allocator); defer i.deinit(); i.setRow(0, &[_]i32{ 1, 0, 0 }); i.setRow(1, &[_]i32{ 0, 1, 0 }); i.setRow(2, &[_]i32{ 0, 0, 1 }); const len: usize = 50; var ms: [len]*Mat(i32, 3, 3) = undefined; for (ms) |_, k| ms[k] = i; var x = try MMC(ms); defer x.deinit(); try comptime std.testing.expect(@TypeOf(x.*).rdim == 3); try comptime std.testing.expect(@TypeOf(x.*).cdim == 3); try std.testing.expectEqualSlices(i32, x.getRow(0), &[_]i32{ 1, 0, 0 }); try std.testing.expectEqualSlices(i32, x.getRow(1), &[_]i32{ 0, 1, 0 }); try std.testing.expectEqualSlices(i32, x.getRow(2), &[_]i32{ 0, 0, 1 }); } } |
