\(32 \times 16\) 行列 \(A\) と、\(32 \times 8\) 行列 \(B\) に対して、行列積 \(C = A^T \times B\) (shape:\(16 \times 8\)) を計算してください。\(A,B,C\) のレイアウトは以下のとおりです。
A: ((32:2), (2:1, 4_PE:1, 2_W:1))
B: ((32:1), (4_PE:1, 2_W:1))
C: ((16:1), (4_PE:1, 2_W:1))
\(A, B, C\) の値はこちらです。
A:
[[ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15],
[ 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
[ 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47],
[ 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63],
[ 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79],
[ 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95],
[ 96, 97, 98, 99,100,101,102,103,104,105,106,107,108,109,110,111],
[112,113,114,115,116,117,118,119,120,121,122,123,124,125,126,127],
[128,129,130,131,132,133,134,135,136,137,138,139,140,141,142,143],
[144,145,146,147,148,149,150,151,152,153,154,155,156,157,158,159],
[160,161,162,163,164,165,166,167,168,169,170,171,172,173,174,175],
[176,177,178,179,180,181,182,183,184,185,186,187,188,189,190,191],
[192,193,194,195,196,197,198,199,200,201,202,203,204,205,206,207],
[208,209,210,211,212,213,214,215,216,217,218,219,220,221,222,223],
[224,225,226,227,228,229,230,231,232,233,234,235,236,237,238,239],
[240,241,242,243,244,245,246,247,248,249,250,251,252,253,254,255],
[256,257,258,259,260,261,262,263,264,265,266,267,268,269,270,271],
[272,273,274,275,276,277,278,279,280,281,282,283,284,285,286,287],
[288,289,290,291,292,293,294,295,296,297,298,299,300,301,302,303],
[304,305,306,307,308,309,310,311,312,313,314,315,316,317,318,319],
[320,321,322,323,324,325,326,327,328,329,330,331,332,333,334,335],
[336,337,338,339,340,341,342,343,344,345,346,347,348,349,350,351],
[352,353,354,355,356,357,358,359,360,361,362,363,364,365,366,367],
[368,369,370,371,372,373,374,375,376,377,378,379,380,381,382,383],
[384,385,386,387,388,389,390,391,392,393,394,395,396,397,398,399],
[400,401,402,403,404,405,406,407,408,409,410,411,412,413,414,415],
[416,417,418,419,420,421,422,423,424,425,426,427,428,429,430,431],
[432,433,434,435,436,437,438,439,440,441,442,443,444,445,446,447],
[448,449,450,451,452,453,454,455,456,457,458,459,460,461,462,463],
[464,465,466,467,468,469,470,471,472,473,474,475,476,477,478,479],
[480,481,482,483,484,485,486,487,488,489,490,491,492,493,494,495],
[496,497,498,499,500,501,502,503,504,505,506,507,508,509,510,511]]
B:
[[-256,-255,-254,-253,-252,-251,-250,-249],
[-248,-247,-246,-245,-244,-243,-242,-241],
[-240,-239,-238,-237,-236,-235,-234,-233],
[-232,-231,-230,-229,-228,-227,-226,-225],
[-224,-223,-222,-221,-220,-219,-218,-217],
[-216,-215,-214,-213,-212,-211,-210,-209],
[-208,-207,-206,-205,-204,-203,-202,-201],
[-200,-199,-198,-197,-196,-195,-194,-193],
[-192,-191,-190,-189,-188,-187,-186,-185],
[-184,-183,-182,-181,-180,-179,-178,-177],
[-176,-175,-174,-173,-172,-171,-170,-169],
[-168,-167,-166,-165,-164,-163,-162,-161],
[-160,-159,-158,-157,-156,-155,-154,-153],
[-152,-151,-150,-149,-148,-147,-146,-145],
[-144,-143,-142,-141,-140,-139,-138,-137],
[-136,-135,-134,-133,-132,-131,-130,-129],
[-128,-127,-126,-125,-124,-123,-122,-121],
[-120,-119,-118,-117,-116,-115,-114,-113],
[-112,-111,-110,-109,-108,-107,-106,-105],
[-104,-103,-102,-101,-100,-99,-98,-97],
[-96,-95,-94,-93,-92,-91,-90,-89],
[-88,-87,-86,-85,-84,-83,-82,-81],
[-80,-79,-78,-77,-76,-75,-74,-73],
[-72,-71,-70,-69,-68,-67,-66,-65],
[-64,-63,-62,-61,-60,-59,-58,-57],
[-56,-55,-54,-53,-52,-51,-50,-49],
[-48,-47,-46,-45,-44,-43,-42,-41],
[-40,-39,-38,-37,-36,-35,-34,-33],
[-32,-31,-30,-29,-28,-27,-26,-25],
[-24,-23,-22,-21,-20,-19,-18,-17],
[-16,-15,-14,-13,-12,-11,-10, -9],
[ -8, -7, -6, -5, -4, -3, -2, -1]]
C:
import numpy as np
A = np.array([[ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15],
[ 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
[ 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47],
[ 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63],
[ 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79],
[ 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95],
[ 96, 97, 98, 99,100,101,102,103,104,105,106,107,108,109,110,111],
[112,113,114,115,116,117,118,119,120,121,122,123,124,125,126,127],
[128,129,130,131,132,133,134,135,136,137,138,139,140,141,142,143],
[144,145,146,147,148,149,150,151,152,153,154,155,156,157,158,159],
[160,161,162,163,164,165,166,167,168,169,170,171,172,173,174,175],
[176,177,178,179,180,181,182,183,184,185,186,187,188,189,190,191],
[192,193,194,195,196,197,198,199,200,201,202,203,204,205,206,207],
[208,209,210,211,212,213,214,215,216,217,218,219,220,221,222,223],
[224,225,226,227,228,229,230,231,232,233,234,235,236,237,238,239],
[240,241,242,243,244,245,246,247,248,249,250,251,252,253,254,255],
[256,257,258,259,260,261,262,263,264,265,266,267,268,269,270,271],
[272,273,274,275,276,277,278,279,280,281,282,283,284,285,286,287],
[288,289,290,291,292,293,294,295,296,297,298,299,300,301,302,303],
[304,305,306,307,308,309,310,311,312,313,314,315,316,317,318,319],
[320,321,322,323,324,325,326,327,328,329,330,331,332,333,334,335],
[336,337,338,339,340,341,342,343,344,345,346,347,348,349,350,351],
[352,353,354,355,356,357,358,359,360,361,362,363,364,365,366,367],
[368,369,370,371,372,373,374,375,376,377,378,379,380,381,382,383],
[384,385,386,387,388,389,390,391,392,393,394,395,396,397,398,399],
[400,401,402,403,404,405,406,407,408,409,410,411,412,413,414,415],
[416,417,418,419,420,421,422,423,424,425,426,427,428,429,430,431],
[432,433,434,435,436,437,438,439,440,441,442,443,444,445,446,447],
[448,449,450,451,452,453,454,455,456,457,458,459,460,461,462,463],
[464,465,466,467,468,469,470,471,472,473,474,475,476,477,478,479],
[480,481,482,483,484,485,486,487,488,489,490,491,492,493,494,495],
[496,497,498,499,500,501,502,503,504,505,506,507,508,509,510,511]])
B = np.array([[-256,-255,-254,-253,-252,-251,-250,-249],
[-248,-247,-246,-245,-244,-243,-242,-241],
[-240,-239,-238,-237,-236,-235,-234,-233],
[-232,-231,-230,-229,-228,-227,-226,-225],
[-224,-223,-222,-221,-220,-219,-218,-217],
[-216,-215,-214,-213,-212,-211,-210,-209],
[-208,-207,-206,-205,-204,-203,-202,-201],
[-200,-199,-198,-197,-196,-195,-194,-193],
[-192,-191,-190,-189,-188,-187,-186,-185],
[-184,-183,-182,-181,-180,-179,-178,-177],
[-176,-175,-174,-173,-172,-171,-170,-169],
[-168,-167,-166,-165,-164,-163,-162,-161],
[-160,-159,-158,-157,-156,-155,-154,-153],
[-152,-151,-150,-149,-148,-147,-146,-145],
[-144,-143,-142,-141,-140,-139,-138,-137],
[-136,-135,-134,-133,-132,-131,-130,-129],
[-128,-127,-126,-125,-124,-123,-122,-121],
[-120,-119,-118,-117,-116,-115,-114,-113],
[-112,-111,-110,-109,-108,-107,-106,-105],
[-104,-103,-102,-101,-100,-99,-98,-97],
[-96,-95,-94,-93,-92,-91,-90,-89],
[-88,-87,-86,-85,-84,-83,-82,-81],
[-80,-79,-78,-77,-76,-75,-74,-73],
[-72,-71,-70,-69,-68,-67,-66,-65],
[-64,-63,-62,-61,-60,-59,-58,-57],
[-56,-55,-54,-53,-52,-51,-50,-49],
[-48,-47,-46,-45,-44,-43,-42,-41],
[-40,-39,-38,-37,-36,-35,-34,-33],
[-32,-31,-30,-29,-28,-27,-26,-25],
[-24,-23,-22,-21,-20,-19,-18,-17],
[-16,-15,-14,-13,-12,-11,-10, -9],
[ -8, -7, -6, -5, -4, -3, -2, -1]])
A.T @ B
[[-698368,-690432,-682496,-674560,-666624,-658688,-650752,-642816],
[-702592,-694624,-686656,-678688,-670720,-662752,-654784,-646816],
[-706816,-698816,-690816,-682816,-674816,-666816,-658816,-650816],
[-711040,-703008,-694976,-686944,-678912,-670880,-662848,-654816],
[-715264,-707200,-699136,-691072,-683008,-674944,-666880,-658816],
[-719488,-711392,-703296,-695200,-687104,-679008,-670912,-662816],
[-723712,-715584,-707456,-699328,-691200,-683072,-674944,-666816],
[-727936,-719776,-711616,-703456,-695296,-687136,-678976,-670816],
[-732160,-723968,-715776,-707584,-699392,-691200,-683008,-674816],
[-736384,-728160,-719936,-711712,-703488,-695264,-687040,-678816],
[-740608,-732352,-724096,-715840,-707584,-699328,-691072,-682816],
[-744832,-736544,-728256,-719968,-711680,-703392,-695104,-686816],
[-749056,-740736,-732416,-724096,-715776,-707456,-699136,-690816],
[-753280,-744928,-736576,-728224,-719872,-711520,-703168,-694816],
[-757504,-749120,-740736,-732352,-723968,-715584,-707200,-698816],
[-761728,-753312,-744896,-736480,-728064,-719648,-711232,-702816]]