[clang-tools-extra] [llvm] [Matrix] Convert column-vector ops feeding dot product to row-vectors. (PR #72647)

Florian Hahn via cfe-commits cfe-commits at lists.llvm.org
Fri Jan 12 02:00:44 PST 2024


https://github.com/fhahn updated https://github.com/llvm/llvm-project/pull/72647

>From 3dfe86782806f048b130d46afa6293712919f672 Mon Sep 17 00:00:00 2001
From: Florian Hahn <flo at fhahn.com>
Date: Fri, 14 Apr 2023 14:33:57 +0100
Subject: [PATCH 1/2] [Matrix] Convert column-vector ops feeding dot product to
 row-vectors.

Generalize the logic used to convert column-vector ops to row-vectors to
support converting chains of operations.

A potential next step is to further generalize this to convert
column-vector ops to row-vector ops in general, not just for operands of
dot products. Dot-product handling would then be driven by the general
conversion, rather than the other way around.
---
 .../Scalar/LowerMatrixIntrinsics.cpp          | 51 ++++++++++++++-----
 .../LowerMatrixIntrinsics/dot-product-int.ll  | 47 ++++-------------
 2 files changed, 47 insertions(+), 51 deletions(-)

diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
index 72b9db1e73d73d..c6bb43d3a78cf3 100644
--- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
+++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
@@ -1332,8 +1332,8 @@ class LowerMatrixIntrinsics {
     if (!IsIntVec && !FMF.allowReassoc())
       return;
 
-    auto CanBeFlattened = [this](Value *Op) {
-      if (match(Op, m_BinOp()) && ShapeMap.find(Op) != ShapeMap.end())
+    auto CanBeFlattened = [](Value *Op) {
+      if (match(Op, m_BinOp()))
         return true;
       return match(
           Op, m_OneUse(m_CombineOr(
@@ -1346,6 +1346,9 @@ class LowerMatrixIntrinsics {
     // the returned cost is < 0, the argument is cheaper to use in the
     // dot-product lowering.
     auto GetCostForArg = [this, &CanBeFlattened](Value *Op, unsigned N) {
+      if (ShapeMap.find(Op) == ShapeMap.end())
+        return InstructionCost::getInvalid();
+
       if (!isa<Instruction>(Op))
         return InstructionCost(0);
 
@@ -1356,7 +1359,7 @@ class LowerMatrixIntrinsics {
         InstructionCost EmbedCost(0);
         // Roughly estimate the cost for embedding the columns into a vector.
         for (unsigned I = 1; I < N; ++I)
-          EmbedCost -=
+          EmbedCost +=
               TTI.getShuffleCost(TTI::SK_Splice, FixedVectorType::get(EltTy, 1),
                                  std::nullopt, TTI::TCK_RecipThroughput);
         return EmbedCost;
@@ -1378,7 +1381,7 @@ class LowerMatrixIntrinsics {
         // vector.
         InstructionCost EmbedCost(0);
         for (unsigned I = 1; I < N; ++I)
-          EmbedCost +=
+          EmbedCost -=
               TTI.getShuffleCost(TTI::SK_Splice, FixedVectorType::get(EltTy, 1),
                                  std::nullopt, TTI::TCK_RecipThroughput);
         return EmbedCost;
@@ -1391,7 +1394,26 @@ class LowerMatrixIntrinsics {
       return TTI.getMemoryOpCost(Instruction::Load, VecTy, Align(1), 0) -
              N * TTI.getMemoryOpCost(Instruction::Load, EltTy, Align(1), 0);
     };
-    auto LHSCost = GetCostForArg(LHS, LShape.NumColumns);
+
+    SmallPtrSet<Value *, 4> Seen;
+    SmallVector<Value *> WorkList;
+    SmallVector<Value *> ToFlatten;
+    WorkList.push_back(LHS);
+    InstructionCost LHSCost(0);
+    while (!WorkList.empty()) {
+      Value *Op = WorkList.pop_back_val();
+      if (!Seen.insert(Op).second)
+        continue;
+
+      InstructionCost OpCost = GetCostForArg(Op, LShape.NumColumns);
+      if (OpCost + LHSCost >= LHSCost)
+        continue;
+
+      LHSCost += OpCost;
+      ToFlatten.push_back(Op);
+      if (auto *I = dyn_cast<Instruction>(Op))
+        WorkList.append(I->op_begin(), I->op_end());
+    }
 
     // We compare the costs of a vector.reduce.add to sequential add.
     int AddOpCode = IsIntVec ? Instruction::Add : Instruction::FAdd;
@@ -1412,16 +1434,16 @@ class LowerMatrixIntrinsics {
     FusedInsts.insert(MatMul);
     IRBuilder<> Builder(MatMul);
     auto FlattenArg = [&Builder, &FusedInsts, &CanBeFlattened,
-                       this](Value *Op) -> Value * {
+                       this](Value *Op) {
       // Matmul must be the only user of loads because we don't use LowerLoad
       // for row vectors (LowerLoad results in scalar loads and shufflevectors
       // instead of single vector load).
       if (!CanBeFlattened(Op))
-        return Op;
+        return;
 
       if (match(Op, m_BinOp()) && ShapeMap.find(Op) != ShapeMap.end()) {
         ShapeMap[Op] = ShapeMap[Op].t();
-        return Op;
+        return;
       }
 
       FusedInsts.insert(cast<Instruction>(Op));
@@ -1432,16 +1454,19 @@ class LowerMatrixIntrinsics {
         auto *NewLoad = Builder.CreateLoad(Op->getType(), Arg);
         Op->replaceAllUsesWith(NewLoad);
         cast<Instruction>(Op)->eraseFromParent();
-        return NewLoad;
+        return;
       } else if (match(Op, m_Intrinsic<Intrinsic::matrix_transpose>(
                                m_Value(Arg)))) {
         ToRemove.push_back(cast<Instruction>(Op));
-        return Arg;
+        Op->replaceAllUsesWith(Arg);
+        return;
       }
-
-      return Op;
     };
-    LHS = FlattenArg(LHS);
+
+    for (auto *V : ToFlatten)
+      FlattenArg(V);
+
+    LHS = MatMul->getArgOperand(0);
 
     // Insert mul/fmul and llvm.vector.reduce.fadd
     Value *Mul =
diff --git a/llvm/test/Transforms/LowerMatrixIntrinsics/dot-product-int.ll b/llvm/test/Transforms/LowerMatrixIntrinsics/dot-product-int.ll
index 7bbd0c50048551..f15dbed1f1f513 100644
--- a/llvm/test/Transforms/LowerMatrixIntrinsics/dot-product-int.ll
+++ b/llvm/test/Transforms/LowerMatrixIntrinsics/dot-product-int.ll
@@ -119,44 +119,15 @@ entry:
 define <1 x i32> @add_chain_feeding_dotproduct_i32_v8_1(<8 x i32> %a, <8 x i32> %b, <8 x i32> %c, <8 x i32> %d) {
 ; CHECK-LABEL: @add_chain_feeding_dotproduct_i32_v8_1(
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    [[SPLIT:%.*]] = shufflevector <8 x i32> [[A:%.*]], <8 x i32> poison, <1 x i32> zeroinitializer
-; CHECK-NEXT:    [[SPLIT1:%.*]] = shufflevector <8 x i32> [[A]], <8 x i32> poison, <1 x i32> <i32 1>
-; CHECK-NEXT:    [[SPLIT2:%.*]] = shufflevector <8 x i32> [[A]], <8 x i32> poison, <1 x i32> <i32 2>
-; CHECK-NEXT:    [[SPLIT3:%.*]] = shufflevector <8 x i32> [[A]], <8 x i32> poison, <1 x i32> <i32 3>
-; CHECK-NEXT:    [[SPLIT4:%.*]] = shufflevector <8 x i32> [[A]], <8 x i32> poison, <1 x i32> <i32 4>
-; CHECK-NEXT:    [[SPLIT5:%.*]] = shufflevector <8 x i32> [[A]], <8 x i32> poison, <1 x i32> <i32 5>
-; CHECK-NEXT:    [[SPLIT6:%.*]] = shufflevector <8 x i32> [[A]], <8 x i32> poison, <1 x i32> <i32 6>
-; CHECK-NEXT:    [[SPLIT7:%.*]] = shufflevector <8 x i32> [[A]], <8 x i32> poison, <1 x i32> <i32 7>
-; CHECK-NEXT:    [[SPLIT8:%.*]] = shufflevector <8 x i32> [[B:%.*]], <8 x i32> poison, <1 x i32> zeroinitializer
-; CHECK-NEXT:    [[SPLIT9:%.*]] = shufflevector <8 x i32> [[B]], <8 x i32> poison, <1 x i32> <i32 1>
-; CHECK-NEXT:    [[SPLIT10:%.*]] = shufflevector <8 x i32> [[B]], <8 x i32> poison, <1 x i32> <i32 2>
-; CHECK-NEXT:    [[SPLIT11:%.*]] = shufflevector <8 x i32> [[B]], <8 x i32> poison, <1 x i32> <i32 3>
-; CHECK-NEXT:    [[SPLIT12:%.*]] = shufflevector <8 x i32> [[B]], <8 x i32> poison, <1 x i32> <i32 4>
-; CHECK-NEXT:    [[SPLIT13:%.*]] = shufflevector <8 x i32> [[B]], <8 x i32> poison, <1 x i32> <i32 5>
-; CHECK-NEXT:    [[SPLIT14:%.*]] = shufflevector <8 x i32> [[B]], <8 x i32> poison, <1 x i32> <i32 6>
-; CHECK-NEXT:    [[SPLIT15:%.*]] = shufflevector <8 x i32> [[B]], <8 x i32> poison, <1 x i32> <i32 7>
-; CHECK-NEXT:    [[TMP0:%.*]] = add <1 x i32> [[SPLIT]], [[SPLIT8]]
-; CHECK-NEXT:    [[TMP1:%.*]] = add <1 x i32> [[SPLIT1]], [[SPLIT9]]
-; CHECK-NEXT:    [[TMP2:%.*]] = add <1 x i32> [[SPLIT2]], [[SPLIT10]]
-; CHECK-NEXT:    [[TMP3:%.*]] = add <1 x i32> [[SPLIT3]], [[SPLIT11]]
-; CHECK-NEXT:    [[TMP4:%.*]] = add <1 x i32> [[SPLIT4]], [[SPLIT12]]
-; CHECK-NEXT:    [[TMP5:%.*]] = add <1 x i32> [[SPLIT5]], [[SPLIT13]]
-; CHECK-NEXT:    [[TMP6:%.*]] = add <1 x i32> [[SPLIT6]], [[SPLIT14]]
-; CHECK-NEXT:    [[TMP7:%.*]] = add <1 x i32> [[SPLIT7]], [[SPLIT15]]
-; CHECK-NEXT:    [[TMP8:%.*]] = shufflevector <1 x i32> [[TMP0]], <1 x i32> [[TMP1]], <2 x i32> <i32 0, i32 1>
-; CHECK-NEXT:    [[TMP9:%.*]] = shufflevector <1 x i32> [[TMP2]], <1 x i32> [[TMP3]], <2 x i32> <i32 0, i32 1>
-; CHECK-NEXT:    [[TMP10:%.*]] = shufflevector <1 x i32> [[TMP4]], <1 x i32> [[TMP5]], <2 x i32> <i32 0, i32 1>
-; CHECK-NEXT:    [[TMP11:%.*]] = shufflevector <1 x i32> [[TMP6]], <1 x i32> [[TMP7]], <2 x i32> <i32 0, i32 1>
-; CHECK-NEXT:    [[TMP12:%.*]] = shufflevector <2 x i32> [[TMP8]], <2 x i32> [[TMP9]], <4 x i32> <i32 0, i32 1, i32 2, i32 3>
-; CHECK-NEXT:    [[TMP13:%.*]] = shufflevector <2 x i32> [[TMP10]], <2 x i32> [[TMP11]], <4 x i32> <i32 0, i32 1, i32 2, i32 3>
-; CHECK-NEXT:    [[TMP14:%.*]] = shufflevector <4 x i32> [[TMP12]], <4 x i32> [[TMP13]], <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
-; CHECK-NEXT:    [[SPLIT16:%.*]] = shufflevector <8 x i32> [[TMP14]], <8 x i32> poison, <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
-; CHECK-NEXT:    [[SPLIT17:%.*]] = shufflevector <8 x i32> [[C:%.*]], <8 x i32> poison, <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
-; CHECK-NEXT:    [[TMP15:%.*]] = add <8 x i32> [[SPLIT16]], [[SPLIT17]]
-; CHECK-NEXT:    [[TMP16:%.*]] = mul <8 x i32> [[TMP15]], [[D:%.*]]
-; CHECK-NEXT:    [[TMP17:%.*]] = call i32 @llvm.vector.reduce.add.v8i32(<8 x i32> [[TMP16]])
-; CHECK-NEXT:    [[TMP18:%.*]] = insertelement <1 x i32> poison, i32 [[TMP17]], i64 0
-; CHECK-NEXT:    ret <1 x i32> [[TMP18]]
+; CHECK-NEXT:    [[SPLIT:%.*]] = shufflevector <8 x i32> [[A:%.*]], <8 x i32> poison, <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
+; CHECK-NEXT:    [[SPLIT1:%.*]] = shufflevector <8 x i32> [[B:%.*]], <8 x i32> poison, <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
+; CHECK-NEXT:    [[TMP0:%.*]] = add <8 x i32> [[SPLIT]], [[SPLIT1]]
+; CHECK-NEXT:    [[SPLIT2:%.*]] = shufflevector <8 x i32> [[C:%.*]], <8 x i32> poison, <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
+; CHECK-NEXT:    [[TMP1:%.*]] = add <8 x i32> [[TMP0]], [[SPLIT2]]
+; CHECK-NEXT:    [[TMP2:%.*]] = mul <8 x i32> [[TMP1]], [[D:%.*]]
+; CHECK-NEXT:    [[TMP3:%.*]] = call i32 @llvm.vector.reduce.add.v8i32(<8 x i32> [[TMP2]])
+; CHECK-NEXT:    [[TMP4:%.*]] = insertelement <1 x i32> poison, i32 [[TMP3]], i64 0
+; CHECK-NEXT:    ret <1 x i32> [[TMP4]]
 ;
 entry:
   %add.1 = add <8 x i32> %a, %b

>From 5c1a3f2cb7ed28c21e8a1e976754e69e9e3bb452 Mon Sep 17 00:00:00 2001
From: Florian Hahn <flo at fhahn.com>
Date: Fri, 12 Jan 2024 09:55:14 +0000
Subject: [PATCH 2/2] !fixup add comment

---
 llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp | 3 +++
 1 file changed, 3 insertions(+)

diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
index c6bb43d3a78cf3..b528762b545659 100644
--- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
+++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
@@ -1395,6 +1395,9 @@ class LowerMatrixIntrinsics {
              N * TTI.getMemoryOpCost(Instruction::Load, EltTy, Align(1), 0);
     };
 
+    // Iterate over LHS and operations feeding LHS and check if it is profitable
+    // to flatten the visited ops.  For each op, we compute the difference
+    // between the flattened and matrix versions.
     SmallPtrSet<Value *, 4> Seen;
     SmallVector<Value *> WorkList;
     SmallVector<Value *> ToFlatten;



More information about the cfe-commits mailing list