Mmul TB 32_8_16

top Top: -

par Par: 96 lines

Problem Statement

\(32 \times 8\) 行列 \(A\) と、\(16 \times 8\) 行列 \(B\) に対して、行列積 \(C = A \times B^T\) (shape:\(32 \times 16\)) を計算してください。\(A,B,C\) のレイアウトは以下のとおりです。

A: ((32:1), (4_PE:1, 2_W:1)) B: ((16:1), (4_PE:1, 2_W:1)) C: ((32:2), (2:1, 4_PE:1, 2_W:1))
\(A, B, C\) の値はこちらです。

A:

[[ 0, 1, 2, 3, 4, 5, 6, 7], [ 8, 9, 10, 11, 12, 13, 14, 15], [ 16, 17, 18, 19, 20, 21, 22, 23], [ 24, 25, 26, 27, 28, 29, 30, 31], [ 32, 33, 34, 35, 36, 37, 38, 39], [ 40, 41, 42, 43, 44, 45, 46, 47], [ 48, 49, 50, 51, 52, 53, 54, 55], [ 56, 57, 58, 59, 60, 61, 62, 63], [ 64, 65, 66, 67, 68, 69, 70, 71], [ 72, 73, 74, 75, 76, 77, 78, 79], [ 80, 81, 82, 83, 84, 85, 86, 87], [ 88, 89, 90, 91, 92, 93, 94, 95], [ 96, 97, 98, 99,100,101,102,103], [104,105,106,107,108,109,110,111], [112,113,114,115,116,117,118,119], [120,121,122,123,124,125,126,127], [128,129,130,131,132,133,134,135], [136,137,138,139,140,141,142,143], [144,145,146,147,148,149,150,151], [152,153,154,155,156,157,158,159], [160,161,162,163,164,165,166,167], [168,169,170,171,172,173,174,175], [176,177,178,179,180,181,182,183], [184,185,186,187,188,189,190,191], [192,193,194,195,196,197,198,199], [200,201,202,203,204,205,206,207], [208,209,210,211,212,213,214,215], [216,217,218,219,220,221,222,223], [224,225,226,227,228,229,230,231], [232,233,234,235,236,237,238,239], [240,241,242,243,244,245,246,247], [248,249,250,251,252,253,254,255]]

B:

[[300,301,302,303,304,305,306,307], [308,309,310,311,312,313,314,315], [316,317,318,319,320,321,322,323], [324,325,326,327,328,329,330,331], [332,333,334,335,336,337,338,339], [340,341,342,343,344,345,346,347], [348,349,350,351,352,353,354,355], [356,357,358,359,360,361,362,363], [364,365,366,367,368,369,370,371], [372,373,374,375,376,377,378,379], [380,381,382,383,384,385,386,387], [388,389,390,391,392,393,394,395], [396,397,398,399,400,401,402,403], [404,405,406,407,408,409,410,411], [412,413,414,415,416,417,418,419], [420,421,422,423,424,425,426,427]]

C:
import numpy as np A = np.array([[ 0, 1, 2, 3, 4, 5, 6, 7], [ 8, 9, 10, 11, 12, 13, 14, 15], [ 16, 17, 18, 19, 20, 21, 22, 23], [ 24, 25, 26, 27, 28, 29, 30, 31], [ 32, 33, 34, 35, 36, 37, 38, 39], [ 40, 41, 42, 43, 44, 45, 46, 47], [ 48, 49, 50, 51, 52, 53, 54, 55], [ 56, 57, 58, 59, 60, 61, 62, 63], [ 64, 65, 66, 67, 68, 69, 70, 71], [ 72, 73, 74, 75, 76, 77, 78, 79], [ 80, 81, 82, 83, 84, 85, 86, 87], [ 88, 89, 90, 91, 92, 93, 94, 95], [ 96, 97, 98, 99,100,101,102,103], [104,105,106,107,108,109,110,111], [112,113,114,115,116,117,118,119], [120,121,122,123,124,125,126,127], [128,129,130,131,132,133,134,135], [136,137,138,139,140,141,142,143], [144,145,146,147,148,149,150,151], [152,153,154,155,156,157,158,159], [160,161,162,163,164,165,166,167], [168,169,170,171,172,173,174,175], [176,177,178,179,180,181,182,183], [184,185,186,187,188,189,190,191], [192,193,194,195,196,197,198,199], [200,201,202,203,204,205,206,207], [208,209,210,211,212,213,214,215], [216,217,218,219,220,221,222,223], [224,225,226,227,228,229,230,231], [232,233,234,235,236,237,238,239], [240,241,242,243,244,245,246,247], [248,249,250,251,252,253,254,255]]) B = np.array([[300,301,302,303,304,305,306,307], [308,309,310,311,312,313,314,315], [316,317,318,319,320,321,322,323], [324,325,326,327,328,329,330,331], [332,333,334,335,336,337,338,339], [340,341,342,343,344,345,346,347], [348,349,350,351,352,353,354,355], [356,357,358,359,360,361,362,363], [364,365,366,367,368,369,370,371], [372,373,374,375,376,377,378,379], [380,381,382,383,384,385,386,387], [388,389,390,391,392,393,394,395], [396,397,398,399,400,401,402,403], [404,405,406,407,408,409,410,411], [412,413,414,415,416,417,418,419], [420,421,422,423,424,425,426,427]]) A @ B.T

