Mmul TA 16_32_8

top Top: -

par Par: 96 lines

Problem Statement

\(32 \times 16\) 行列 \(A\) と、\(32 \times 8\) 行列 \(B\) に対して、行列積 \(C = A^T \times B\) (shape:\(16 \times 8\)) を計算してください。\(A,B,C\) のレイアウトは以下のとおりです。

A: ((32:2), (2:1, 4_PE:1, 2_W:1)) B: ((32:1), (4_PE:1, 2_W:1)) C: ((16:1), (4_PE:1, 2_W:1))
\(A, B, C\) の値はこちらです。

A:

[[ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15], [ 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31], [ 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47], [ 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63], [ 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79], [ 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95], [ 96, 97, 98, 99,100,101,102,103,104,105,106,107,108,109,110,111], [112,113,114,115,116,117,118,119,120,121,122,123,124,125,126,127], [128,129,130,131,132,133,134,135,136,137,138,139,140,141,142,143], [144,145,146,147,148,149,150,151,152,153,154,155,156,157,158,159], [160,161,162,163,164,165,166,167,168,169,170,171,172,173,174,175], [176,177,178,179,180,181,182,183,184,185,186,187,188,189,190,191], [192,193,194,195,196,197,198,199,200,201,202,203,204,205,206,207], [208,209,210,211,212,213,214,215,216,217,218,219,220,221,222,223], [224,225,226,227,228,229,230,231,232,233,234,235,236,237,238,239], [240,241,242,243,244,245,246,247,248,249,250,251,252,253,254,255], [256,257,258,259,260,261,262,263,264,265,266,267,268,269,270,271], [272,273,274,275,276,277,278,279,280,281,282,283,284,285,286,287], [288,289,290,291,292,293,294,295,296,297,298,299,300,301,302,303], [304,305,306,307,308,309,310,311,312,313,314,315,316,317,318,319], [320,321,322,323,324,325,326,327,328,329,330,331,332,333,334,335], [336,337,338,339,340,341,342,343,344,345,346,347,348,349,350,351], [352,353,354,355,356,357,358,359,360,361,362,363,364,365,366,367], [368,369,370,371,372,373,374,375,376,377,378,379,380,381,382,383], [384,385,386,387,388,389,390,391,392,393,394,395,396,397,398,399], [400,401,402,403,404,405,406,407,408,409,410,411,412,413,414,415], [416,417,418,419,420,421,422,423,424,425,426,427,428,429,430,431], [432,433,434,435,436,437,438,439,440,441,442,443,444,445,446,447], [448,449,450,451,452,453,454,455,456,457,458,459,460,461,462,463], [464,465,466,467,468,469,470,471,472,473,474,475,476,477,478,479], [480,481,482,483,484,485,486,487,488,489,490,491,492,493,494,495], [496,497,498,499,500,501,502,503,504,505,506,507,508,509,510,511]]

B:

[[-256,-255,-254,-253,-252,-251,-250,-249], [-248,-247,-246,-245,-244,-243,-242,-241], [-240,-239,-238,-237,-236,-235,-234,-233], [-232,-231,-230,-229,-228,-227,-226,-225], [-224,-223,-222,-221,-220,-219,-218,-217], [-216,-215,-214,-213,-212,-211,-210,-209], [-208,-207,-206,-205,-204,-203,-202,-201], [-200,-199,-198,-197,-196,-195,-194,-193], [-192,-191,-190,-189,-188,-187,-186,-185], [-184,-183,-182,-181,-180,-179,-178,-177], [-176,-175,-174,-173,-172,-171,-170,-169], [-168,-167,-166,-165,-164,-163,-162,-161], [-160,-159,-158,-157,-156,-155,-154,-153], [-152,-151,-150,-149,-148,-147,-146,-145], [-144,-143,-142,-141,-140,-139,-138,-137], [-136,-135,-134,-133,-132,-131,-130,-129], [-128,-127,-126,-125,-124,-123,-122,-121], [-120,-119,-118,-117,-116,-115,-114,-113], [-112,-111,-110,-109,-108,-107,-106,-105], [-104,-103,-102,-101,-100,-99,-98,-97], [-96,-95,-94,-93,-92,-91,-90,-89], [-88,-87,-86,-85,-84,-83,-82,-81], [-80,-79,-78,-77,-76,-75,-74,-73], [-72,-71,-70,-69,-68,-67,-66,-65], [-64,-63,-62,-61,-60,-59,-58,-57], [-56,-55,-54,-53,-52,-51,-50,-49], [-48,-47,-46,-45,-44,-43,-42,-41], [-40,-39,-38,-37,-36,-35,-34,-33], [-32,-31,-30,-29,-28,-27,-26,-25], [-24,-23,-22,-21,-20,-19,-18,-17], [-16,-15,-14,-13,-12,-11,-10, -9], [ -8, -7, -6, -5, -4, -3, -2, -1]]

