[llvm] ebbcbb2 - [Matrix] Remove redundant transpose with dot product lowering.

Florian Hahn via llvm-commits llvm-commits at lists.llvm.org
Sun May 14 14:07:56 PDT 2023


Author: Florian Hahn
Date: 2023-05-14T22:07:38+01:00
New Revision: ebbcbb2af51a702fe804cc8c08f176e9989e9fe3

URL: https://github.com/llvm/llvm-project/commit/ebbcbb2af51a702fe804cc8c08f176e9989e9fe3
DIFF: https://github.com/llvm/llvm-project/commit/ebbcbb2af51a702fe804cc8c08f176e9989e9fe3.diff

LOG: [Matrix] Remove redundant transpose with dot product lowering.

Extend dot-product handling to skip transposes of the first operand. As
this is a vector, the conversion between column and row vector via the
transpose isn't needed.

Reviewed By: thegameg

Differential Revision: https://reviews.llvm.org/D148428

Added: 
    

Modified: 
    llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
    llvm/test/Transforms/LowerMatrixIntrinsics/dot-product-transpose-int.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
index 594556a0b13df..8508c90cc939d 100644
--- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
+++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
@@ -1359,10 +1359,12 @@ class LowerMatrixIntrinsics {
       return;
 
     auto CanBeFlattened = [](Value *Op) {
-      return match(Op, m_OneUse(m_CombineOr(
-                           m_Load(m_Value()),
-                           m_Intrinsic<Intrinsic::matrix_column_major_load>(
-                               m_Value(), m_SpecificInt(1)))));
+      return match(
+          Op, m_OneUse(m_CombineOr(
+                  m_Load(m_Value()),
+                  m_CombineOr(m_Intrinsic<Intrinsic::matrix_transpose>(),
+                              m_Intrinsic<Intrinsic::matrix_column_major_load>(
+                                  m_Value(), m_SpecificInt(1))))));
     };
     // Returns the cost benefit of using \p Op with the dot product lowering. If
     // the returned cost is < 0, the argument is cheaper to use in the
@@ -1374,21 +1376,34 @@ class LowerMatrixIntrinsics {
       FixedVectorType *VecTy = cast<FixedVectorType>(Op->getType());
       Type *EltTy = VecTy->getElementType();
 
-      if (CanBeFlattened(Op)) {
-        if (N == 1)
-          return InstructionCost(0);
+      if (!CanBeFlattened(Op)) {
+        InstructionCost EmbedCost(0);
+        // Roughly estimate the cost for embedding the columns into a vector.
+        for (unsigned I = 1; I < N; ++I)
+          EmbedCost -=
+              TTI.getShuffleCost(TTI::SK_Splice, FixedVectorType::get(EltTy, 1),
+                                 std::nullopt, TTI::TCK_RecipThroughput);
+        return EmbedCost;
+      }
 
-        return TTI.getMemoryOpCost(Instruction::Load, VecTy, Align(1), 0) -
-               N * TTI.getMemoryOpCost(Instruction::Load, EltTy, Align(1), 0);
+      if (match(Op, m_Intrinsic<Intrinsic::matrix_transpose>())) {
+        // The transpose can be skipped for the dot product lowering, roughly
+        // estimate the savings as the cost of embedding the columns in a
+        // vector.
+        InstructionCost EmbedCost(0);
+        for (unsigned I = 1; I < N; ++I)
+          EmbedCost +=
+              TTI.getShuffleCost(TTI::SK_Splice, FixedVectorType::get(EltTy, 1),
+                                 std::nullopt, TTI::TCK_RecipThroughput);
+        return EmbedCost;
       }
 
-      InstructionCost EmbedCost(0);
-      // Roughly estimate the cost for embedding the columns into a vector.
-      for (unsigned I = 1; I < N; ++I)
-        EmbedCost +=
-            TTI.getShuffleCost(TTI::SK_Splice, FixedVectorType::get(EltTy, 1),
-                               std::nullopt, TTI::TCK_RecipThroughput);
-      return EmbedCost;
+      // Costs for loads.
+      if (N == 1)
+        return InstructionCost(0);
+
+      return TTI.getMemoryOpCost(Instruction::Load, VecTy, Align(1), 0) -
+             N * TTI.getMemoryOpCost(Instruction::Load, EltTy, Align(1), 0);
     };
     auto LHSCost = GetCostForArg(LHS, LShape.NumColumns);
 
@@ -1410,8 +1425,8 @@ class LowerMatrixIntrinsics {
 
     FusedInsts.insert(MatMul);
     IRBuilder<> Builder(MatMul);
-    auto FlattenArg = [&Builder, &FusedInsts,
-                       &CanBeFlattened](Value *Op) -> Value * {
+    auto FlattenArg = [&Builder, &FusedInsts, &CanBeFlattened,
+                       this](Value *Op) -> Value * {
       // Matmul must be the only user of loads because we don't use LowerLoad
       // for row vectors (LowerLoad results in scalar loads and shufflevectors
       // instead of single vector load).
@@ -1419,15 +1434,21 @@ class LowerMatrixIntrinsics {
         return Op;
 
       FusedInsts.insert(cast<Instruction>(Op));
+
       // If vector uses the builtin load, lower to a LoadInst
-      Value *Ptr;
+      Value *Arg;
       if (match(Op, m_Intrinsic<Intrinsic::matrix_column_major_load>(
-                        m_Value(Ptr)))) {
-        auto *NewLoad = Builder.CreateLoad(Op->getType(), Ptr);
+                        m_Value(Arg)))) {
+        auto *NewLoad = Builder.CreateLoad(Op->getType(), Arg);
         Op->replaceAllUsesWith(NewLoad);
         cast<Instruction>(Op)->eraseFromParent();
         return NewLoad;
+      } else if (match(Op, m_Intrinsic<Intrinsic::matrix_transpose>(
+                               m_Value(Arg)))) {
+        ToRemove.push_back(cast<Instruction>(Op));
+        return Arg;
       }
+
       return Op;
     };
     LHS = FlattenArg(LHS);

diff  --git a/llvm/test/Transforms/LowerMatrixIntrinsics/dot-product-transpose-int.ll b/llvm/test/Transforms/LowerMatrixIntrinsics/dot-product-transpose-int.ll
index 9f1578e8dc9d9..d719b8ae01def 100644
--- a/llvm/test/Transforms/LowerMatrixIntrinsics/dot-product-transpose-int.ll
+++ b/llvm/test/Transforms/LowerMatrixIntrinsics/dot-product-transpose-int.ll
@@ -5,21 +5,9 @@
 define void @transposed_multiply_feeding_dot_product_v4i322(<4 x i32> %a, <4 x i32> %b) {
 ; CHECK-LABEL: @transposed_multiply_feeding_dot_product_v4i322(
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    [[SPLIT:%.*]] = shufflevector <4 x i32> [[A:%.*]], <4 x i32> poison, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
-; CHECK-NEXT:    [[TMP0:%.*]] = extractelement <4 x i32> [[SPLIT]], i64 0
-; CHECK-NEXT:    [[TMP1:%.*]] = insertelement <1 x i32> poison, i32 [[TMP0]], i64 0
-; CHECK-NEXT:    [[TMP2:%.*]] = extractelement <4 x i32> [[SPLIT]], i64 1
-; CHECK-NEXT:    [[TMP3:%.*]] = insertelement <1 x i32> poison, i32 [[TMP2]], i64 0
-; CHECK-NEXT:    [[TMP4:%.*]] = extractelement <4 x i32> [[SPLIT]], i64 2
-; CHECK-NEXT:    [[TMP5:%.*]] = insertelement <1 x i32> poison, i32 [[TMP4]], i64 0
-; CHECK-NEXT:    [[TMP6:%.*]] = extractelement <4 x i32> [[SPLIT]], i64 3
-; CHECK-NEXT:    [[TMP7:%.*]] = insertelement <1 x i32> poison, i32 [[TMP6]], i64 0
-; CHECK-NEXT:    [[TMP8:%.*]] = shufflevector <1 x i32> [[TMP1]], <1 x i32> [[TMP3]], <2 x i32> <i32 0, i32 1>
-; CHECK-NEXT:    [[TMP9:%.*]] = shufflevector <1 x i32> [[TMP5]], <1 x i32> [[TMP7]], <2 x i32> <i32 0, i32 1>
-; CHECK-NEXT:    [[TMP10:%.*]] = shufflevector <2 x i32> [[TMP8]], <2 x i32> [[TMP9]], <4 x i32> <i32 0, i32 1, i32 2, i32 3>
-; CHECK-NEXT:    [[TMP11:%.*]] = mul <4 x i32> [[TMP10]], [[B:%.*]]
-; CHECK-NEXT:    [[TMP12:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[TMP11]])
-; CHECK-NEXT:    [[TMP13:%.*]] = insertelement <1 x i32> poison, i32 [[TMP12]], i64 0
+; CHECK-NEXT:    [[TMP0:%.*]] = mul <4 x i32> [[A:%.*]], [[B:%.*]]
+; CHECK-NEXT:    [[TMP1:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[TMP0]])
+; CHECK-NEXT:    [[TMP2:%.*]] = insertelement <1 x i32> poison, i32 [[TMP1]], i64 0
 ; CHECK-NEXT:    ret void
 ;
 entry:
@@ -61,18 +49,10 @@ define void @transposed_multiply_feeding_dot_produc_v4i32(<4 x i32> %a, <4 x i32
 ; CHECK-NEXT:    [[TMP11:%.*]] = add <2 x i32> [[TMP8]], [[TMP10]]
 ; CHECK-NEXT:    [[TMP12:%.*]] = shufflevector <2 x i32> [[TMP11]], <2 x i32> poison, <2 x i32> <i32 0, i32 1>
 ; CHECK-NEXT:    [[TMP13:%.*]] = shufflevector <2 x i32> undef, <2 x i32> [[TMP12]], <2 x i32> <i32 2, i32 3>
-; CHECK-NEXT:    [[TMP14:%.*]] = extractelement <2 x i32> [[TMP6]], i64 0
-; CHECK-NEXT:    [[TMP15:%.*]] = insertelement <2 x i32> poison, i32 [[TMP14]], i64 0
-; CHECK-NEXT:    [[TMP16:%.*]] = extractelement <2 x i32> [[TMP13]], i64 0
-; CHECK-NEXT:    [[TMP17:%.*]] = insertelement <2 x i32> [[TMP15]], i32 [[TMP16]], i64 1
-; CHECK-NEXT:    [[TMP18:%.*]] = extractelement <2 x i32> [[TMP6]], i64 1
-; CHECK-NEXT:    [[TMP19:%.*]] = insertelement <2 x i32> poison, i32 [[TMP18]], i64 0
-; CHECK-NEXT:    [[TMP20:%.*]] = extractelement <2 x i32> [[TMP13]], i64 1
-; CHECK-NEXT:    [[TMP21:%.*]] = insertelement <2 x i32> [[TMP19]], i32 [[TMP20]], i64 1
-; CHECK-NEXT:    [[TMP22:%.*]] = shufflevector <2 x i32> [[TMP17]], <2 x i32> [[TMP21]], <4 x i32> <i32 0, i32 1, i32 2, i32 3>
-; CHECK-NEXT:    [[TMP23:%.*]] = mul <4 x i32> [[TMP22]], [[C:%.*]]
-; CHECK-NEXT:    [[TMP24:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[TMP23]])
-; CHECK-NEXT:    [[TMP25:%.*]] = insertelement <1 x i32> poison, i32 [[TMP24]], i64 0
+; CHECK-NEXT:    [[TMP14:%.*]] = shufflevector <2 x i32> [[TMP6]], <2 x i32> [[TMP13]], <4 x i32> <i32 0, i32 1, i32 2, i32 3>
+; CHECK-NEXT:    [[TMP15:%.*]] = mul <4 x i32> [[TMP14]], [[C:%.*]]
+; CHECK-NEXT:    [[TMP16:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[TMP15]])
+; CHECK-NEXT:    [[TMP17:%.*]] = insertelement <1 x i32> poison, i32 [[TMP16]], i64 0
 ; CHECK-NEXT:    ret void
 ;
 entry:


        


More information about the llvm-commits mailing list