Mmul TA 16_8_16

top Top: -

par Par: 64 lines

Problem Statement

\(8 \times 16\) 行列 \(A\) と、\(8 \times 16\) 行列 \(B\) に対して、行列積 \(C = A^T \times B\) (shape:\(16 \times 16\)) を計算してください。\(A,B,C\) のレイアウトは以下のとおりです。

A: (( 8:2), (2:1, 4_PE:1, 2_W:1)) B: (( 8:2), (2:1, 4_PE:1, 2_W:1)) C: ((16:2), (2:1, 4_PE:1, 2_W:1))
\(A, B, C\) の値はこちらです。

A:

[[ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15], [ 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31], [ 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47], [ 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63], [ 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79], [ 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95], [ 96, 97, 98, 99,100,101,102,103,104,105,106,107,108,109,110,111], [112,113,114,115,116,117,118,119,120,121,122,123,124,125,126,127]]

B:

[[200,201,202,203,204,205,206,207,208,209,210,211,212,213,214,215], [216,217,218,219,220,221,222,223,224,225,226,227,228,229,230,231], [232,233,234,235,236,237,238,239,240,241,242,243,244,245,246,247], [248,249,250,251,252,253,254,255,256,257,258,259,260,261,262,263], [264,265,266,267,268,269,270,271,272,273,274,275,276,277,278,279], [280,281,282,283,284,285,286,287,288,289,290,291,292,293,294,295], [296,297,298,299,300,301,302,303,304,305,306,307,308,309,310,311], [312,313,314,315,316,317,318,319,320,321,322,323,324,325,326,327]]

C:
import numpy as np A = np.array([[ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15], [ 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31], [ 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47], [ 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63], [ 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79], [ 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95], [ 96, 97, 98, 99,100,101,102,103,104,105,106,107,108,109,110,111], [112,113,114,115,116,117,118,119,120,121,122,123,124,125,126,127]]) B = np.array([[200,201,202,203,204,205,206,207,208,209,210,211,212,213,214,215], [216,217,218,219,220,221,222,223,224,225,226,227,228,229,230,231], [232,233,234,235,236,237,238,239,240,241,242,243,244,245,246,247], [248,249,250,251,252,253,254,255,256,257,258,259,260,261,262,263], [264,265,266,267,268,269,270,271,272,273,274,275,276,277,278,279], [280,281,282,283,284,285,286,287,288,289,290,291,292,293,294,295], [296,297,298,299,300,301,302,303,304,305,306,307,308,309,310,311], [312,313,314,315,316,317,318,319,320,321,322,323,324,325,326,327]]) A.T @ B

[[125440,125888,126336,126784,127232,127680,128128,128576,129024,129472,129920,130368,130816,131264,131712,132160], [127488,127944,128400,128856,129312,129768,130224,130680,131136,131592,132048,132504,132960,133416,133872,134328], [129536,130000,130464,130928,131392,131856,132320,132784,133248,133712,134176,134640,135104,135568,136032,136496], [131584,132056,132528,133000,133472,133944,134416,134888,135360,135832,136304,136776,137248,137720,138192,138664], [133632,134112,134592,135072,135552,136032,136512,136992,137472,137952,138432,138912,139392,139872,140352,140832], [135680,136168,136656,137144,137632,138120,138608,139096,139584,140072,140560,141048,141536,142024,142512,143000], [137728,138224,138720,139216,139712,140208,140704,141200,141696,142192,142688,143184,143680,144176,144672,145168], [139776,140280,140784,141288,141792,142296,142800,143304,143808,144312,144816,145320,145824,146328,146832,147336], [141824,142336,142848,143360,143872,144384,144896,145408,145920,146432,146944,147456,147968,148480,148992,149504], [143872,144392,144912,145432,145952,146472,146992,147512,148032,148552,149072,149592,150112,150632,151152,151672], [145920,146448,146976,147504,148032,148560,149088,149616,150144,150672,151200,151728,152256,152784,153312,153840], [147968,148504,149040,149576,150112,150648,151184,151720,152256,152792,153328,153864,154400,154936,155472,156008], [150016,150560,151104,151648,152192,152736,153280,153824,154368,154912,155456,156000,156544,157088,157632,158176], [152064,152616,153168,153720,154272,154824,155376,155928,156480,157032,157584,158136,158688,159240,159792,160344], [154112,154672,155232,155792,156352,156912,157472,158032,158592,159152,159712,160272,160832,161392,161952,162512], [156160,156728,157296,157864,158432,159000,159568,160136,160704,161272,161840,162408,162976,163544,164112,164680]]

Explanation

前回の Mmul TA 8_8_8 では、\(8 \times 8\) 行列同士の行列積を計算しました。

今回は、\(8 \times 16\) 行列を、右半分、左半分に分けて、それぞれ \(8 \times 8\) 行列を計算し、結合することで計算することができます。

つまり、イメージとしては行列 \(A,B\) のそれぞれ左半分と右半分を \(A_0, A_1, B_0, B_1\) としたとき、

[A0 × B0, A0 × B1], [A1 × B0, A1 × B1],

を計算します。

4 回の \(8 \times 8\) 行列積を行うたびに入力の転置を行うことで 64 行程度で Accept できますし、

最初に \(A,B\) の転置を行ってから行列積を行うことで、36 行程度で Accept できます。

小行列のレイアウトのおさらい

詳しくは Mmul 4_16_16 のときに ((*:2), (2:1, 4_PE:1, 2_W:1)) のレイアウトの意味を説明しました。

今回の場合、

また、出力の \(C\) に関して、

Inputs

Outputs

Testcases

testcase.vsm

Submission

ログイン / 新規登録