C:
import numpy as np A = np.array([[ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15], [ 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31], [ 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47], [ 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63], [ 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79], [ 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95], [ 96, 97, 98, 99,100,101,102,103,104,105,106,107,108,109,110,111], [112,113,114,115,116,117,118,119,120,121,122,123,124,125,126,127], [128,129,130,131,132,133,134,135,136,137,138,139,140,141,142,143], [144,145,146,147,148,149,150,151,152,153,154,155,156,157,158,159], [160,161,162,163,164,165,166,167,168,169,170,171,172,173,174,175], [176,177,178,179,180,181,182,183,184,185,186,187,188,189,190,191], [192,193,194,195,196,197,198,199,200,201,202,203,204,205,206,207], [208,209,210,211,212,213,214,215,216,217,218,219,220,221,222,223], [224,225,226,227,228,229,230,231,232,233,234,235,236,237,238,239], [240,241,242,243,244,245,246,247,248,249,250,251,252,253,254,255], [256,257,258,259,260,261,262,263,264,265,266,267,268,269,270,271], [272,273,274,275,276,277,278,279,280,281,282,283,284,285,286,287], [288,289,290,291,292,293,294,295,296,297,298,299,300,301,302,303], [304,305,306,307,308,309,310,311,312,313,314,315,316,317,318,319], [320,321,322,323,324,325,326,327,328,329,330,331,332,333,334,335], [336,337,338,339,340,341,342,343,344,345,346,347,348,349,350,351], [352,353,354,355,356,357,358,359,360,361,362,363,364,365,366,367], [368,369,370,371,372,373,374,375,376,377,378,379,380,381,382,383], [384,385,386,387,388,389,390,391,392,393,394,395,396,397,398,399], [400,401,402,403,404,405,406,407,408,409,410,411,412,413,414,415], [416,417,418,419,420,421,422,423,424,425,426,427,428,429,430,431], [432,433,434,435,436,437,438,439,440,441,442,443,444,445,446,447], [448,449,450,451,452,453,454,455,456,457,458,459,460,461,462,463], [464,465,466,467,468,469,470,471,472,473,474,475,476,477,478,479], [480,481,482,483,484,485,486,487,488,489,490,491,492,493,494,495], [496,497,498,499,500,501,502,503,504,505,506,507,508,509,510,511]]) B = np.array([[-256,-255,-254,-253,-252,-251,-250,-249], [-248,-247,-246,-245,-244,-243,-242,-241], [-240,-239,-238,-237,-236,-235,-234,-233], [-232,-231,-230,-229,-228,-227,-226,-225], [-224,-223,-222,-221,-220,-219,-218,-217], [-216,-215,-214,-213,-212,-211,-210,-209], [-208,-207,-206,-205,-204,-203,-202,-201], [-200,-199,-198,-197,-196,-195,-194,-193], [-192,-191,-190,-189,-188,-187,-186,-185], [-184,-183,-182,-181,-180,-179,-178,-177], [-176,-175,-174,-173,-172,-171,-170,-169], [-168,-167,-166,-165,-164,-163,-162,-161], [-160,-159,-158,-157,-156,-155,-154,-153], [-152,-151,-150,-149,-148,-147,-146,-145], [-144,-143,-142,-141,-140,-139,-138,-137], [-136,-135,-134,-133,-132,-131,-130,-129], [-128,-127,-126,-125,-124,-123,-122,-121], [-120,-119,-118,-117,-116,-115,-114,-113], [-112,-111,-110,-109,-108,-107,-106,-105], [-104,-103,-102,-101,-100,-99,-98,-97], [-96,-95,-94,-93,-92,-91,-90,-89], [-88,-87,-86,-85,-84,-83,-82,-81], [-80,-79,-78,-77,-76,-75,-74,-73], [-72,-71,-70,-69,-68,-67,-66,-65], [-64,-63,-62,-61,-60,-59,-58,-57], [-56,-55,-54,-53,-52,-51,-50,-49], [-48,-47,-46,-45,-44,-43,-42,-41], [-40,-39,-38,-37,-36,-35,-34,-33], [-32,-31,-30,-29,-28,-27,-26,-25], [-24,-23,-22,-21,-20,-19,-18,-17], [-16,-15,-14,-13,-12,-11,-10, -9], [ -8, -7, -6, -5, -4, -3, -2, -1]]) A.T @ B

