[llvm] [Matrix] Use DenseMap for ShapeMap instead of ValueMap. (PR #118282)
Florian Hahn via llvm-commits
llvm-commits at lists.llvm.org
Wed Dec 4 05:22:10 PST 2024
https://github.com/fhahn updated https://github.com/llvm/llvm-project/pull/118282
>From b6d029c60bcf5a73b795d669b4e6d8548b9192f8 Mon Sep 17 00:00:00 2001
From: Florian Hahn <flo at fhahn.com>
Date: Sun, 1 Dec 2024 19:01:02 +0000
Subject: [PATCH 1/4] [Matrix] Use DenseMap for ShapeMap instead of ValueMap.
ValueMap automatically updates entries with the new value if they have
been RAUW. This can lead to instructions that are expected to not have
shape info to be added to the map (e.g. shufflevector as in the added
test case).
This leads to incorrect results. Originally it was used for transpose
optimizations, but they now all use updateShapeAndReplaceAllUsesWith,
which takes care of updating the shape info as needed.
This fixes a crash in the newly added test case.
---
.../Scalar/LowerMatrixIntrinsics.cpp | 8 ++---
.../dot-product-transpose-int.ll | 30 +++++++++++++++++++
2 files changed, 33 insertions(+), 5 deletions(-)
diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
index 6a9ec48864b2c5..85cc069f098ea4 100644
--- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
+++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
@@ -259,7 +259,7 @@ static bool isUniformShape(Value *V) {
/// Return the ShapeInfo for the result of \p I, it it can be determined.
static std::optional<ShapeInfo>
computeShapeInfoForInst(Instruction *I,
- const ValueMap<Value *, ShapeInfo> &ShapeMap) {
+ const DenseMap<Value *, ShapeInfo> &ShapeMap) {
Value *M;
Value *N;
Value *K;
@@ -493,10 +493,8 @@ class LowerMatrixIntrinsics {
/// the result value of the instruction, with the only exceptions being store
/// instructions and the matrix_column_major_store intrinsics. For those, the
/// shape information indicates that those instructions should be lowered
- /// using shape information as well. A ValueMap is used so that when
- /// sub-passes like optimizeTransposes performs RAUW the map stays
- /// up-to-date.
- ValueMap<Value *, ShapeInfo> ShapeMap;
+ /// using shape information as well.
+ DenseMap<Value *, ShapeInfo> ShapeMap;
/// List of instructions to remove. While lowering, we are not replacing all
/// users of a lowered instruction, if shape information is available and
diff --git a/llvm/test/Transforms/LowerMatrixIntrinsics/dot-product-transpose-int.ll b/llvm/test/Transforms/LowerMatrixIntrinsics/dot-product-transpose-int.ll
index 2fd77e245a34e5..aadaf1ffffb23a 100644
--- a/llvm/test/Transforms/LowerMatrixIntrinsics/dot-product-transpose-int.ll
+++ b/llvm/test/Transforms/LowerMatrixIntrinsics/dot-product-transpose-int.ll
@@ -190,3 +190,33 @@ declare <1 x i32> @llvm.matrix.multiply.v1i32.v5i32.v5i32(<5 x i32>, <5 x i32>,
declare <5 x i32> @llvm.matrix.column.major.load.v5i32.i64(ptr nocapture, i64, i1 immarg, i32 immarg, i32 immarg) #1
declare <5 x i32> @llvm.matrix.transpose.v5i32(<5 x i32>, i32 immarg, i32 immarg) #0
+
+define <1 x i32> @test_dot_product_with_transposed_shuffle_op(<4 x i32> %a, <2 x i32> %b) {
+; CHECK-LABEL: @test_dot_product_with_transposed_shuffle_op(
+; CHECK-NEXT: entry:
+; CHECK-NEXT: [[SPLIT:%.*]] = shufflevector <4 x i32> [[A:%.*]], <4 x i32> poison, <2 x i32> <i32 0, i32 1>
+; CHECK-NEXT: [[SPLIT1:%.*]] = shufflevector <4 x i32> [[A]], <4 x i32> poison, <2 x i32> <i32 2, i32 3>
+; CHECK-NEXT: [[TMP0:%.*]] = extractelement <2 x i32> [[SPLIT]], i64 0
+; CHECK-NEXT: [[TMP1:%.*]] = insertelement <2 x i32> poison, i32 [[TMP0]], i64 0
+; CHECK-NEXT: [[TMP2:%.*]] = extractelement <2 x i32> [[SPLIT1]], i64 0
+; CHECK-NEXT: [[TMP3:%.*]] = insertelement <2 x i32> [[TMP1]], i32 [[TMP2]], i64 1
+; CHECK-NEXT: [[TMP4:%.*]] = extractelement <2 x i32> [[SPLIT]], i64 1
+; CHECK-NEXT: [[TMP5:%.*]] = insertelement <2 x i32> poison, i32 [[TMP4]], i64 0
+; CHECK-NEXT: [[TMP6:%.*]] = extractelement <2 x i32> [[SPLIT1]], i64 1
+; CHECK-NEXT: [[TMP7:%.*]] = insertelement <2 x i32> [[TMP5]], i32 [[TMP6]], i64 1
+; CHECK-NEXT: [[TMP8:%.*]] = shufflevector <2 x i32> [[TMP3]], <2 x i32> [[TMP7]], <4 x i32> <i32 0, i32 1, i32 2, i32 3>
+; CHECK-NEXT: [[SHUFFLE:%.*]] = shufflevector <4 x i32> [[TMP8]], <4 x i32> zeroinitializer, <2 x i32> <i32 0, i32 1>
+; CHECK-NEXT: [[TMP9:%.*]] = mul <2 x i32> [[SHUFFLE]], [[B:%.*]]
+; CHECK-NEXT: [[TMP10:%.*]] = call i32 @llvm.vector.reduce.add.v2i32(<2 x i32> [[TMP9]])
+; CHECK-NEXT: [[TMP11:%.*]] = insertelement <1 x i32> poison, i32 [[TMP10]], i64 0
+; CHECK-NEXT: ret <1 x i32> [[TMP11]]
+;
+entry:
+ %t.a = tail call <4 x i32> @llvm.matrix.transpose.v4i32(<4 x i32> %a, i32 2, i32 2)
+ %shuffle = shufflevector <4 x i32> %t.a, <4 x i32> zeroinitializer, <2 x i32> <i32 0, i32 1>
+ %t.shuffle = call <2 x i32> @llvm.matrix.transpose.v2i32(<2 x i32> %shuffle, i32 2, i32 1)
+ %m = call <1 x i32> @llvm.matrix.multiply.v1i32.v2i32.v2i32(<2 x i32> %t.shuffle, <2 x i32> %b, i32 1, i32 2, i32 1)
+ ret <1 x i32> %m
+}
+
+declare <2 x i32> @llvm.matrix.transpose.v2i32(<2 x i32>, i32 immarg, i32 immarg)
>From 3cfc930842050582f6fb7e2170ef866b9ae6a36f Mon Sep 17 00:00:00 2001
From: Florian Hahn <flo at fhahn.com>
Date: Mon, 2 Dec 2024 18:33:01 +0000
Subject: [PATCH 2/4] !fixup also remove entries from ShapeMap when removing
instructions.
---
.../Scalar/LowerMatrixIntrinsics.cpp | 51 +++++++++++--------
1 file changed, 29 insertions(+), 22 deletions(-)
diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
index 85cc069f098ea4..3e0da3ccd9ae5c 100644
--- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
+++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
@@ -97,19 +97,6 @@ static DISubprogram *getSubprogram(DIScope *Scope) {
return cast<DILocalScope>(Scope)->getSubprogram();
}
-/// Erase \p V from \p BB and move \II forward to avoid invalidating
-/// iterators.
-static void eraseFromParentAndMove(Value *V, BasicBlock::reverse_iterator &II,
- BasicBlock &BB) {
- auto *Inst = cast<Instruction>(V);
- // Still used, don't erase.
- if (!Inst->use_empty())
- return;
- if (II != BB.rend() && Inst == &*II)
- ++II;
- Inst->eraseFromParent();
-}
-
/// Return true if V is a splat of a value (which is used when multiplying a
/// matrix with a scalar).
static bool isSplat(Value *V) {
@@ -756,6 +743,26 @@ class LowerMatrixIntrinsics {
return Operation(T0, Shape0.t(), T1, Shape1.t());
}
+ void eraseFromParentAndRemoveFromShapeMap(Instruction *Inst) {
+ auto Iter = ShapeMap.find(Inst);
+ if (Iter != ShapeMap.end())
+ ShapeMap.erase(Iter);
+ Inst->eraseFromParent();
+ }
+
+ /// Erase \p V from \p BB and move \II forward to avoid invalidating
+ /// iterators.
+ void eraseFromParentAndMove(Value *V, BasicBlock::reverse_iterator &II,
+ BasicBlock &BB) {
+ auto *Inst = cast<Instruction>(V);
+ // Still used, don't erase.
+ if (!Inst->use_empty())
+ return;
+ if (II != BB.rend() && Inst == &*II)
+ ++II;
+ eraseFromParentAndRemoveFromShapeMap(Inst);
+ }
+
void updateShapeAndReplaceAllUsesWith(Instruction &Old, Value *New) {
// We need to remove Old from the ShapeMap otherwise RAUW will replace it
// with New. We should only add New it it supportsShapeInfo so we insert
@@ -869,13 +876,13 @@ class LowerMatrixIntrinsics {
void liftTranspose(Instruction &I) {
// Erase dead Instructions after lifting transposes from binops.
- auto CleanupBinOp = [](Instruction &T, Value *A, Value *B) {
+ auto CleanupBinOp = [this](Instruction &T, Value *A, Value *B) {
if (T.use_empty())
- T.eraseFromParent();
+ eraseFromParentAndRemoveFromShapeMap(&T);
if (A->use_empty())
- cast<Instruction>(A)->eraseFromParent();
+ eraseFromParentAndRemoveFromShapeMap(cast<Instruction>(A));
if (A != B && B->use_empty())
- cast<Instruction>(B)->eraseFromParent();
+ eraseFromParentAndRemoveFromShapeMap(cast<Instruction>(B));
};
Value *A, *B, *AT, *BT;
@@ -1482,7 +1489,7 @@ class LowerMatrixIntrinsics {
m_Value(Arg)))) {
auto *NewLoad = Builder.CreateLoad(Op->getType(), Arg);
Op->replaceAllUsesWith(NewLoad);
- cast<Instruction>(Op)->eraseFromParent();
+ eraseFromParentAndRemoveFromShapeMap(cast<Instruction>(Op));
return;
} else if (match(Op, m_Intrinsic<Intrinsic::matrix_transpose>(
m_Value(Arg)))) {
@@ -1851,15 +1858,15 @@ class LowerMatrixIntrinsics {
// Mark eliminated instructions as fused and remove them.
FusedInsts.insert(Store);
FusedInsts.insert(MatMul);
- Store->eraseFromParent();
- MatMul->eraseFromParent();
+ eraseFromParentAndRemoveFromShapeMap(Store);
+ eraseFromParentAndRemoveFromShapeMap(MatMul);
if (LoadOp0->hasNUses(0)) {
FusedInsts.insert(LoadOp0);
- LoadOp0->eraseFromParent();
+ eraseFromParentAndRemoveFromShapeMap(LoadOp0);
}
if (LoadOp1 != LoadOp0 && LoadOp1->hasNUses(0)) {
FusedInsts.insert(LoadOp1);
- LoadOp1->eraseFromParent();
+ eraseFromParentAndRemoveFromShapeMap(LoadOp1);
}
}
>From a92c86bf4d5287bb5aad3f9e2c0cd03ff4303e77 Mon Sep 17 00:00:00 2001
From: Florian Hahn <flo at fhahn.com>
Date: Mon, 2 Dec 2024 22:04:10 +0000
Subject: [PATCH 3/4] !fixup add test case showing issues
---
.../transpose-opts-lifting.ll | 20 +++++++++++++++++++
1 file changed, 20 insertions(+)
diff --git a/llvm/test/Transforms/LowerMatrixIntrinsics/transpose-opts-lifting.ll b/llvm/test/Transforms/LowerMatrixIntrinsics/transpose-opts-lifting.ll
index fcf83b03bc3d23..1b3b41d8cfe1f8 100644
--- a/llvm/test/Transforms/LowerMatrixIntrinsics/transpose-opts-lifting.ll
+++ b/llvm/test/Transforms/LowerMatrixIntrinsics/transpose-opts-lifting.ll
@@ -144,8 +144,28 @@ entry:
ret <6 x double> %mul
}
+define void @test_remove_entries_from_shape_map(<3 x float> %a, <2 x float> %b, <6 x float> %c, ptr %dst) {
+; CHECK-LABEL: define void @test_remove_entries_from_shape_map(
+; CHECK-SAME: <3 x float> [[A:%.*]], <2 x float> [[B:%.*]], <6 x float> [[C:%.*]], ptr [[DST:%.*]]) {
+; CHECK-NEXT: [[ENTRY:.*:]]
+; CHECK-NEXT: [[TMP0:%.*]] = call <6 x float> @llvm.matrix.multiply.v6f32.v3f32.v2f32(<3 x float> [[A]], <2 x float> [[B]], i32 3, i32 1, i32 2)
+; CHECK-NEXT: [[MFADD:%.*]] = fadd <6 x float> [[C]], [[TMP0]]
+; CHECK-NEXT: [[MFADD_T:%.*]] = call <6 x float> @llvm.matrix.transpose.v6f32(<6 x float> [[MFADD]], i32 3, i32 2)
+; CHECK-NEXT: store <6 x float> [[MFADD_T]], ptr [[DST]], align 4
+; CHECK-NEXT: ret void
+;
+entry:
+ %m = tail call <6 x float> @llvm.matrix.multiply.v6f32.v3f32.v2f32(<3 x float> %a, <2 x float> %b, i32 3, i32 1, i32 2)
+ %add = fadd <6 x float> %c, %m
+ %t = tail call <6 x float> @llvm.matrix.transpose.v6f32(<6 x float> %add, i32 3, i32 2)
+ store <6 x float> %t, ptr %dst, align 4
+ ret void
+}
+
declare <6 x double> @llvm.matrix.transpose.v6f64.v6f64(<6 x double>, i32, i32)
declare <4 x double> @llvm.matrix.transpose.v4f64.v4f64(<4 x double>, i32, i32)
declare <9 x double> @llvm.matrix.multiply.v9f64.v6f64(<6 x double>, <6 x double>, i32, i32, i32)
declare <6 x double> @llvm.matrix.multiply.v6f64.v6f64.v4f64(<6 x double>, <4 x double>, i32, i32, i32)
declare <6 x double> @llvm.matrix.multiply.v6f64.v6f64.v6f64(<6 x double>, <4 x double>, i32, i32, i32)
+declare <6 x float> @llvm.matrix.transpose.v6f32(<6 x float>, i32 immarg, i32 immarg)
+declare <6 x float> @llvm.matrix.multiply.v6f32.v3f32.v2f32(<3 x float>, <2 x float>, i32 immarg, i32 immarg, i32 immarg)
>From db6cdac075620e7ebf88fd8e2f88ae42540e5f7c Mon Sep 17 00:00:00 2001
From: Florian Hahn <flo at fhahn.com>
Date: Wed, 4 Dec 2024 13:21:04 +0000
Subject: [PATCH 4/4] !fixup improve docs.
---
.../Transforms/Scalar/LowerMatrixIntrinsics.cpp | 14 +++++++++++++-
1 file changed, 13 insertions(+), 1 deletion(-)
diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
index 3e0da3ccd9ae5c..29844c4630751e 100644
--- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
+++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
@@ -480,7 +480,15 @@ class LowerMatrixIntrinsics {
/// the result value of the instruction, with the only exceptions being store
/// instructions and the matrix_column_major_store intrinsics. For those, the
/// shape information indicates that those instructions should be lowered
- /// using shape information as well.
+ /// using shape information as well. Note that extra care is needed when
+ /// erasing or RAUW'ing a value that is present in ShapeMap. If the
+ /// replacement is also a matrix operation, use
+ /// updateShapeAndReplaceAllUsesWith to make sure the replacement is added to
+ /// ShapeMap. We don't use ValueMap, as there are also cases where we do not
+ /// want to add shape information for a replacement instruction. When directly
+ /// erasing a value with an entry in ShapeMap, use
+ /// eraseFromParentAndRemoveFromShapeMap to make sure ShapeMap is also updated
+ /// accordingly.
DenseMap<Value *, ShapeInfo> ShapeMap;
/// List of instructions to remove. While lowering, we are not replacing all
@@ -743,6 +751,8 @@ class LowerMatrixIntrinsics {
return Operation(T0, Shape0.t(), T1, Shape1.t());
}
+ /// Erase \p Inst from both ShapeMap (if an entry exists) and erase \p Inst
+ /// itself.
void eraseFromParentAndRemoveFromShapeMap(Instruction *Inst) {
auto Iter = ShapeMap.find(Inst);
if (Iter != ShapeMap.end())
@@ -763,6 +773,8 @@ class LowerMatrixIntrinsics {
eraseFromParentAndRemoveFromShapeMap(Inst);
}
+ /// Add a new entry to ShapeMap for \p New with \p Old's shape info, erase the
+ /// entry for \p Old and replace all uses of \p Old with \p New.
void updateShapeAndReplaceAllUsesWith(Instruction &Old, Value *New) {
// We need to remove Old from the ShapeMap otherwise RAUW will replace it
// with New. We should only add New it it supportsShapeInfo so we insert
More information about the llvm-commits
mailing list