[llvm] d87d361 - [Matrix] Fix shape for factored transpose

Adam Nemet via llvm-commits llvm-commits at lists.llvm.org
Tue Jul 27 11:36:51 PDT 2021


Author: Adam Nemet
Date: 2021-07-27T11:36:13-07:00
New Revision: d87d3615f75502b3adf93d05d4a217f6ab947fdd

URL: https://github.com/llvm/llvm-project/commit/d87d3615f75502b3adf93d05d4a217f6ab947fdd
DIFF: https://github.com/llvm/llvm-project/commit/d87d3615f75502b3adf93d05d4a217f6ab947fdd.diff

LOG: [Matrix] Fix shape for factored transpose

The shape of the input is C x R.

Differential Revision: https://reviews.llvm.org/D106722

Added: 
    

Modified: 
    llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
    llvm/test/Transforms/LowerMatrixIntrinsics/transpose-opts.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
index ab75cd3f566b4..42c183a6408e3 100644
--- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
+++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
@@ -774,6 +774,7 @@ class LowerMatrixIntrinsics {
         ++II;
         Value *A, *B, *AT, *BT;
         ConstantInt *R, *K, *C;
+        // A^t * B ^t -> (B * A)^t
         if (match(&*I, m_Intrinsic<Intrinsic::matrix_multiply>(
                            m_Value(A), m_Value(B), m_ConstantInt(R),
                            m_ConstantInt(K), m_ConstantInt(C))) &&
@@ -784,8 +785,8 @@ class LowerMatrixIntrinsics {
           Value *M = Builder.CreateMatrixMultiply(
               BT, AT, C->getZExtValue(), K->getZExtValue(), R->getZExtValue());
           setShapeInfo(M, {C, R});
-          Value *NewInst = Builder.CreateMatrixTranspose(M, R->getZExtValue(),
-                                                         C->getZExtValue());
+          Instruction *NewInst = Builder.CreateMatrixTranspose(
+              M, C->getZExtValue(), R->getZExtValue());
           ReplaceAllUsesWith(*I, NewInst);
           if (I->use_empty())
             I->eraseFromParent();

diff  --git a/llvm/test/Transforms/LowerMatrixIntrinsics/transpose-opts.ll b/llvm/test/Transforms/LowerMatrixIntrinsics/transpose-opts.ll
index 2a7ac8278d5e2..d7c89ae50295a 100644
--- a/llvm/test/Transforms/LowerMatrixIntrinsics/transpose-opts.ll
+++ b/llvm/test/Transforms/LowerMatrixIntrinsics/transpose-opts.ll
@@ -1012,6 +1012,126 @@ define <6 x double> @transpose_of_transpose_of_non_matrix_op(double* %a) {
   ret <6 x double> %tt
 }
 
+define <12 x double> @factor_transpose(<6 x double> %a, <8 x double> %b) {
+; CHECK-LABEL: @factor_transpose(
+; CHECK-NEXT:    [[SPLIT:%.*]] = shufflevector <8 x double> [[B:%.*]], <8 x double> poison, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
+; CHECK-NEXT:    [[SPLIT1:%.*]] = shufflevector <8 x double> [[B]], <8 x double> poison, <4 x i32> <i32 4, i32 5, i32 6, i32 7>
+; CHECK-NEXT:    [[SPLIT2:%.*]] = shufflevector <6 x double> [[A:%.*]], <6 x double> poison, <2 x i32> <i32 0, i32 1>
+; CHECK-NEXT:    [[SPLIT3:%.*]] = shufflevector <6 x double> [[A]], <6 x double> poison, <2 x i32> <i32 2, i32 3>
+; CHECK-NEXT:    [[SPLIT4:%.*]] = shufflevector <6 x double> [[A]], <6 x double> poison, <2 x i32> <i32 4, i32 5>
+; CHECK-NEXT:    [[BLOCK:%.*]] = shufflevector <4 x double> [[SPLIT]], <4 x double> poison, <2 x i32> <i32 0, i32 1>
+; CHECK-NEXT:    [[TMP1:%.*]] = extractelement <2 x double> [[SPLIT2]], i64 0
+; CHECK-NEXT:    [[SPLAT_SPLATINSERT:%.*]] = insertelement <2 x double> poison, double [[TMP1]], i32 0
+; CHECK-NEXT:    [[SPLAT_SPLAT:%.*]] = shufflevector <2 x double> [[SPLAT_SPLATINSERT]], <2 x double> poison, <2 x i32> zeroinitializer
+; CHECK-NEXT:    [[TMP2:%.*]] = fmul <2 x double> [[BLOCK]], [[SPLAT_SPLAT]]
+; CHECK-NEXT:    [[BLOCK5:%.*]] = shufflevector <4 x double> [[SPLIT1]], <4 x double> poison, <2 x i32> <i32 0, i32 1>
+; CHECK-NEXT:    [[TMP3:%.*]] = extractelement <2 x double> [[SPLIT2]], i64 1
+; CHECK-NEXT:    [[SPLAT_SPLATINSERT6:%.*]] = insertelement <2 x double> poison, double [[TMP3]], i32 0
+; CHECK-NEXT:    [[SPLAT_SPLAT7:%.*]] = shufflevector <2 x double> [[SPLAT_SPLATINSERT6]], <2 x double> poison, <2 x i32> zeroinitializer
+; CHECK-NEXT:    [[TMP4:%.*]] = fmul <2 x double> [[BLOCK5]], [[SPLAT_SPLAT7]]
+; CHECK-NEXT:    [[TMP5:%.*]] = fadd <2 x double> [[TMP2]], [[TMP4]]
+; CHECK-NEXT:    [[TMP6:%.*]] = shufflevector <2 x double> [[TMP5]], <2 x double> poison, <4 x i32> <i32 0, i32 1, i32 undef, i32 undef>
+; CHECK-NEXT:    [[TMP7:%.*]] = shufflevector <4 x double> undef, <4 x double> [[TMP6]], <4 x i32> <i32 4, i32 5, i32 2, i32 3>
+; CHECK-NEXT:    [[BLOCK8:%.*]] = shufflevector <4 x double> [[SPLIT]], <4 x double> poison, <2 x i32> <i32 2, i32 3>
+; CHECK-NEXT:    [[TMP8:%.*]] = extractelement <2 x double> [[SPLIT2]], i64 0
+; CHECK-NEXT:    [[SPLAT_SPLATINSERT9:%.*]] = insertelement <2 x double> poison, double [[TMP8]], i32 0
+; CHECK-NEXT:    [[SPLAT_SPLAT10:%.*]] = shufflevector <2 x double> [[SPLAT_SPLATINSERT9]], <2 x double> poison, <2 x i32> zeroinitializer
+; CHECK-NEXT:    [[TMP9:%.*]] = fmul <2 x double> [[BLOCK8]], [[SPLAT_SPLAT10]]
+; CHECK-NEXT:    [[BLOCK11:%.*]] = shufflevector <4 x double> [[SPLIT1]], <4 x double> poison, <2 x i32> <i32 2, i32 3>
+; CHECK-NEXT:    [[TMP10:%.*]] = extractelement <2 x double> [[SPLIT2]], i64 1
+; CHECK-NEXT:    [[SPLAT_SPLATINSERT12:%.*]] = insertelement <2 x double> poison, double [[TMP10]], i32 0
+; CHECK-NEXT:    [[SPLAT_SPLAT13:%.*]] = shufflevector <2 x double> [[SPLAT_SPLATINSERT12]], <2 x double> poison, <2 x i32> zeroinitializer
+; CHECK-NEXT:    [[TMP11:%.*]] = fmul <2 x double> [[BLOCK11]], [[SPLAT_SPLAT13]]
+; CHECK-NEXT:    [[TMP12:%.*]] = fadd <2 x double> [[TMP9]], [[TMP11]]
+; CHECK-NEXT:    [[TMP13:%.*]] = shufflevector <2 x double> [[TMP12]], <2 x double> poison, <4 x i32> <i32 0, i32 1, i32 undef, i32 undef>
+; CHECK-NEXT:    [[TMP14:%.*]] = shufflevector <4 x double> [[TMP7]], <4 x double> [[TMP13]], <4 x i32> <i32 0, i32 1, i32 4, i32 5>
+; CHECK-NEXT:    [[BLOCK14:%.*]] = shufflevector <4 x double> [[SPLIT]], <4 x double> poison, <2 x i32> <i32 0, i32 1>
+; CHECK-NEXT:    [[TMP15:%.*]] = extractelement <2 x double> [[SPLIT3]], i64 0
+; CHECK-NEXT:    [[SPLAT_SPLATINSERT15:%.*]] = insertelement <2 x double> poison, double [[TMP15]], i32 0
+; CHECK-NEXT:    [[SPLAT_SPLAT16:%.*]] = shufflevector <2 x double> [[SPLAT_SPLATINSERT15]], <2 x double> poison, <2 x i32> zeroinitializer
+; CHECK-NEXT:    [[TMP16:%.*]] = fmul <2 x double> [[BLOCK14]], [[SPLAT_SPLAT16]]
+; CHECK-NEXT:    [[BLOCK17:%.*]] = shufflevector <4 x double> [[SPLIT1]], <4 x double> poison, <2 x i32> <i32 0, i32 1>
+; CHECK-NEXT:    [[TMP17:%.*]] = extractelement <2 x double> [[SPLIT3]], i64 1
+; CHECK-NEXT:    [[SPLAT_SPLATINSERT18:%.*]] = insertelement <2 x double> poison, double [[TMP17]], i32 0
+; CHECK-NEXT:    [[SPLAT_SPLAT19:%.*]] = shufflevector <2 x double> [[SPLAT_SPLATINSERT18]], <2 x double> poison, <2 x i32> zeroinitializer
+; CHECK-NEXT:    [[TMP18:%.*]] = fmul <2 x double> [[BLOCK17]], [[SPLAT_SPLAT19]]
+; CHECK-NEXT:    [[TMP19:%.*]] = fadd <2 x double> [[TMP16]], [[TMP18]]
+; CHECK-NEXT:    [[TMP20:%.*]] = shufflevector <2 x double> [[TMP19]], <2 x double> poison, <4 x i32> <i32 0, i32 1, i32 undef, i32 undef>
+; CHECK-NEXT:    [[TMP21:%.*]] = shufflevector <4 x double> undef, <4 x double> [[TMP20]], <4 x i32> <i32 4, i32 5, i32 2, i32 3>
+; CHECK-NEXT:    [[BLOCK20:%.*]] = shufflevector <4 x double> [[SPLIT]], <4 x double> poison, <2 x i32> <i32 2, i32 3>
+; CHECK-NEXT:    [[TMP22:%.*]] = extractelement <2 x double> [[SPLIT3]], i64 0
+; CHECK-NEXT:    [[SPLAT_SPLATINSERT21:%.*]] = insertelement <2 x double> poison, double [[TMP22]], i32 0
+; CHECK-NEXT:    [[SPLAT_SPLAT22:%.*]] = shufflevector <2 x double> [[SPLAT_SPLATINSERT21]], <2 x double> poison, <2 x i32> zeroinitializer
+; CHECK-NEXT:    [[TMP23:%.*]] = fmul <2 x double> [[BLOCK20]], [[SPLAT_SPLAT22]]
+; CHECK-NEXT:    [[BLOCK23:%.*]] = shufflevector <4 x double> [[SPLIT1]], <4 x double> poison, <2 x i32> <i32 2, i32 3>
+; CHECK-NEXT:    [[TMP24:%.*]] = extractelement <2 x double> [[SPLIT3]], i64 1
+; CHECK-NEXT:    [[SPLAT_SPLATINSERT24:%.*]] = insertelement <2 x double> poison, double [[TMP24]], i32 0
+; CHECK-NEXT:    [[SPLAT_SPLAT25:%.*]] = shufflevector <2 x double> [[SPLAT_SPLATINSERT24]], <2 x double> poison, <2 x i32> zeroinitializer
+; CHECK-NEXT:    [[TMP25:%.*]] = fmul <2 x double> [[BLOCK23]], [[SPLAT_SPLAT25]]
+; CHECK-NEXT:    [[TMP26:%.*]] = fadd <2 x double> [[TMP23]], [[TMP25]]
+; CHECK-NEXT:    [[TMP27:%.*]] = shufflevector <2 x double> [[TMP26]], <2 x double> poison, <4 x i32> <i32 0, i32 1, i32 undef, i32 undef>
+; CHECK-NEXT:    [[TMP28:%.*]] = shufflevector <4 x double> [[TMP21]], <4 x double> [[TMP27]], <4 x i32> <i32 0, i32 1, i32 4, i32 5>
+; CHECK-NEXT:    [[BLOCK26:%.*]] = shufflevector <4 x double> [[SPLIT]], <4 x double> poison, <2 x i32> <i32 0, i32 1>
+; CHECK-NEXT:    [[TMP29:%.*]] = extractelement <2 x double> [[SPLIT4]], i64 0
+; CHECK-NEXT:    [[SPLAT_SPLATINSERT27:%.*]] = insertelement <2 x double> poison, double [[TMP29]], i32 0
+; CHECK-NEXT:    [[SPLAT_SPLAT28:%.*]] = shufflevector <2 x double> [[SPLAT_SPLATINSERT27]], <2 x double> poison, <2 x i32> zeroinitializer
+; CHECK-NEXT:    [[TMP30:%.*]] = fmul <2 x double> [[BLOCK26]], [[SPLAT_SPLAT28]]
+; CHECK-NEXT:    [[BLOCK29:%.*]] = shufflevector <4 x double> [[SPLIT1]], <4 x double> poison, <2 x i32> <i32 0, i32 1>
+; CHECK-NEXT:    [[TMP31:%.*]] = extractelement <2 x double> [[SPLIT4]], i64 1
+; CHECK-NEXT:    [[SPLAT_SPLATINSERT30:%.*]] = insertelement <2 x double> poison, double [[TMP31]], i32 0
+; CHECK-NEXT:    [[SPLAT_SPLAT31:%.*]] = shufflevector <2 x double> [[SPLAT_SPLATINSERT30]], <2 x double> poison, <2 x i32> zeroinitializer
+; CHECK-NEXT:    [[TMP32:%.*]] = fmul <2 x double> [[BLOCK29]], [[SPLAT_SPLAT31]]
+; CHECK-NEXT:    [[TMP33:%.*]] = fadd <2 x double> [[TMP30]], [[TMP32]]
+; CHECK-NEXT:    [[TMP34:%.*]] = shufflevector <2 x double> [[TMP33]], <2 x double> poison, <4 x i32> <i32 0, i32 1, i32 undef, i32 undef>
+; CHECK-NEXT:    [[TMP35:%.*]] = shufflevector <4 x double> undef, <4 x double> [[TMP34]], <4 x i32> <i32 4, i32 5, i32 2, i32 3>
+; CHECK-NEXT:    [[BLOCK32:%.*]] = shufflevector <4 x double> [[SPLIT]], <4 x double> poison, <2 x i32> <i32 2, i32 3>
+; CHECK-NEXT:    [[TMP36:%.*]] = extractelement <2 x double> [[SPLIT4]], i64 0
+; CHECK-NEXT:    [[SPLAT_SPLATINSERT33:%.*]] = insertelement <2 x double> poison, double [[TMP36]], i32 0
+; CHECK-NEXT:    [[SPLAT_SPLAT34:%.*]] = shufflevector <2 x double> [[SPLAT_SPLATINSERT33]], <2 x double> poison, <2 x i32> zeroinitializer
+; CHECK-NEXT:    [[TMP37:%.*]] = fmul <2 x double> [[BLOCK32]], [[SPLAT_SPLAT34]]
+; CHECK-NEXT:    [[BLOCK35:%.*]] = shufflevector <4 x double> [[SPLIT1]], <4 x double> poison, <2 x i32> <i32 2, i32 3>
+; CHECK-NEXT:    [[TMP38:%.*]] = extractelement <2 x double> [[SPLIT4]], i64 1
+; CHECK-NEXT:    [[SPLAT_SPLATINSERT36:%.*]] = insertelement <2 x double> poison, double [[TMP38]], i32 0
+; CHECK-NEXT:    [[SPLAT_SPLAT37:%.*]] = shufflevector <2 x double> [[SPLAT_SPLATINSERT36]], <2 x double> poison, <2 x i32> zeroinitializer
+; CHECK-NEXT:    [[TMP39:%.*]] = fmul <2 x double> [[BLOCK35]], [[SPLAT_SPLAT37]]
+; CHECK-NEXT:    [[TMP40:%.*]] = fadd <2 x double> [[TMP37]], [[TMP39]]
+; CHECK-NEXT:    [[TMP41:%.*]] = shufflevector <2 x double> [[TMP40]], <2 x double> poison, <4 x i32> <i32 0, i32 1, i32 undef, i32 undef>
+; CHECK-NEXT:    [[TMP42:%.*]] = shufflevector <4 x double> [[TMP35]], <4 x double> [[TMP41]], <4 x i32> <i32 0, i32 1, i32 4, i32 5>
+; CHECK-NEXT:    [[TMP43:%.*]] = extractelement <4 x double> [[TMP14]], i64 0
+; CHECK-NEXT:    [[TMP44:%.*]] = insertelement <3 x double> undef, double [[TMP43]], i64 0
+; CHECK-NEXT:    [[TMP45:%.*]] = extractelement <4 x double> [[TMP28]], i64 0
+; CHECK-NEXT:    [[TMP46:%.*]] = insertelement <3 x double> [[TMP44]], double [[TMP45]], i64 1
+; CHECK-NEXT:    [[TMP47:%.*]] = extractelement <4 x double> [[TMP42]], i64 0
+; CHECK-NEXT:    [[TMP48:%.*]] = insertelement <3 x double> [[TMP46]], double [[TMP47]], i64 2
+; CHECK-NEXT:    [[TMP49:%.*]] = extractelement <4 x double> [[TMP14]], i64 1
+; CHECK-NEXT:    [[TMP50:%.*]] = insertelement <3 x double> undef, double [[TMP49]], i64 0
+; CHECK-NEXT:    [[TMP51:%.*]] = extractelement <4 x double> [[TMP28]], i64 1
+; CHECK-NEXT:    [[TMP52:%.*]] = insertelement <3 x double> [[TMP50]], double [[TMP51]], i64 1
+; CHECK-NEXT:    [[TMP53:%.*]] = extractelement <4 x double> [[TMP42]], i64 1
+; CHECK-NEXT:    [[TMP54:%.*]] = insertelement <3 x double> [[TMP52]], double [[TMP53]], i64 2
+; CHECK-NEXT:    [[TMP55:%.*]] = extractelement <4 x double> [[TMP14]], i64 2
+; CHECK-NEXT:    [[TMP56:%.*]] = insertelement <3 x double> undef, double [[TMP55]], i64 0
+; CHECK-NEXT:    [[TMP57:%.*]] = extractelement <4 x double> [[TMP28]], i64 2
+; CHECK-NEXT:    [[TMP58:%.*]] = insertelement <3 x double> [[TMP56]], double [[TMP57]], i64 1
+; CHECK-NEXT:    [[TMP59:%.*]] = extractelement <4 x double> [[TMP42]], i64 2
+; CHECK-NEXT:    [[TMP60:%.*]] = insertelement <3 x double> [[TMP58]], double [[TMP59]], i64 2
+; CHECK-NEXT:    [[TMP61:%.*]] = extractelement <4 x double> [[TMP14]], i64 3
+; CHECK-NEXT:    [[TMP62:%.*]] = insertelement <3 x double> undef, double [[TMP61]], i64 0
+; CHECK-NEXT:    [[TMP63:%.*]] = extractelement <4 x double> [[TMP28]], i64 3
+; CHECK-NEXT:    [[TMP64:%.*]] = insertelement <3 x double> [[TMP62]], double [[TMP63]], i64 1
+; CHECK-NEXT:    [[TMP65:%.*]] = extractelement <4 x double> [[TMP42]], i64 3
+; CHECK-NEXT:    [[TMP66:%.*]] = insertelement <3 x double> [[TMP64]], double [[TMP65]], i64 2
+; CHECK-NEXT:    [[TMP67:%.*]] = shufflevector <3 x double> [[TMP48]], <3 x double> [[TMP54]], <6 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5>
+; CHECK-NEXT:    [[TMP68:%.*]] = shufflevector <3 x double> [[TMP60]], <3 x double> [[TMP66]], <6 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5>
+; CHECK-NEXT:    [[TMP69:%.*]] = shufflevector <6 x double> [[TMP67]], <6 x double> [[TMP68]], <12 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7, i32 8, i32 9, i32 10, i32 11>
+; CHECK-NEXT:    ret <12 x double> [[TMP69]]
+;
+  %at = call <6 x double> @llvm.matrix.transpose.v6f64.v6f64(<6 x double> %a, i32 2, i32 3)
+  %bt = call <8 x double> @llvm.matrix.transpose.v8f64.v8f64(<8 x double> %b, i32 4, i32 2)
+  %m = call <12 x double> @llvm.matrix.multiply.v12f64.v6f64.v8f64(<6 x double> %at, <8 x double> %bt, i32 3, i32 2, i32 4)
+  ret <12 x double> %m
+}
+
 declare <9 x double> @llvm.matrix.multiply.v9f64.v9f64.v9f64(<9 x double>, <9 x double>, i32, i32, i32)
 declare <12 x double> @llvm.matrix.multiply.v12f64.v6f64.v8f64(<6 x double>, <8 x double>, i32, i32, i32)
 declare <8 x double> @llvm.matrix.multiply.v8f64.v6f64.v12f64(<6 x double> %a, <12 x double>, i32, i32, i32)


        


More information about the llvm-commits mailing list