[[-698368,-690432,-682496,-674560,-666624,-658688,-650752,-642816], [-702592,-694624,-686656,-678688,-670720,-662752,-654784,-646816], [-706816,-698816,-690816,-682816,-674816,-666816,-658816,-650816], [-711040,-703008,-694976,-686944,-678912,-670880,-662848,-654816], [-715264,-707200,-699136,-691072,-683008,-674944,-666880,-658816], [-719488,-711392,-703296,-695200,-687104,-679008,-670912,-662816], [-723712,-715584,-707456,-699328,-691200,-683072,-674944,-666816], [-727936,-719776,-711616,-703456,-695296,-687136,-678976,-670816], [-732160,-723968,-715776,-707584,-699392,-691200,-683008,-674816], [-736384,-728160,-719936,-711712,-703488,-695264,-687040,-678816], [-740608,-732352,-724096,-715840,-707584,-699328,-691072,-682816], [-744832,-736544,-728256,-719968,-711680,-703392,-695104,-686816], [-749056,-740736,-732416,-724096,-715776,-707456,-699136,-690816], [-753280,-744928,-736576,-728224,-719872,-711520,-703168,-694816], [-757504,-749120,-740736,-732352,-723968,-715584,-707200,-698816], [-761728,-753312,-744896,-736480,-728064,-719648,-711232,-702816]]

Explanation

方針

行列 \(A, B, C\) を、\(8\times8\) の単位の小行列に分けて考えると、以下のようになります。

A = [ A00, A01 ] [ A10, A11 ] [ A20, A21 ] [ A30, A31 ] B = [ B0 ] [ B1 ] [ B2 ] [ B3 ]
このとき、求める行列 \(C\) は、このような分割と計算になります。
C0: [ A00 x B0 + A10 x B1 + A20 x B2 + A30 x B3 ] C1: [ A01 x B0 + A11 x B1 + A21 x B2 + A31 x B3 ]

行列 \(A, B, C\) の小行列のアドレス
A00 $lm0v4, $lm16v4 A01 $lm2v4, $lm18v4 A10 $lm32v4, $lm48v4 A11 $lm34v4, $lm50v4 A20 $lm64v4, $lm80v4 A21 $lm66v4, $lm82v4 A30 $lm96v4, $lm112v4 A31 $lm98v4, $lm114v4 B0 $lm128v, $lm136v B1 $lm144v, $lm152v B2 $lm160v, $lm168v B3 $lm176v, $lm184v C0 $ln0v, $ln8v C1 $ln16v, $ln24v

Inputs

Outputs

Testcases

testcase.vsm

Submission

ログイン / 新規登録