Mmul TB 4_16_8

top Top: -

par Par: 12 lines

Problem Statement

\(4 \times 16\) 行列 \(A\) と、\(8 \times 16\) 行列 \(B\) に対して、行列積 \(C = A \times B^T\) (shape:\(4 \times 8\)) を計算してください。\(A,B,C\) のレイアウトは以下のとおりです。

A: ((4:2), (2:1, 4_PE:1, 2_W:1)) B: ((8:2), (2:1, 4_PE:1, 2_W:1)) C: ((4:1), ( 4_PE:1, 2_W:1))
\(A, B, C\) の値はこちらです。

A:

[[ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15], [ 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31], [ 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47], [ 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63]]

B:

[[100,101,102,103,104,105,106,107,108,109,110,111,112,113,114,115], [116,117,118,119,120,121,122,123,124,125,126,127,128,129,130,131], [132,133,134,135,136,137,138,139,140,141,142,143,144,145,146,147], [148,149,150,151,152,153,154,155,156,157,158,159,160,161,162,163], [164,165,166,167,168,169,170,171,172,173,174,175,176,177,178,179], [180,181,182,183,184,185,186,187,188,189,190,191,192,193,194,195], [196,197,198,199,200,201,202,203,204,205,206,207,208,209,210,211], [212,213,214,215,216,217,218,219,220,221,222,223,224,225,226,227]]

C:
import numpy as np A = np.array([[ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15], [ 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31], [ 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47], [ 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63]]) B = np.array([[100,101,102,103,104,105,106,107,108,109,110,111,112,113,114,115], [116,117,118,119,120,121,122,123,124,125,126,127,128,129,130,131], [132,133,134,135,136,137,138,139,140,141,142,143,144,145,146,147], [148,149,150,151,152,153,154,155,156,157,158,159,160,161,162,163], [164,165,166,167,168,169,170,171,172,173,174,175,176,177,178,179], [180,181,182,183,184,185,186,187,188,189,190,191,192,193,194,195], [196,197,198,199,200,201,202,203,204,205,206,207,208,209,210,211], [212,213,214,215,216,217,218,219,220,221,222,223,224,225,226,227]]) A @ B.T

[[ 13240, 15160, 17080, 19000, 20920, 22840, 24760, 26680], [ 40760, 46776, 52792, 58808, 64824, 70840, 76856, 82872], [ 68280, 78392, 88504, 98616,108728,118840,128952,139064], [ 95800,110008,124216,138424,152632,166840,181048,195256]]

Explanation

今度は、Mmul TB 4_8_8 から \(A,B\) の 2 次元目が倍になりました。その代わり、出力は同じく \(4\times8\) のままです。

つまり、これまでは長さ 8 の内積だったところ、長さ 16 の内積になりました。行列積命令 gmmul の他に、行列積和命令 gmfma も使用します。

考え方・方針

以下のコードで、8 要素の内積、つまり行列 \(A, B\) の左同士の行列積を計算することができます。

gbfn $lm16v4 $nowrite # 入力の行列 B を $lm8v から変更 gmwrite $aluf $ly0 gbfn $lm32v4 $nowrite # 入力の行列 B を $lm16v から変更 gmwrite $aluf $ly4 gbfn $lm0v4 $nowrite # 入力の行列 A を $lm0v から変更 gmmul $ly $aluf $ln0v

両方の行列の入力サイズが増えたため、\(A\) の領域が $lm[0:8] から $lm[0:16] に、 \(B\) の領域が $lm[8:24] から $lm[16:48] に変わりました。

また、入力の 2 次元目のサイズが倍増したため、1 次元目が (4:2)、つまり 2 長語(4単語)の間隔で 4 個配置されることになります。そのため、$lm16v4 のように表現する必要があります。

自信がなくなったら、書き込まれた行列レジスタの内容などを確認して見ましょう。

