[llvm] 7b6e0d9 - [Matrix] Use DenseMap for ShapeMap instead of ValueMap. (#118282)
via llvm-commits
llvm-commits at lists.llvm.org
Wed Dec 4 06:51:35 PST 2024
Author: Florian Hahn
Date: 2024-12-04T14:51:31Z
New Revision: 7b6e0d9fc3993f3e3df596fd16d97e2ed2e1d0aa
URL: https://github.com/llvm/llvm-project/commit/7b6e0d9fc3993f3e3df596fd16d97e2ed2e1d0aa
DIFF: https://github.com/llvm/llvm-project/commit/7b6e0d9fc3993f3e3df596fd16d97e2ed2e1d0aa.diff
LOG: [Matrix] Use DenseMap for ShapeMap instead of ValueMap. (#118282)
ValueMap automatically updates entries with the new value if they have
been RAUW. This can lead to instructions that are expected to not have
shape info to be added to the map (e.g. shufflevector as in the added
test case).
This leads to incorrect results. Originally it was used for transpose
optimizations, but they now all use updateShapeAndReplaceAllUsesWith,
which takes care of updating the shape info as needed.
This fixes a crash in the newly added test cases.
PR: https://github.com/llvm/llvm-project/pull/118282
diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
index 6a9ec48864b2c5..29844c4630751e 100644
--- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
+++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
@@ -97,19 +97,6 @@ static DISubprogram *getSubprogram(DIScope *Scope) {
return cast<DILocalScope>(Scope)->getSubprogram();
-/// Erase \p V from \p BB and move \II forward to avoid invalidating
-/// iterators.
-static void eraseFromParentAndMove(Value *V, BasicBlock::reverse_iterator &II,
- BasicBlock &BB) {
- auto *Inst = cast<Instruction>(V);
- // Still used, don't erase.
- if (!Inst->use_empty())
- return;
- if (II != BB.rend() && Inst == &*II)
- ++II;
- Inst->eraseFromParent();
/// Return true if V is a splat of a value (which is used when multiplying a
/// matrix with a scalar).
static bool isSplat(Value *V) {
@@ -259,7 +246,7 @@ static bool isUniformShape(Value *V) {
/// Return the ShapeInfo for the result of \p I, it it can be determined.
static std::optional<ShapeInfo>
computeShapeInfoForInst(Instruction *I,
- const ValueMap<Value *, ShapeInfo> &ShapeMap) {
+ const DenseMap<Value *, ShapeInfo> &ShapeMap) {
Value *M;
Value *N;
Value *K;
@@ -493,10 +480,16 @@ class LowerMatrixIntrinsics {
/// the result value of the instruction, with the only exceptions being store
/// instructions and the matrix_column_major_store intrinsics. For those, the
/// shape information indicates that those instructions should be lowered
- /// using shape information as well. A ValueMap is used so that when
- /// sub-passes like optimizeTransposes performs RAUW the map stays
- /// up-to-date.
- ValueMap<Value *, ShapeInfo> ShapeMap;
+ /// using shape information as well. Note that extra care is needed when
+ /// erasing or RAUW'ing a value that is present in ShapeMap. If the
+ /// replacement is also a matrix operation, use
+ /// updateShapeAndReplaceAllUsesWith to make sure the replacement is added to
+ /// ShapeMap. We don't use ValueMap, as there are also cases where we do not
+ /// want to add shape information for a replacement instruction. When directly
+ /// erasing a value with an entry in ShapeMap, use
+ /// eraseFromParentAndRemoveFromShapeMap to make sure ShapeMap is also updated
+ /// accordingly.
+ DenseMap<Value *, ShapeInfo> ShapeMap;
/// List of instructions to remove. While lowering, we are not replacing all
/// users of a lowered instruction, if shape information is available and
@@ -758,6 +751,30 @@ class LowerMatrixIntrinsics {
return Operation(T0, Shape0.t(), T1, Shape1.t());
+ /// Erase \p Inst from both ShapeMap (if an entry exists) and erase \p Inst
+ /// itself.
+ void eraseFromParentAndRemoveFromShapeMap(Instruction *Inst) {
+ auto Iter = ShapeMap.find(Inst);
+ if (Iter != ShapeMap.end())
+ ShapeMap.erase(Iter);
+ Inst->eraseFromParent();
+ }
+ /// Erase \p V from \p BB and move \II forward to avoid invalidating
+ /// iterators.
+ void eraseFromParentAndMove(Value *V, BasicBlock::reverse_iterator &II,
+ BasicBlock &BB) {
+ auto *Inst = cast<Instruction>(V);
+ // Still used, don't erase.
+ if (!Inst->use_empty())
+ return;
+ if (II != BB.rend() && Inst == &*II)
+ ++II;
+ eraseFromParentAndRemoveFromShapeMap(Inst);
+ }
+ /// Add a new entry to ShapeMap for \p New with \p Old's shape info, erase the
+ /// entry for \p Old and replace all uses of \p Old with \p New.
void updateShapeAndReplaceAllUsesWith(Instruction &Old, Value *New) {
// We need to remove Old from the ShapeMap otherwise RAUW will replace it
// with New. We should only add New it it supportsShapeInfo so we insert
@@ -871,13 +888,13 @@ class LowerMatrixIntrinsics {
void liftTranspose(Instruction &I) {
// Erase dead Instructions after lifting transposes from binops.
- auto CleanupBinOp = [](Instruction &T, Value *A, Value *B) {
+ auto CleanupBinOp = [this](Instruction &T, Value *A, Value *B) {
if (T.use_empty())
- T.eraseFromParent();
+ eraseFromParentAndRemoveFromShapeMap(&T);
if (A->use_empty())
- cast<Instruction>(A)->eraseFromParent();
+ eraseFromParentAndRemoveFromShapeMap(cast<Instruction>(A));
if (A != B && B->use_empty())
- cast<Instruction>(B)->eraseFromParent();
+ eraseFromParentAndRemoveFromShapeMap(cast<Instruction>(B));
Value *A, *B, *AT, *BT;
@@ -1484,7 +1501,7 @@ class LowerMatrixIntrinsics {
m_Value(Arg)))) {
auto *NewLoad = Builder.CreateLoad(Op->getType(), Arg);
- cast<Instruction>(Op)->eraseFromParent();
+ eraseFromParentAndRemoveFromShapeMap(cast<Instruction>(Op));
} else if (match(Op, m_Intrinsic<Intrinsic::matrix_transpose>(
m_Value(Arg)))) {
@@ -1853,15 +1870,15 @@ class LowerMatrixIntrinsics {
// Mark eliminated instructions as fused and remove them.
- Store->eraseFromParent();
- MatMul->eraseFromParent();
+ eraseFromParentAndRemoveFromShapeMap(Store);
+ eraseFromParentAndRemoveFromShapeMap(MatMul);
if (LoadOp0->hasNUses(0)) {
- LoadOp0->eraseFromParent();
+ eraseFromParentAndRemoveFromShapeMap(LoadOp0);
if (LoadOp1 != LoadOp0 && LoadOp1->hasNUses(0)) {
- LoadOp1->eraseFromParent();
+ eraseFromParentAndRemoveFromShapeMap(LoadOp1);
diff --git a/llvm/test/Transforms/LowerMatrixIntrinsics/dot-product-transpose-int.ll b/llvm/test/Transforms/LowerMatrixIntrinsics/dot-product-transpose-int.ll
index 2fd77e245a34e5..aadaf1ffffb23a 100644
--- a/llvm/test/Transforms/LowerMatrixIntrinsics/dot-product-transpose-int.ll
+++ b/llvm/test/Transforms/LowerMatrixIntrinsics/dot-product-transpose-int.ll
@@ -190,3 +190,33 @@ declare <1 x i32> @llvm.matrix.multiply.v1i32.v5i32.v5i32(<5 x i32>, <5 x i32>,
declare <5 x i32> @llvm.matrix.column.major.load.v5i32.i64(ptr nocapture, i64, i1 immarg, i32 immarg, i32 immarg) #1
declare <5 x i32> @llvm.matrix.transpose.v5i32(<5 x i32>, i32 immarg, i32 immarg) #0
+define <1 x i32> @test_dot_product_with_transposed_shuffle_op(<4 x i32> %a, <2 x i32> %b) {
+; CHECK-LABEL: @test_dot_product_with_transposed_shuffle_op(
+; CHECK-NEXT: entry:
+; CHECK-NEXT: [[SPLIT:%.*]] = shufflevector <4 x i32> [[A:%.*]], <4 x i32> poison, <2 x i32> <i32 0, i32 1>
+; CHECK-NEXT: [[SPLIT1:%.*]] = shufflevector <4 x i32> [[A]], <4 x i32> poison, <2 x i32> <i32 2, i32 3>
+; CHECK-NEXT: [[TMP0:%.*]] = extractelement <2 x i32> [[SPLIT]], i64 0
+; CHECK-NEXT: [[TMP1:%.*]] = insertelement <2 x i32> poison, i32 [[TMP0]], i64 0
+; CHECK-NEXT: [[TMP2:%.*]] = extractelement <2 x i32> [[SPLIT1]], i64 0
+; CHECK-NEXT: [[TMP3:%.*]] = insertelement <2 x i32> [[TMP1]], i32 [[TMP2]], i64 1
+; CHECK-NEXT: [[TMP4:%.*]] = extractelement <2 x i32> [[SPLIT]], i64 1
+; CHECK-NEXT: [[TMP5:%.*]] = insertelement <2 x i32> poison, i32 [[TMP4]], i64 0
+; CHECK-NEXT: [[TMP6:%.*]] = extractelement <2 x i32> [[SPLIT1]], i64 1
+; CHECK-NEXT: [[TMP7:%.*]] = insertelement <2 x i32> [[TMP5]], i32 [[TMP6]], i64 1
+; CHECK-NEXT: [[TMP8:%.*]] = shufflevector <2 x i32> [[TMP3]], <2 x i32> [[TMP7]], <4 x i32> <i32 0, i32 1, i32 2, i32 3>
+; CHECK-NEXT: [[SHUFFLE:%.*]] = shufflevector <4 x i32> [[TMP8]], <4 x i32> zeroinitializer, <2 x i32> <i32 0, i32 1>
+; CHECK-NEXT: [[TMP9:%.*]] = mul <2 x i32> [[SHUFFLE]], [[B:%.*]]
+; CHECK-NEXT: [[TMP10:%.*]] = call i32 @llvm.vector.reduce.add.v2i32(<2 x i32> [[TMP9]])
+; CHECK-NEXT: [[TMP11:%.*]] = insertelement <1 x i32> poison, i32 [[TMP10]], i64 0
+; CHECK-NEXT: ret <1 x i32> [[TMP11]]
+ %t.a = tail call <4 x i32> @llvm.matrix.transpose.v4i32(<4 x i32> %a, i32 2, i32 2)
+ %shuffle = shufflevector <4 x i32> %t.a, <4 x i32> zeroinitializer, <2 x i32> <i32 0, i32 1>
+ %t.shuffle = call <2 x i32> @llvm.matrix.transpose.v2i32(<2 x i32> %shuffle, i32 2, i32 1)
+ %m = call <1 x i32> @llvm.matrix.multiply.v1i32.v2i32.v2i32(<2 x i32> %t.shuffle, <2 x i32> %b, i32 1, i32 2, i32 1)
+ ret <1 x i32> %m
+declare <2 x i32> @llvm.matrix.transpose.v2i32(<2 x i32>, i32 immarg, i32 immarg)
diff --git a/llvm/test/Transforms/LowerMatrixIntrinsics/transpose-opts-lifting.ll b/llvm/test/Transforms/LowerMatrixIntrinsics/transpose-opts-lifting.ll
index fcf83b03bc3d23..1b3b41d8cfe1f8 100644
--- a/llvm/test/Transforms/LowerMatrixIntrinsics/transpose-opts-lifting.ll
+++ b/llvm/test/Transforms/LowerMatrixIntrinsics/transpose-opts-lifting.ll
@@ -144,8 +144,28 @@ entry:
ret <6 x double> %mul
+define void @test_remove_entries_from_shape_map(<3 x float> %a, <2 x float> %b, <6 x float> %c, ptr %dst) {
+; CHECK-LABEL: define void @test_remove_entries_from_shape_map(
+; CHECK-SAME: <3 x float> [[A:%.*]], <2 x float> [[B:%.*]], <6 x float> [[C:%.*]], ptr [[DST:%.*]]) {
+; CHECK-NEXT: [[ENTRY:.*:]]
+; CHECK-NEXT: [[TMP0:%.*]] = call <6 x float> @llvm.matrix.multiply.v6f32.v3f32.v2f32(<3 x float> [[A]], <2 x float> [[B]], i32 3, i32 1, i32 2)
+; CHECK-NEXT: [[MFADD:%.*]] = fadd <6 x float> [[C]], [[TMP0]]
+; CHECK-NEXT: [[MFADD_T:%.*]] = call <6 x float> @llvm.matrix.transpose.v6f32(<6 x float> [[MFADD]], i32 3, i32 2)
+; CHECK-NEXT: store <6 x float> [[MFADD_T]], ptr [[DST]], align 4
+; CHECK-NEXT: ret void
+ %m = tail call <6 x float> @llvm.matrix.multiply.v6f32.v3f32.v2f32(<3 x float> %a, <2 x float> %b, i32 3, i32 1, i32 2)
+ %add = fadd <6 x float> %c, %m
+ %t = tail call <6 x float> @llvm.matrix.transpose.v6f32(<6 x float> %add, i32 3, i32 2)
+ store <6 x float> %t, ptr %dst, align 4
+ ret void
declare <6 x double> @llvm.matrix.transpose.v6f64.v6f64(<6 x double>, i32, i32)
declare <4 x double> @llvm.matrix.transpose.v4f64.v4f64(<4 x double>, i32, i32)
declare <9 x double> @llvm.matrix.multiply.v9f64.v6f64(<6 x double>, <6 x double>, i32, i32, i32)
declare <6 x double> @llvm.matrix.multiply.v6f64.v6f64.v4f64(<6 x double>, <4 x double>, i32, i32, i32)
declare <6 x double> @llvm.matrix.multiply.v6f64.v6f64.v6f64(<6 x double>, <4 x double>, i32, i32, i32)
+declare <6 x float> @llvm.matrix.transpose.v6f32(<6 x float>, i32 immarg, i32 immarg)
+declare <6 x float> @llvm.matrix.multiply.v6f32.v3f32.v2f32(<3 x float>, <2 x float>, i32 immarg, i32 immarg, i32 immarg)
More information about the llvm-commits
mailing list