[llvm] 12cefcc - [Matrix] Skip already fused instructions before trying to fuse multiply.
Florian Hahn via llvm-commits
llvm-commits at lists.llvm.org
Thu Nov 28 08:11:47 PST 2024
Author: Florian Hahn
Date: 2024-11-28T16:11:40Z
New Revision: 12cefcc7ecd2615069206b35b0ea81b9e78bb1ea
URL: https://github.com/llvm/llvm-project/commit/12cefcc7ecd2615069206b35b0ea81b9e78bb1ea
DIFF: https://github.com/llvm/llvm-project/commit/12cefcc7ecd2615069206b35b0ea81b9e78bb1ea.diff
LOG: [Matrix] Skip already fused instructions before trying to fuse multiply.
lowerDotProduct called above may already lower a matrix multiply and
mark it as procssed by adding it to FusedInsts. Don't try to process it
again in LowerMatrixMultiplyFused by checking if FusedInsts.
Without this change, we trigger an assertion when trying to erase the
same original matrix multiply twice.
Added:
llvm/test/Transforms/LowerMatrixIntrinsics/dot-product-int-also-fusable-multiply.ll
Modified:
llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
Removed:
################################################################################
diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
index eaf58ea8dd9d06..62ab83dae8ae66 100644
--- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
+++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
@@ -1014,7 +1014,8 @@ class LowerMatrixIntrinsics {
// Third, try to fuse candidates.
for (CallInst *CI : MaybeFusableInsts)
- LowerMatrixMultiplyFused(CI, FusedInsts, LifetimeEnds);
+ if (!FusedInsts.contains(CI))
+ LowerMatrixMultiplyFused(CI, FusedInsts, LifetimeEnds);
Changed = !FusedInsts.empty();
diff --git a/llvm/test/Transforms/LowerMatrixIntrinsics/dot-product-int-also-fusable-multiply.ll b/llvm/test/Transforms/LowerMatrixIntrinsics/dot-product-int-also-fusable-multiply.ll
new file mode 100644
index 00000000000000..b78d56646d9e4d
--- /dev/null
+++ b/llvm/test/Transforms/LowerMatrixIntrinsics/dot-product-int-also-fusable-multiply.ll
@@ -0,0 +1,49 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
+; RUN: opt -p lower-matrix-intrinsics -S %s | FileCheck %s
+
+define void @test(ptr %p, <8 x i32> %x) {
+; CHECK-LABEL: define void @test(
+; CHECK-SAME: ptr [[P:%.*]], <8 x i32> [[X:%.*]]) {
+; CHECK-NEXT: [[L:%.*]] = load <8 x i32>, ptr [[P]], align 4
+; CHECK-NEXT: [[SPLIT:%.*]] = shufflevector <8 x i32> [[X]], <8 x i32> poison, <1 x i32> zeroinitializer
+; CHECK-NEXT: [[SPLIT1:%.*]] = shufflevector <8 x i32> [[X]], <8 x i32> poison, <1 x i32> <i32 1>
+; CHECK-NEXT: [[SPLIT2:%.*]] = shufflevector <8 x i32> [[X]], <8 x i32> poison, <1 x i32> <i32 2>
+; CHECK-NEXT: [[SPLIT3:%.*]] = shufflevector <8 x i32> [[X]], <8 x i32> poison, <1 x i32> <i32 3>
+; CHECK-NEXT: [[SPLIT4:%.*]] = shufflevector <8 x i32> [[X]], <8 x i32> poison, <1 x i32> <i32 4>
+; CHECK-NEXT: [[SPLIT5:%.*]] = shufflevector <8 x i32> [[X]], <8 x i32> poison, <1 x i32> <i32 5>
+; CHECK-NEXT: [[SPLIT6:%.*]] = shufflevector <8 x i32> [[X]], <8 x i32> poison, <1 x i32> <i32 6>
+; CHECK-NEXT: [[SPLIT7:%.*]] = shufflevector <8 x i32> [[X]], <8 x i32> poison, <1 x i32> <i32 7>
+; CHECK-NEXT: [[TMP1:%.*]] = extractelement <1 x i32> [[SPLIT]], i64 0
+; CHECK-NEXT: [[TMP2:%.*]] = insertelement <8 x i32> poison, i32 [[TMP1]], i64 0
+; CHECK-NEXT: [[TMP3:%.*]] = extractelement <1 x i32> [[SPLIT1]], i64 0
+; CHECK-NEXT: [[TMP4:%.*]] = insertelement <8 x i32> [[TMP2]], i32 [[TMP3]], i64 1
+; CHECK-NEXT: [[TMP5:%.*]] = extractelement <1 x i32> [[SPLIT2]], i64 0
+; CHECK-NEXT: [[TMP6:%.*]] = insertelement <8 x i32> [[TMP4]], i32 [[TMP5]], i64 2
+; CHECK-NEXT: [[TMP7:%.*]] = extractelement <1 x i32> [[SPLIT3]], i64 0
+; CHECK-NEXT: [[TMP8:%.*]] = insertelement <8 x i32> [[TMP6]], i32 [[TMP7]], i64 3
+; CHECK-NEXT: [[TMP9:%.*]] = extractelement <1 x i32> [[SPLIT4]], i64 0
+; CHECK-NEXT: [[TMP10:%.*]] = insertelement <8 x i32> [[TMP8]], i32 [[TMP9]], i64 4
+; CHECK-NEXT: [[TMP11:%.*]] = extractelement <1 x i32> [[SPLIT5]], i64 0
+; CHECK-NEXT: [[TMP12:%.*]] = insertelement <8 x i32> [[TMP10]], i32 [[TMP11]], i64 5
+; CHECK-NEXT: [[TMP13:%.*]] = extractelement <1 x i32> [[SPLIT6]], i64 0
+; CHECK-NEXT: [[TMP14:%.*]] = insertelement <8 x i32> [[TMP12]], i32 [[TMP13]], i64 6
+; CHECK-NEXT: [[TMP15:%.*]] = extractelement <1 x i32> [[SPLIT7]], i64 0
+; CHECK-NEXT: [[TMP16:%.*]] = insertelement <8 x i32> [[TMP14]], i32 [[TMP15]], i64 7
+; CHECK-NEXT: [[TMP17:%.*]] = mul <8 x i32> [[L]], [[TMP16]]
+; CHECK-NEXT: [[TMP18:%.*]] = call i32 @llvm.vector.reduce.add.v8i32(<8 x i32> [[TMP17]])
+; CHECK-NEXT: [[TMP19:%.*]] = insertelement <1 x i32> poison, i32 [[TMP18]], i64 0
+; CHECK-NEXT: [[E:%.*]] = extractelement <1 x i32> [[TMP19]], i64 0
+; CHECK-NEXT: store i32 [[E]], ptr [[P]], align 4
+; CHECK-NEXT: ret void
+;
+ %l = load <8 x i32>, ptr %p, align 4
+ %t = tail call <8 x i32> @llvm.matrix.transpose.v8i32(<8 x i32> %x, i32 1, i32 8)
+ %m = tail call <1 x i32> @llvm.matrix.multiply.v1i32.v8i32.v8i32(<8 x i32> %l, <8 x i32> %t, i32 1, i32 8, i32 1)
+ %e = extractelement <1 x i32> %m, i64 0
+ store i32 %e, ptr %p, align 4
+ ret void
+}
+
+declare <8 x i32> @llvm.matrix.transpose.v8i32(<8 x i32>, i32 immarg, i32 immarg)
+
+declare <1 x i32> @llvm.matrix.multiply.v1i32.v8i32.v8i32(<8 x i32>, <8 x i32>, i32 immarg, i32 immarg, i32 immarg)
More information about the llvm-commits
mailing list