gbfn $lm16v4 $nowrite gmwrite $aluf $ly0 gbfn $lm32v4 $nowrite gmwrite $aluf $ly4 d getbf $ly0n0c0b0m0 8
出力:
DEBUG-MRy(n0c0b0m0,0):{(100, 101) (0x42e40000, 0x42e50000), (102, 103) (0x42e60000, 0x42e70000), (104, 105) (0x42e80000, 0x42e90000), (106, 107) (0x42ea0000, 0x42eb0000)} #d getbf $ly0n0c0b0m0 8 DEBUG-MRy(n0c0b0m0,1):{(116, 117) (0x42f40000, 0x42f50000), (118, 119) (0x42f60000, 0x42f70000), (120, 121) (0x42f80000, 0x42f90000), (122, 123) (0x42fa0000, 0x42fb0000)} #d getbf $ly0n0c0b0m0 8 DEBUG-MRy(n0c0b0m0,2):{(132, 133) (0x43420000, 0x43428000), (134, 135) (0x43430000, 0x43438000), (136, 137) (0x43440000, 0x43448000), (138, 139) (0x43450000, 0x43458000)} #d getbf $ly0n0c0b0m0 8 DEBUG-MRy(n0c0b0m0,3):{(148, 149) (0x434a0000, 0x434a8000), (150, 151) (0x434b0000, 0x434b8000), (152, 153) (0x434c0000, 0x434c8000), (154, 155) (0x434d0000, 0x434d8000)} #d getbf $ly0n0c0b0m0 8 DEBUG-MRy(n0c0b0m0,4):{(164, 165) (0x43520000, 0x43528000), (166, 167) (0x43530000, 0x43538000), (168, 169) (0x43540000, 0x43548000), (170, 171) (0x43550000, 0x43558000)} #d getbf $ly0n0c0b0m0 8 DEBUG-MRy(n0c0b0m0,5):{(180, 181) (0x435a0000, 0x435a8000), (182, 183) (0x435b0000, 0x435b8000), (184, 185) (0x435c0000, 0x435c8000), (186, 187) (0x435d0000, 0x435d8000)} #d getbf $ly0n0c0b0m0 8 DEBUG-MRy(n0c0b0m0,6):{(196, 197) (0x43620000, 0x43628000), (198, 199) (0x43630000, 0x43638000), (200, 201) (0x43640000, 0x43648000), (202, 203) (0x43650000, 0x43658000)} #d getbf $ly0n0c0b0m0 8 DEBUG-MRy(n0c0b0m0,7):{(212, 213) (0x436a0000, 0x436a8000), (214, 215) (0x436b0000, 0x436b8000), (216, 217) (0x436c0000, 0x436c8000), (218, 219) (0x436d0000, 0x436d8000)} #d getbf $ly0n0c0b0m0 8

デバッグ出力が長いと思ったら、d getbf $ly0n0c0b0m0 2 のように、2 行だけ表示するだけでも効果的です。

デバッグ Tips:

また、どの内積の組を計算したいのか、というのを考えるのも、デバッグでは重要です。

例えば以下のように書いてみたとします。

gbfn $lm16v $nowrite gmwrite $aluf $ly0 d getbf $ly0n0c0b0m0 2

以下のデバッグ出力が得られます。

DEBUG-MRy(n0c0b0m0,0):{(100, 101) (0x42e40000, 0x42e50000), (102, 103) (0x42e60000, 0x42e70000), (104, 105) (0x42e80000, 0x42e90000), (106, 107) (0x42ea0000, 0x42eb0000)]} #d getbf $ly0n0c0b0m0 2 DEBUG-MRy(n0c0b0m0,1):{(108, 109) (0x42ec0000, 0x42ed0000), (110, 111) (0x42ee0000, 0x42ef0000), (112, 113) (0x42f00000, 0x42f10000), (114, 115) (0x42f20000, 0x42f30000)]} #d getbf $ly0n0c0b0m0 2

今回は長さ 16 の内積を計算したいので、(100, 101, ..., 115) の組、(116, 117, ..., 131) の組で計算したいです。

ここで、行列レジスタの 2 行目を見ると、(116, 117, ... となって欲しい組が、(108, 109, ... と、(100, 101, ... の後半の組のものが混じってしまっているので、バグに気づけます。

行列レジスタに入れない方の行列(今回で言う \(A\))のデバッグをしたい際は、

gbfn $lm0v4 $lr0v d getbf $lr0n0c0b0m0 1 d getbf $lr2n0c0b0m0 1

出力:

DEBUG-GREG0(n0c0b0m0p0,0):(0, 1) (0x40800000, 0x40900000) #d getbf $lr0n0c0b0m0 1 DEBUG-GREG0(n0c0b0m0p1,0):(2, 3) (0x40a00000, 0x40b00000) #d getbf $lr0n0c0b0m0 1 DEBUG-GREG0(n0c0b0m0p2,0):(4, 5) (0x40c00000, 0x40d00000) #d getbf $lr0n0c0b0m0 1 DEBUG-GREG0(n0c0b0m0p3,0):(6, 7) (0x40e00000, 0x40f00000) #d getbf $lr0n0c0b0m0 1 DEBUG-GREG0(n0c0b0m0p0,2):(16, 17) (0x41c00000, 0x41c40000) #d getbf $lr2n0c0b0m0 1 DEBUG-GREG0(n0c0b0m0p1,2):(18, 19) (0x41c80000, 0x41cc0000) #d getbf $lr2n0c0b0m0 1 DEBUG-GREG0(n0c0b0m0p2,2):(20, 21) (0x41d00000, 0x41d40000) #d getbf $lr2n0c0b0m0 1 DEBUG-GREG0(n0c0b0m0p3,2):(22, 23) (0x41d80000, 0x41dc0000) #d getbf $lr2n0c0b0m0 1

のように 2 行表示すると良いでしょう。

さて、行列 \(A, B\) の右同士の行列積も計算し、合算するコードも書いてみましょう。こちらは、gmmul ではなく、gmfma を使用します。

残りの半分の回答例
gbfn $lm18v4 $nowrite gmwrite $aluf $ly0 gbfn $lm34v4 $nowrite gmwrite $aluf $ly4 gbfn $lm2v4 $nowrite gmfma $ly $aluf $ln0v $ln0v

Inputs

Outputs

Testcases

testcase.vsm

Submission

ログイン / 新規登録