[llvm-branch-commits] [llvm] release/19.x: [Matrix] Skip already fused instructions before trying to fuse multiply. (PR #118020)

via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Thu Nov 28 08:17:26 PST 2024


https://github.com/llvmbot created https://github.com/llvm/llvm-project/pull/118020

Backport 12cefcc7ecd2615069206b35b0ea81b9e78bb1ea

Requested by: @fhahn

>From c186e5d3e0e88ef5f285c0fd5f33ae826f7d9221 Mon Sep 17 00:00:00 2001
From: Florian Hahn <flo at fhahn.com>
Date: Thu, 28 Nov 2024 16:11:39 +0000
Subject: [PATCH] [Matrix] Skip already fused instructions before trying to
 fuse multiply.

lowerDotProduct called above may already lower a matrix multiply and
mark it as procssed by adding it to FusedInsts. Don't try to process it
again in LowerMatrixMultiplyFused by checking if FusedInsts.

Without this change, we trigger an assertion when trying to erase the
same original matrix multiply twice.

(cherry picked from commit 12cefcc7ecd2615069206b35b0ea81b9e78bb1ea)
---
 .../Scalar/LowerMatrixIntrinsics.cpp          |  3 +-
 .../dot-product-int-also-fusable-multiply.ll  | 49 +++++++++++++++++++
 2 files changed, 51 insertions(+), 1 deletion(-)
 create mode 100644 llvm/test/Transforms/LowerMatrixIntrinsics/dot-product-int-also-fusable-multiply.ll

diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
index 6a681fd9339717..a44a123fdf8cda 100644
--- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
+++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
@@ -1014,7 +1014,8 @@ class LowerMatrixIntrinsics {
 
     // Third, try to fuse candidates.
     for (CallInst *CI : MaybeFusableInsts)
-      LowerMatrixMultiplyFused(CI, FusedInsts, LifetimeEnds);
+      if (!FusedInsts.contains(CI))
+        LowerMatrixMultiplyFused(CI, FusedInsts, LifetimeEnds);
 
     Changed = !FusedInsts.empty();
 
diff --git a/llvm/test/Transforms/LowerMatrixIntrinsics/dot-product-int-also-fusable-multiply.ll b/llvm/test/Transforms/LowerMatrixIntrinsics/dot-product-int-also-fusable-multiply.ll
new file mode 100644
index 00000000000000..b78d56646d9e4d
--- /dev/null
+++ b/llvm/test/Transforms/LowerMatrixIntrinsics/dot-product-int-also-fusable-multiply.ll
@@ -0,0 +1,49 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
+; RUN: opt -p lower-matrix-intrinsics -S %s | FileCheck %s
+
+define void @test(ptr %p, <8 x i32> %x) {
+; CHECK-LABEL: define void @test(
+; CHECK-SAME: ptr [[P:%.*]], <8 x i32> [[X:%.*]]) {
+; CHECK-NEXT:    [[L:%.*]] = load <8 x i32>, ptr [[P]], align 4
+; CHECK-NEXT:    [[SPLIT:%.*]] = shufflevector <8 x i32> [[X]], <8 x i32> poison, <1 x i32> zeroinitializer
+; CHECK-NEXT:    [[SPLIT1:%.*]] = shufflevector <8 x i32> [[X]], <8 x i32> poison, <1 x i32> <i32 1>
+; CHECK-NEXT:    [[SPLIT2:%.*]] = shufflevector <8 x i32> [[X]], <8 x i32> poison, <1 x i32> <i32 2>
+; CHECK-NEXT:    [[SPLIT3:%.*]] = shufflevector <8 x i32> [[X]], <8 x i32> poison, <1 x i32> <i32 3>
+; CHECK-NEXT:    [[SPLIT4:%.*]] = shufflevector <8 x i32> [[X]], <8 x i32> poison, <1 x i32> <i32 4>
+; CHECK-NEXT:    [[SPLIT5:%.*]] = shufflevector <8 x i32> [[X]], <8 x i32> poison, <1 x i32> <i32 5>
+; CHECK-NEXT:    [[SPLIT6:%.*]] = shufflevector <8 x i32> [[X]], <8 x i32> poison, <1 x i32> <i32 6>
+; CHECK-NEXT:    [[SPLIT7:%.*]] = shufflevector <8 x i32> [[X]], <8 x i32> poison, <1 x i32> <i32 7>
+; CHECK-NEXT:    [[TMP1:%.*]] = extractelement <1 x i32> [[SPLIT]], i64 0
+; CHECK-NEXT:    [[TMP2:%.*]] = insertelement <8 x i32> poison, i32 [[TMP1]], i64 0
+; CHECK-NEXT:    [[TMP3:%.*]] = extractelement <1 x i32> [[SPLIT1]], i64 0
+; CHECK-NEXT:    [[TMP4:%.*]] = insertelement <8 x i32> [[TMP2]], i32 [[TMP3]], i64 1
+; CHECK-NEXT:    [[TMP5:%.*]] = extractelement <1 x i32> [[SPLIT2]], i64 0
+; CHECK-NEXT:    [[TMP6:%.*]] = insertelement <8 x i32> [[TMP4]], i32 [[TMP5]], i64 2
+; CHECK-NEXT:    [[TMP7:%.*]] = extractelement <1 x i32> [[SPLIT3]], i64 0
+; CHECK-NEXT:    [[TMP8:%.*]] = insertelement <8 x i32> [[TMP6]], i32 [[TMP7]], i64 3
+; CHECK-NEXT:    [[TMP9:%.*]] = extractelement <1 x i32> [[SPLIT4]], i64 0
+; CHECK-NEXT:    [[TMP10:%.*]] = insertelement <8 x i32> [[TMP8]], i32 [[TMP9]], i64 4
+; CHECK-NEXT:    [[TMP11:%.*]] = extractelement <1 x i32> [[SPLIT5]], i64 0
+; CHECK-NEXT:    [[TMP12:%.*]] = insertelement <8 x i32> [[TMP10]], i32 [[TMP11]], i64 5
+; CHECK-NEXT:    [[TMP13:%.*]] = extractelement <1 x i32> [[SPLIT6]], i64 0
+; CHECK-NEXT:    [[TMP14:%.*]] = insertelement <8 x i32> [[TMP12]], i32 [[TMP13]], i64 6
+; CHECK-NEXT:    [[TMP15:%.*]] = extractelement <1 x i32> [[SPLIT7]], i64 0
+; CHECK-NEXT:    [[TMP16:%.*]] = insertelement <8 x i32> [[TMP14]], i32 [[TMP15]], i64 7
+; CHECK-NEXT:    [[TMP17:%.*]] = mul <8 x i32> [[L]], [[TMP16]]
+; CHECK-NEXT:    [[TMP18:%.*]] = call i32 @llvm.vector.reduce.add.v8i32(<8 x i32> [[TMP17]])
+; CHECK-NEXT:    [[TMP19:%.*]] = insertelement <1 x i32> poison, i32 [[TMP18]], i64 0
+; CHECK-NEXT:    [[E:%.*]] = extractelement <1 x i32> [[TMP19]], i64 0
+; CHECK-NEXT:    store i32 [[E]], ptr [[P]], align 4
+; CHECK-NEXT:    ret void
+;
+  %l = load <8 x i32>, ptr %p, align 4
+  %t = tail call <8 x i32> @llvm.matrix.transpose.v8i32(<8 x i32> %x, i32 1, i32 8)
+  %m = tail call <1 x i32> @llvm.matrix.multiply.v1i32.v8i32.v8i32(<8 x i32> %l, <8 x i32> %t, i32 1, i32 8, i32 1)
+  %e = extractelement <1 x i32> %m, i64 0
+  store i32 %e, ptr %p, align 4
+  ret void
+}
+
+declare <8 x i32> @llvm.matrix.transpose.v8i32(<8 x i32>, i32 immarg, i32 immarg)
+
+declare <1 x i32> @llvm.matrix.multiply.v1i32.v8i32.v8i32(<8 x i32>, <8 x i32>, i32 immarg, i32 immarg, i32 immarg)



More information about the llvm-branch-commits mailing list