[[ 8540, 8764, 8988, 9212, 9436, 9660, 9884, 10108, 10332, 10556, 10780, 11004, 11228, 11452, 11676, 11900], [ 27964, 28700, 29436, 30172, 30908, 31644, 32380, 33116, 33852, 34588, 35324, 36060, 36796, 37532, 38268, 39004], [ 47388, 48636, 49884, 51132, 52380, 53628, 54876, 56124, 57372, 58620, 59868, 61116, 62364, 63612, 64860, 66108], [ 66812, 68572, 70332, 72092, 73852, 75612, 77372, 79132, 80892, 82652, 84412, 86172, 87932, 89692, 91452, 93212], [ 86236, 88508, 90780, 93052, 95324, 97596, 99868,102140,104412,106684,108956,111228,113500,115772,118044,120316], [105660,108444,111228,114012,116796,119580,122364,125148,127932,130716,133500,136284,139068,141852,144636,147420], [125084,128380,131676,134972,138268,141564,144860,148156,151452,154748,158044,161340,164636,167932,171228,174524], [144508,148316,152124,155932,159740,163548,167356,171164,174972,178780,182588,186396,190204,194012,197820,201628], [163932,168252,172572,176892,181212,185532,189852,194172,198492,202812,207132,211452,215772,220092,224412,228732], [183356,188188,193020,197852,202684,207516,212348,217180,222012,226844,231676,236508,241340,246172,251004,255836], [202780,208124,213468,218812,224156,229500,234844,240188,245532,250876,256220,261564,266908,272252,277596,282940], [222204,228060,233916,239772,245628,251484,257340,263196,269052,274908,280764,286620,292476,298332,304188,310044], [241628,247996,254364,260732,267100,273468,279836,286204,292572,298940,305308,311676,318044,324412,330780,337148], [261052,267932,274812,281692,288572,295452,302332,309212,316092,322972,329852,336732,343612,350492,357372,364252], [280476,287868,295260,302652,310044,317436,324828,332220,339612,347004,354396,361788,369180,376572,383964,391356], [299900,307804,315708,323612,331516,339420,347324,355228,363132,371036,378940,386844,394748,402652,410556,418460], [319324,327740,336156,344572,352988,361404,369820,378236,386652,395068,403484,411900,420316,428732,437148,445564], [338748,347676,356604,365532,374460,383388,392316,401244,410172,419100,428028,436956,445884,454812,463740,472668], [358172,367612,377052,386492,395932,405372,414812,424252,433692,443132,452572,462012,471452,480892,490332,499772], [377596,387548,397500,407452,417404,427356,437308,447260,457212,467164,477116,487068,497020,506972,516924,526876], [397020,407484,417948,428412,438876,449340,459804,470268,480732,491196,501660,512124,522588,533052,543516,553980], [416444,427420,438396,449372,460348,471324,482300,493276,504252,515228,526204,537180,548156,559132,570108,581084], [435868,447356,458844,470332,481820,493308,504796,516284,527772,539260,550748,562236,573724,585212,596700,608188], [455292,467292,479292,491292,503292,515292,527292,539292,551292,563292,575292,587292,599292,611292,623292,635292], [474716,487228,499740,512252,524764,537276,549788,562300,574812,587324,599836,612348,624860,637372,649884,662396], [494140,507164,520188,533212,546236,559260,572284,585308,598332,611356,624380,637404,650428,663452,676476,689500], [513564,527100,540636,554172,567708,581244,594780,608316,621852,635388,648924,662460,675996,689532,703068,716604], [532988,547036,561084,575132,589180,603228,617276,631324,645372,659420,673468,687516,701564,715612,729660,743708], [552412,566972,581532,596092,610652,625212,639772,654332,668892,683452,698012,712572,727132,741692,756252,770812], [571836,586908,601980,617052,632124,647196,662268,677340,692412,707484,722556,737628,752700,767772,782844,797916], [591260,606844,622428,638012,653596,669180,684764,700348,715932,731516,747100,762684,778268,793852,809436,825020], [610684,626780,642876,658972,675068,691164,707260,723356,739452,755548,771644,787740,803836,819932,836028,852124]]

Explanation

Mmul TB 4_8_16 をベースに、\(A\) の 1 次元目が \(4\) から \(32\) に \(8\) 倍に増え、それに伴い出力の 1 次元目も同様に増えています。gmmul を16回、gmfma を0回使います。

入力 \(A\) が $lm[0:8] から $lm[0:64] に変わり、\(B\) も $lm[8:40] から $lm[64:96] に変わったことに対応しましょう。

答えをスクリプトで生成すると良いかもしれません。

回答例(1/8) Mmul TB 4_8_16 をベースに
gbfn $lm64v $nowrite # $lm8v から変更 gmwrite $aluf $ly0 gbfn $lm72v $nowrite # $lm16v から変更 gmwrite $aluf $ly4 gbfn $lm0v $nowrite gmmul $ly $aluf $ln0v4 gbfn $lm80v $nowrite # $lm24v から変更 gmwrite $aluf $ly0 gbfn $lm88v $nowrite # $lm32v から変更 gmwrite $aluf $ly4 gbfn $lm0v $nowrite gmmul $ly $aluf $ln2v4

に変更することで、64 value(s) correct, but 448 value(s) mismatch と、64/512 (=1/8) が合うようになります。

入力 \(A\) の $lm[64:96] を処理して、$ln[0:16] に出力が終えたところなので、残りの 7/8 も同様に処理しましょう。

Inputs

Outputs

Testcases

testcase.vsm

Submission

ログイン / 新規登録