[llvm] [Matrix] Use DenseMap for ShapeMap instead of ValueMap. (PR #118282)
Florian Hahn via llvm-commits
llvm-commits at lists.llvm.org
Mon Dec 2 14:04:50 PST 2024
https://github.com/fhahn updated https://github.com/llvm/llvm-project/pull/118282
>From db421102bde1f551fa44aabbee867abf12eaeed7 Mon Sep 17 00:00:00 2001
From: Florian Hahn <flo at fhahn.com>
Date: Sun, 1 Dec 2024 19:01:02 +0000
Subject: [PATCH 1/3] [Matrix] Use DenseMap for ShapeMap instead of ValueMap.
ValueMap automatically updates entries with the new value if they have
been RAUW. This can lead to instructions that are expected to not have
shape info to be added to the map (e.g. shufflevector as in the added
test case).
This leads to incorrect results. Originally it was used for transpose
optimizations, but they now all use updateShapeAndReplaceAllUsesWith,
which takes care of updating the shape info as needed.
This fixes a crash in the newly added test case.
---
.../Scalar/LowerMatrixIntrinsics.cpp | 8 ++---
.../dot-product-transpose-int.ll | 30 +++++++++++++++++++
2 files changed, 33 insertions(+), 5 deletions(-)
diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
index 6a9ec48864b2c5..85cc069f098ea4 100644
--- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
+++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
@@ -259,7 +259,7 @@ static bool isUniformShape(Value *V) {
/// Return the ShapeInfo for the result of \p I, it it can be determined.
static std::optional<ShapeInfo>
computeShapeInfoForInst(Instruction *I,
- const ValueMap<Value *, ShapeInfo> &ShapeMap) {
+ const DenseMap<Value *, ShapeInfo> &ShapeMap) {
Value *M;
Value *N;
Value *K;
@@ -493,10 +493,8 @@ class LowerMatrixIntrinsics {
/// the result value of the instruction, with the only exceptions being store
/// instructions and the matrix_column_major_store intrinsics. For those, the
/// shape information indicates that those instructions should be lowered
- /// using shape information as well. A ValueMap is used so that when
- /// sub-passes like optimizeTransposes performs RAUW the map stays
- /// up-to-date.
- ValueMap<Value *, ShapeInfo> ShapeMap;
+ /// using shape information as well.
+ DenseMap<Value *, ShapeInfo> ShapeMap;
/// List of instructions to remove. While lowering, we are not replacing all
/// users of a lowered instruction, if shape information is available and
diff --git a/llvm/test/Transforms/LowerMatrixIntrinsics/dot-product-transpose-int.ll b/llvm/test/Transforms/LowerMatrixIntrinsics/dot-product-transpose-int.ll
index 2fd77e245a34e5..aadaf1ffffb23a 100644
--- a/llvm/test/Transforms/LowerMatrixIntrinsics/dot-product-transpose-int.ll
+++ b/llvm/test/Transforms/LowerMatrixIntrinsics/dot-product-transpose-int.ll
@@ -190,3 +190,33 @@ declare <1 x i32> @llvm.matrix.multiply.v1i32.v5i32.v5i32(<5 x i32>, <5 x i32>,
declare <5 x i32> @llvm.matrix.column.major.load.v5i32.i64(ptr nocapture, i64, i1 immarg, i32 immarg, i32 immarg) #1
declare <5 x i32> @llvm.matrix.transpose.v5i32(<5 x i32>, i32 immarg, i32 immarg) #0
+
+define <1 x i32> @test_dot_product_with_transposed_shuffle_op(<4 x i32> %a, <2 x i32> %b) {
+; CHECK-LABEL: @test_dot_product_with_transposed_shuffle_op(
+; CHECK-NEXT: entry:
+; CHECK-NEXT: [[SPLIT:%.*]] = shufflevector <4 x i32> [[A:%.*]], <4 x i32> poison, <2 x i32> <i32 0, i32 1>
+; CHECK-NEXT: [[SPLIT1:%.*]] = shufflevector <4 x i32> [[A]], <4 x i32> poison, <2 x i32> <i32 2, i32 3>
+; CHECK-NEXT: [[TMP0:%.*]] = extractelement <2 x i32> [[SPLIT]], i64 0
+; CHECK-NEXT: [[TMP1:%.*]] = insertelement <2 x i32> poison, i32 [[TMP0]], i64 0
+; CHECK-NEXT: [[TMP2:%.*]] = extractelement <2 x i32> [[SPLIT1]], i64 0
+; CHECK-NEXT: [[TMP3:%.*]] = insertelement <2 x i32> [[TMP1]], i32 [[TMP2]], i64 1
+; CHECK-NEXT: [[TMP4:%.*]] = extractelement <2 x i32> [[SPLIT]], i64 1
+; CHECK-NEXT: [[TMP5:%.*]] = insertelement <2 x i32> poison, i32 [[TMP4]], i64 0
+; CHECK-NEXT: [[TMP6:%.*]] = extractelement <2 x i32> [[SPLIT1]], i64 1
+; CHECK-NEXT: [[TMP7:%.*]] = insertelement <2 x i32> [[TMP5]], i32 [[TMP6]], i64 1
+; CHECK-NEXT: [[TMP8:%.*]] = shufflevector <2 x i32> [[TMP3]], <2 x i32> [[TMP7]], <4 x i32> <i32 0, i32 1, i32 2, i32 3>
+; CHECK-NEXT: [[SHUFFLE:%.*]] = shufflevector <4 x i32> [[TMP8]], <4 x i32> zeroinitializer, <2 x i32> <i32 0, i32 1>
+; CHECK-NEXT: [[TMP9:%.*]] = mul <2 x i32> [[SHUFFLE]], [[B:%.*]]
+; CHECK-NEXT: [[TMP10:%.*]] = call i32 @llvm.vector.reduce.add.v2i32(<2 x i32> [[TMP9]])
+; CHECK-NEXT: [[TMP11:%.*]] = insertelement <1 x i32> poison, i32 [[TMP10]], i64 0
+; CHECK-NEXT: ret <1 x i32> [[TMP11]]
+;
+entry:
+ %t.a = tail call <4 x i32> @llvm.matrix.transpose.v4i32(<4 x i32> %a, i32 2, i32 2)
+ %shuffle = shufflevector <4 x i32> %t.a, <4 x i32> zeroinitializer, <2 x i32> <i32 0, i32 1>
+ %t.shuffle = call <2 x i32> @llvm.matrix.transpose.v2i32(<2 x i32> %shuffle, i32 2, i32 1)
+ %m = call <1 x i32> @llvm.matrix.multiply.v1i32.v2i32.v2i32(<2 x i32> %t.shuffle, <2 x i32> %b, i32 1, i32 2, i32 1)
+ ret <1 x i32> %m
+}
+
+declare <2 x i32> @llvm.matrix.transpose.v2i32(<2 x i32>, i32 immarg, i32 immarg)
>From 852233f634c84cce486580e857e0b20da87f72c8 Mon Sep 17 00:00:00 2001
From: Florian Hahn <flo at fhahn.com>
Date: Mon, 2 Dec 2024 18:33:01 +0000
Subject: [PATCH 2/3] !fixup also remove entries from ShapeMap when removing
instructions.
---
.../Scalar/LowerMatrixIntrinsics.cpp | 51 +++++++++++--------
1 file changed, 29 insertions(+), 22 deletions(-)
diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
index 85cc069f098ea4..3e0da3ccd9ae5c 100644
--- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
+++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
@@ -97,19 +97,6 @@ static DISubprogram *getSubprogram(DIScope *Scope) {
return cast<DILocalScope>(Scope)->getSubprogram();
}
-/// Erase \p V from \p BB and move \II forward to avoid invalidating
-/// iterators.
-static void eraseFromParentAndMove(Value *V, BasicBlock::reverse_iterator &II,
- BasicBlock &BB) {
- auto *Inst = cast<Instruction>(V);
- // Still used, don't erase.
- if (!Inst->use_empty())
- return;
- if (II != BB.rend() && Inst == &*II)
- ++II;
- Inst->eraseFromParent();
-}
-
/// Return true if V is a splat of a value (which is used when multiplying a
/// matrix with a scalar).
static bool isSplat(Value *V) {
@@ -756,6 +743,26 @@ class LowerMatrixIntrinsics {
return Operation(T0, Shape0.t(), T1, Shape1.t());
}
+ void eraseFromParentAndRemoveFromShapeMap(Instruction *Inst) {
+ auto Iter = ShapeMap.find(Inst);
+ if (Iter != ShapeMap.end())
+ ShapeMap.erase(Iter);
+ Inst->eraseFromParent();
+ }
+
+ /// Erase \p V from \p BB and move \II forward to avoid invalidating
+ /// iterators.
+ void eraseFromParentAndMove(Value *V, BasicBlock::reverse_iterator &II,
+ BasicBlock &BB) {
+ auto *Inst = cast<Instruction>(V);
+ // Still used, don't erase.
+ if (!Inst->use_empty())
+ return;
+ if (II != BB.rend() && Inst == &*II)
+ ++II;
+ eraseFromParentAndRemoveFromShapeMap(Inst);
+ }
+
void updateShapeAndReplaceAllUsesWith(Instruction &Old, Value *New) {
// We need to remove Old from the ShapeMap otherwise RAUW will replace it
// with New. We should only add New it it supportsShapeInfo so we insert
@@ -869,13 +876,13 @@ class LowerMatrixIntrinsics {
void liftTranspose(Instruction &I) {
// Erase dead Instructions after lifting transposes from binops.
- auto CleanupBinOp = [](Instruction &T, Value *A, Value *B) {
+ auto CleanupBinOp = [this](Instruction &T, Value *A, Value *B) {
if (T.use_empty())
- T.eraseFromParent();
+ eraseFromParentAndRemoveFromShapeMap(&T);
if (A->use_empty())
- cast<Instruction>(A)->eraseFromParent();
+ eraseFromParentAndRemoveFromShapeMap(cast<Instruction>(A));
if (A != B && B->use_empty())
- cast<Instruction>(B)->eraseFromParent();
+ eraseFromParentAndRemoveFromShapeMap(cast<Instruction>(B));
};
Value *A, *B, *AT, *BT;
@@ -1482,7 +1489,7 @@ class LowerMatrixIntrinsics {
m_Value(Arg)))) {
auto *NewLoad = Builder.CreateLoad(Op->getType(), Arg);
Op->replaceAllUsesWith(NewLoad);
- cast<Instruction>(Op)->eraseFromParent();
+ eraseFromParentAndRemoveFromShapeMap(cast<Instruction>(Op));
return;
} else if (match(Op, m_Intrinsic<Intrinsic::matrix_transpose>(
m_Value(Arg)))) {
@@ -1851,15 +1858,15 @@ class LowerMatrixIntrinsics {
// Mark eliminated instructions as fused and remove them.
FusedInsts.insert(Store);
FusedInsts.insert(MatMul);
- Store->eraseFromParent();
- MatMul->eraseFromParent();
+ eraseFromParentAndRemoveFromShapeMap(Store);
+ eraseFromParentAndRemoveFromShapeMap(MatMul);
if (LoadOp0->hasNUses(0)) {
FusedInsts.insert(LoadOp0);
- LoadOp0->eraseFromParent();
+ eraseFromParentAndRemoveFromShapeMap(LoadOp0);
}
if (LoadOp1 != LoadOp0 && LoadOp1->hasNUses(0)) {
FusedInsts.insert(LoadOp1);
- LoadOp1->eraseFromParent();
+ eraseFromParentAndRemoveFromShapeMap(LoadOp1);
}
}
>From e7206712342d66c3063e8c64731d06cafef73ecb Mon Sep 17 00:00:00 2001
From: Florian Hahn <flo at fhahn.com>
Date: Mon, 2 Dec 2024 22:04:10 +0000
Subject: [PATCH 3/3] !fixup add test case showing issues
---
.../transpose-opts-lifting.ll | 20 +++++++++++++++++++
1 file changed, 20 insertions(+)
diff --git a/llvm/test/Transforms/LowerMatrixIntrinsics/transpose-opts-lifting.ll b/llvm/test/Transforms/LowerMatrixIntrinsics/transpose-opts-lifting.ll
index fcf83b03bc3d23..1b3b41d8cfe1f8 100644
--- a/llvm/test/Transforms/LowerMatrixIntrinsics/transpose-opts-lifting.ll
+++ b/llvm/test/Transforms/LowerMatrixIntrinsics/transpose-opts-lifting.ll
@@ -144,8 +144,28 @@ entry:
ret <6 x double> %mul
}
+define void @test_remove_entries_from_shape_map(<3 x float> %a, <2 x float> %b, <6 x float> %c, ptr %dst) {
+; CHECK-LABEL: define void @test_remove_entries_from_shape_map(
+; CHECK-SAME: <3 x float> [[A:%.*]], <2 x float> [[B:%.*]], <6 x float> [[C:%.*]], ptr [[DST:%.*]]) {
+; CHECK-NEXT: [[ENTRY:.*:]]
+; CHECK-NEXT: [[TMP0:%.*]] = call <6 x float> @llvm.matrix.multiply.v6f32.v3f32.v2f32(<3 x float> [[A]], <2 x float> [[B]], i32 3, i32 1, i32 2)
+; CHECK-NEXT: [[MFADD:%.*]] = fadd <6 x float> [[C]], [[TMP0]]
+; CHECK-NEXT: [[MFADD_T:%.*]] = call <6 x float> @llvm.matrix.transpose.v6f32(<6 x float> [[MFADD]], i32 3, i32 2)
+; CHECK-NEXT: store <6 x float> [[MFADD_T]], ptr [[DST]], align 4
+; CHECK-NEXT: ret void
+;
+entry:
+ %m = tail call <6 x float> @llvm.matrix.multiply.v6f32.v3f32.v2f32(<3 x float> %a, <2 x float> %b, i32 3, i32 1, i32 2)
+ %add = fadd <6 x float> %c, %m
+ %t = tail call <6 x float> @llvm.matrix.transpose.v6f32(<6 x float> %add, i32 3, i32 2)
+ store <6 x float> %t, ptr %dst, align 4
+ ret void
+}
+
declare <6 x double> @llvm.matrix.transpose.v6f64.v6f64(<6 x double>, i32, i32)
declare <4 x double> @llvm.matrix.transpose.v4f64.v4f64(<4 x double>, i32, i32)
declare <9 x double> @llvm.matrix.multiply.v9f64.v6f64(<6 x double>, <6 x double>, i32, i32, i32)
declare <6 x double> @llvm.matrix.multiply.v6f64.v6f64.v4f64(<6 x double>, <4 x double>, i32, i32, i32)
declare <6 x double> @llvm.matrix.multiply.v6f64.v6f64.v6f64(<6 x double>, <4 x double>, i32, i32, i32)
+declare <6 x float> @llvm.matrix.transpose.v6f32(<6 x float>, i32 immarg, i32 immarg)
+declare <6 x float> @llvm.matrix.multiply.v6f32.v3f32.v2f32(<3 x float>, <2 x float>, i32 immarg, i32 immarg, i32 immarg)
More information about the llvm-commits
mailing list