\(8 \times 16\) 行列 \(A\) と、\(8 \times 16\) 行列 \(B\) に対して、行列積 \(C = A^T \times B\) (shape:\(16 \times 16\)) を計算してください。\(A,B,C\) のレイアウトは以下のとおりです。
A: (( 8:2), (2:1, 4_PE:1, 2_W:1))
B: (( 8:2), (2:1, 4_PE:1, 2_W:1))
C: ((16:2), (2:1, 4_PE:1, 2_W:1))
\(A, B, C\) の値はこちらです。
A:
[[ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15],
[ 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
[ 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47],
[ 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63],
[ 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79],
[ 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95],
[ 96, 97, 98, 99,100,101,102,103,104,105,106,107,108,109,110,111],
[112,113,114,115,116,117,118,119,120,121,122,123,124,125,126,127]]
B:
[[200,201,202,203,204,205,206,207,208,209,210,211,212,213,214,215],
[216,217,218,219,220,221,222,223,224,225,226,227,228,229,230,231],
[232,233,234,235,236,237,238,239,240,241,242,243,244,245,246,247],
[248,249,250,251,252,253,254,255,256,257,258,259,260,261,262,263],
[264,265,266,267,268,269,270,271,272,273,274,275,276,277,278,279],
[280,281,282,283,284,285,286,287,288,289,290,291,292,293,294,295],
[296,297,298,299,300,301,302,303,304,305,306,307,308,309,310,311],
[312,313,314,315,316,317,318,319,320,321,322,323,324,325,326,327]]
C:
import numpy as np
A = np.array([[ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15],
[ 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
[ 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47],
[ 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63],
[ 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79],
[ 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95],
[ 96, 97, 98, 99,100,101,102,103,104,105,106,107,108,109,110,111],
[112,113,114,115,116,117,118,119,120,121,122,123,124,125,126,127]])
B = np.array([[200,201,202,203,204,205,206,207,208,209,210,211,212,213,214,215],
[216,217,218,219,220,221,222,223,224,225,226,227,228,229,230,231],
[232,233,234,235,236,237,238,239,240,241,242,243,244,245,246,247],
[248,249,250,251,252,253,254,255,256,257,258,259,260,261,262,263],
[264,265,266,267,268,269,270,271,272,273,274,275,276,277,278,279],
[280,281,282,283,284,285,286,287,288,289,290,291,292,293,294,295],
[296,297,298,299,300,301,302,303,304,305,306,307,308,309,310,311],
[312,313,314,315,316,317,318,319,320,321,322,323,324,325,326,327]])
A.T @ B
[[125440,125888,126336,126784,127232,127680,128128,128576,129024,129472,129920,130368,130816,131264,131712,132160],
[127488,127944,128400,128856,129312,129768,130224,130680,131136,131592,132048,132504,132960,133416,133872,134328],
[129536,130000,130464,130928,131392,131856,132320,132784,133248,133712,134176,134640,135104,135568,136032,136496],
[131584,132056,132528,133000,133472,133944,134416,134888,135360,135832,136304,136776,137248,137720,138192,138664],
[133632,134112,134592,135072,135552,136032,136512,136992,137472,137952,138432,138912,139392,139872,140352,140832],
[135680,136168,136656,137144,137632,138120,138608,139096,139584,140072,140560,141048,141536,142024,142512,143000],
[137728,138224,138720,139216,139712,140208,140704,141200,141696,142192,142688,143184,143680,144176,144672,145168],
[139776,140280,140784,141288,141792,142296,142800,143304,143808,144312,144816,145320,145824,146328,146832,147336],
[141824,142336,142848,143360,143872,144384,144896,145408,145920,146432,146944,147456,147968,148480,148992,149504],
[143872,144392,144912,145432,145952,146472,146992,147512,148032,148552,149072,149592,150112,150632,151152,151672],
[145920,146448,146976,147504,148032,148560,149088,149616,150144,150672,151200,151728,152256,152784,153312,153840],
[147968,148504,149040,149576,150112,150648,151184,151720,152256,152792,153328,153864,154400,154936,155472,156008],
[150016,150560,151104,151648,152192,152736,153280,153824,154368,154912,155456,156000,156544,157088,157632,158176],
[152064,152616,153168,153720,154272,154824,155376,155928,156480,157032,157584,158136,158688,159240,159792,160344],
[154112,154672,155232,155792,156352,156912,157472,158032,158592,159152,159712,160272,160832,161392,161952,162512],
[156160,156728,157296,157864,158432,159000,159568,160136,160704,161272,161840,162408,162976,163544,164112,164680]]
今回は、\(8 \times 16\) 行列を、右半分、左半分に分けて、それぞれ \(8 \times 8\) 行列を計算し、結合することで計算することができます。