[llvm] ebbcbb2 - [Matrix] Remove redundant transpose with dot product lowering.
Florian Hahn via llvm-commits
llvm-commits at lists.llvm.org
Sun May 14 14:07:56 PDT 2023
Author: Florian Hahn
Date: 2023-05-14T22:07:38+01:00
New Revision: ebbcbb2af51a702fe804cc8c08f176e9989e9fe3
URL: https://github.com/llvm/llvm-project/commit/ebbcbb2af51a702fe804cc8c08f176e9989e9fe3
DIFF: https://github.com/llvm/llvm-project/commit/ebbcbb2af51a702fe804cc8c08f176e9989e9fe3.diff
LOG: [Matrix] Remove redundant transpose with dot product lowering.
Extend dot-product handling to skip transposes of the first operand. As
this is a vector, the conversion between column and row vector via the
transpose isn't needed.
Reviewed By: thegameg
Differential Revision: https://reviews.llvm.org/D148428
Added:
Modified:
llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
llvm/test/Transforms/LowerMatrixIntrinsics/dot-product-transpose-int.ll
Removed:
################################################################################
diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
index 594556a0b13df..8508c90cc939d 100644
--- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
+++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
@@ -1359,10 +1359,12 @@ class LowerMatrixIntrinsics {
return;
auto CanBeFlattened = [](Value *Op) {
- return match(Op, m_OneUse(m_CombineOr(
- m_Load(m_Value()),
- m_Intrinsic<Intrinsic::matrix_column_major_load>(
- m_Value(), m_SpecificInt(1)))));
+ return match(
+ Op, m_OneUse(m_CombineOr(
+ m_Load(m_Value()),
+ m_CombineOr(m_Intrinsic<Intrinsic::matrix_transpose>(),
+ m_Intrinsic<Intrinsic::matrix_column_major_load>(
+ m_Value(), m_SpecificInt(1))))));
};
// Returns the cost benefit of using \p Op with the dot product lowering. If
// the returned cost is < 0, the argument is cheaper to use in the
@@ -1374,21 +1376,34 @@ class LowerMatrixIntrinsics {
FixedVectorType *VecTy = cast<FixedVectorType>(Op->getType());
Type *EltTy = VecTy->getElementType();
- if (CanBeFlattened(Op)) {
- if (N == 1)
- return InstructionCost(0);
+ if (!CanBeFlattened(Op)) {
+ InstructionCost EmbedCost(0);
+ // Roughly estimate the cost for embedding the columns into a vector.
+ for (unsigned I = 1; I < N; ++I)
+ EmbedCost -=
+ TTI.getShuffleCost(TTI::SK_Splice, FixedVectorType::get(EltTy, 1),
+ std::nullopt, TTI::TCK_RecipThroughput);
+ return EmbedCost;
+ }
- return TTI.getMemoryOpCost(Instruction::Load, VecTy, Align(1), 0) -
- N * TTI.getMemoryOpCost(Instruction::Load, EltTy, Align(1), 0);
+ if (match(Op, m_Intrinsic<Intrinsic::matrix_transpose>())) {
+ // The transpose can be skipped for the dot product lowering, roughly
+ // estimate the savings as the cost of embedding the columns in a
+ // vector.
+ InstructionCost EmbedCost(0);
+ for (unsigned I = 1; I < N; ++I)
+ EmbedCost +=
+ TTI.getShuffleCost(TTI::SK_Splice, FixedVectorType::get(EltTy, 1),
+ std::nullopt, TTI::TCK_RecipThroughput);
+ return EmbedCost;
}
- InstructionCost EmbedCost(0);
- // Roughly estimate the cost for embedding the columns into a vector.
- for (unsigned I = 1; I < N; ++I)
- EmbedCost +=
- TTI.getShuffleCost(TTI::SK_Splice, FixedVectorType::get(EltTy, 1),
- std::nullopt, TTI::TCK_RecipThroughput);
- return EmbedCost;
+ // Costs for loads.
+ if (N == 1)
+ return InstructionCost(0);
+
+ return TTI.getMemoryOpCost(Instruction::Load, VecTy, Align(1), 0) -
+ N * TTI.getMemoryOpCost(Instruction::Load, EltTy, Align(1), 0);
};
auto LHSCost = GetCostForArg(LHS, LShape.NumColumns);
@@ -1410,8 +1425,8 @@ class LowerMatrixIntrinsics {
FusedInsts.insert(MatMul);
IRBuilder<> Builder(MatMul);
- auto FlattenArg = [&Builder, &FusedInsts,
- &CanBeFlattened](Value *Op) -> Value * {
+ auto FlattenArg = [&Builder, &FusedInsts, &CanBeFlattened,
+ this](Value *Op) -> Value * {
// Matmul must be the only user of loads because we don't use LowerLoad
// for row vectors (LowerLoad results in scalar loads and shufflevectors
// instead of single vector load).
@@ -1419,15 +1434,21 @@ class LowerMatrixIntrinsics {
return Op;
FusedInsts.insert(cast<Instruction>(Op));
+
// If vector uses the builtin load, lower to a LoadInst
- Value *Ptr;
+ Value *Arg;
if (match(Op, m_Intrinsic<Intrinsic::matrix_column_major_load>(
- m_Value(Ptr)))) {
- auto *NewLoad = Builder.CreateLoad(Op->getType(), Ptr);
+ m_Value(Arg)))) {
+ auto *NewLoad = Builder.CreateLoad(Op->getType(), Arg);
Op->replaceAllUsesWith(NewLoad);
cast<Instruction>(Op)->eraseFromParent();
return NewLoad;
+ } else if (match(Op, m_Intrinsic<Intrinsic::matrix_transpose>(
+ m_Value(Arg)))) {
+ ToRemove.push_back(cast<Instruction>(Op));
+ return Arg;
}
+
return Op;
};
LHS = FlattenArg(LHS);
diff --git a/llvm/test/Transforms/LowerMatrixIntrinsics/dot-product-transpose-int.ll b/llvm/test/Transforms/LowerMatrixIntrinsics/dot-product-transpose-int.ll
index 9f1578e8dc9d9..d719b8ae01def 100644
--- a/llvm/test/Transforms/LowerMatrixIntrinsics/dot-product-transpose-int.ll
+++ b/llvm/test/Transforms/LowerMatrixIntrinsics/dot-product-transpose-int.ll
@@ -5,21 +5,9 @@
define void @transposed_multiply_feeding_dot_product_v4i322(<4 x i32> %a, <4 x i32> %b) {
; CHECK-LABEL: @transposed_multiply_feeding_dot_product_v4i322(
; CHECK-NEXT: entry:
-; CHECK-NEXT: [[SPLIT:%.*]] = shufflevector <4 x i32> [[A:%.*]], <4 x i32> poison, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
-; CHECK-NEXT: [[TMP0:%.*]] = extractelement <4 x i32> [[SPLIT]], i64 0
-; CHECK-NEXT: [[TMP1:%.*]] = insertelement <1 x i32> poison, i32 [[TMP0]], i64 0
-; CHECK-NEXT: [[TMP2:%.*]] = extractelement <4 x i32> [[SPLIT]], i64 1
-; CHECK-NEXT: [[TMP3:%.*]] = insertelement <1 x i32> poison, i32 [[TMP2]], i64 0
-; CHECK-NEXT: [[TMP4:%.*]] = extractelement <4 x i32> [[SPLIT]], i64 2
-; CHECK-NEXT: [[TMP5:%.*]] = insertelement <1 x i32> poison, i32 [[TMP4]], i64 0
-; CHECK-NEXT: [[TMP6:%.*]] = extractelement <4 x i32> [[SPLIT]], i64 3
-; CHECK-NEXT: [[TMP7:%.*]] = insertelement <1 x i32> poison, i32 [[TMP6]], i64 0
-; CHECK-NEXT: [[TMP8:%.*]] = shufflevector <1 x i32> [[TMP1]], <1 x i32> [[TMP3]], <2 x i32> <i32 0, i32 1>
-; CHECK-NEXT: [[TMP9:%.*]] = shufflevector <1 x i32> [[TMP5]], <1 x i32> [[TMP7]], <2 x i32> <i32 0, i32 1>
-; CHECK-NEXT: [[TMP10:%.*]] = shufflevector <2 x i32> [[TMP8]], <2 x i32> [[TMP9]], <4 x i32> <i32 0, i32 1, i32 2, i32 3>
-; CHECK-NEXT: [[TMP11:%.*]] = mul <4 x i32> [[TMP10]], [[B:%.*]]
-; CHECK-NEXT: [[TMP12:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[TMP11]])
-; CHECK-NEXT: [[TMP13:%.*]] = insertelement <1 x i32> poison, i32 [[TMP12]], i64 0
+; CHECK-NEXT: [[TMP0:%.*]] = mul <4 x i32> [[A:%.*]], [[B:%.*]]
+; CHECK-NEXT: [[TMP1:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[TMP0]])
+; CHECK-NEXT: [[TMP2:%.*]] = insertelement <1 x i32> poison, i32 [[TMP1]], i64 0
; CHECK-NEXT: ret void
;
entry:
@@ -61,18 +49,10 @@ define void @transposed_multiply_feeding_dot_produc_v4i32(<4 x i32> %a, <4 x i32
; CHECK-NEXT: [[TMP11:%.*]] = add <2 x i32> [[TMP8]], [[TMP10]]
; CHECK-NEXT: [[TMP12:%.*]] = shufflevector <2 x i32> [[TMP11]], <2 x i32> poison, <2 x i32> <i32 0, i32 1>
; CHECK-NEXT: [[TMP13:%.*]] = shufflevector <2 x i32> undef, <2 x i32> [[TMP12]], <2 x i32> <i32 2, i32 3>
-; CHECK-NEXT: [[TMP14:%.*]] = extractelement <2 x i32> [[TMP6]], i64 0
-; CHECK-NEXT: [[TMP15:%.*]] = insertelement <2 x i32> poison, i32 [[TMP14]], i64 0
-; CHECK-NEXT: [[TMP16:%.*]] = extractelement <2 x i32> [[TMP13]], i64 0
-; CHECK-NEXT: [[TMP17:%.*]] = insertelement <2 x i32> [[TMP15]], i32 [[TMP16]], i64 1
-; CHECK-NEXT: [[TMP18:%.*]] = extractelement <2 x i32> [[TMP6]], i64 1
-; CHECK-NEXT: [[TMP19:%.*]] = insertelement <2 x i32> poison, i32 [[TMP18]], i64 0
-; CHECK-NEXT: [[TMP20:%.*]] = extractelement <2 x i32> [[TMP13]], i64 1
-; CHECK-NEXT: [[TMP21:%.*]] = insertelement <2 x i32> [[TMP19]], i32 [[TMP20]], i64 1
-; CHECK-NEXT: [[TMP22:%.*]] = shufflevector <2 x i32> [[TMP17]], <2 x i32> [[TMP21]], <4 x i32> <i32 0, i32 1, i32 2, i32 3>
-; CHECK-NEXT: [[TMP23:%.*]] = mul <4 x i32> [[TMP22]], [[C:%.*]]
-; CHECK-NEXT: [[TMP24:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[TMP23]])
-; CHECK-NEXT: [[TMP25:%.*]] = insertelement <1 x i32> poison, i32 [[TMP24]], i64 0
+; CHECK-NEXT: [[TMP14:%.*]] = shufflevector <2 x i32> [[TMP6]], <2 x i32> [[TMP13]], <4 x i32> <i32 0, i32 1, i32 2, i32 3>
+; CHECK-NEXT: [[TMP15:%.*]] = mul <4 x i32> [[TMP14]], [[C:%.*]]
+; CHECK-NEXT: [[TMP16:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[TMP15]])
+; CHECK-NEXT: [[TMP17:%.*]] = insertelement <1 x i32> poison, i32 [[TMP16]], i64 0
; CHECK-NEXT: ret void
;
entry:
More information about the llvm-commits
mailing list