[llvm] f10153f - [Matrix] Handle integer types when distributing transposes across adds.
Florian Hahn via llvm-commits
llvm-commits at lists.llvm.org
Fri Apr 21 08:35:45 PDT 2023
Author: Florian Hahn
Date: 2023-04-21T16:35:11+01:00
New Revision: f10153fe91508966aef062ba062271631f2c0f88
URL: https://github.com/llvm/llvm-project/commit/f10153fe91508966aef062ba062271631f2c0f88
DIFF: https://github.com/llvm/llvm-project/commit/f10153fe91508966aef062ba062271631f2c0f88.diff
LOG: [Matrix] Handle integer types when distributing transposes across adds.
The current code did not properly account for integer matrixes. Check
if the operands are floating point or integer matrixes and use FAdd/Add
accordingly.
This is already done for other cases, like multiplies.
Fixes #62281.
Added:
Modified:
llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
llvm/test/Transforms/LowerMatrixIntrinsics/after-transpose-opts.ll
Removed:
################################################################################
diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
index 161b14897069b..f9a149f3616aa 100644
--- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
+++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
@@ -838,10 +838,13 @@ class LowerMatrixIntrinsics {
auto NewInst = distributeTransposes(
TAMA, {R, C}, TAMB, {R, C}, Builder,
[&](Value *T0, ShapeInfo Shape0, Value *T1, ShapeInfo Shape1) {
- auto *FAdd =
- cast<Instruction>(LocalBuilder.CreateFAdd(T0, T1, "mfadd"));
- setShapeInfo(FAdd, Shape0);
- return FAdd;
+ bool IsFP = I.getType()->isFPOrFPVectorTy();
+ auto *Add = IsFP ? LocalBuilder.CreateFAdd(T0, T1, "madd")
+ : LocalBuilder.CreateAdd(T0, T1, "madd");
+
+ auto *Result = cast<Instruction>(Add);
+ setShapeInfo(Result, Shape0);
+ return Result;
});
updateShapeAndReplaceAllUsesWith(I, NewInst);
eraseFromParentAndMove(&I, II, BB);
diff --git a/llvm/test/Transforms/LowerMatrixIntrinsics/after-transpose-opts.ll b/llvm/test/Transforms/LowerMatrixIntrinsics/after-transpose-opts.ll
index 7510f43d31986..4a3b121afb6f5 100644
--- a/llvm/test/Transforms/LowerMatrixIntrinsics/after-transpose-opts.ll
+++ b/llvm/test/Transforms/LowerMatrixIntrinsics/after-transpose-opts.ll
@@ -121,8 +121,8 @@ define void @a_plus_b_t(ptr %Aptr, ptr %Bptr, ptr %C) {
; CHECK-NEXT: entry:
; CHECK-NEXT: [[A:%.*]] = load <9 x double>, ptr [[APTR:%.*]], align 128
; CHECK-NEXT: [[B:%.*]] = load <9 x double>, ptr [[BPTR:%.*]], align 128
-; CHECK-NEXT: [[MFADD1:%.*]] = fadd <9 x double> [[A]], [[B]]
-; CHECK-NEXT: [[MFADD_T:%.*]] = call <9 x double> @llvm.matrix.transpose.v9f64(<9 x double> [[MFADD1]], i32 3, i32 3)
+; CHECK-NEXT: [[MFADD:%.*]] = fadd <9 x double> [[A]], [[B]]
+; CHECK-NEXT: [[MFADD_T:%.*]] = call <9 x double> @llvm.matrix.transpose.v9f64(<9 x double> [[MFADD]], i32 3, i32 3)
; CHECK-NEXT: store <9 x double> [[MFADD_T]], ptr [[C:%.*]], align 128
; CHECK-NEXT: ret void
;
@@ -203,8 +203,8 @@ define void @atbt_plus_kctdt_t(ptr %Aptr, ptr %Bptr, ptr %Cptr, ptr %Dptr, doubl
; CHECK-NEXT: [[MMUL2:%.*]] = call <9 x double> @llvm.matrix.multiply.v9f64.v9f64.v9f64(<9 x double> [[B]], <9 x double> [[A]], i32 3, i32 3, i32 3)
; CHECK-NEXT: [[MMUL1:%.*]] = call <9 x double> @llvm.matrix.multiply.v9f64.v9f64.v9f64(<9 x double> [[C]], <9 x double> [[SPLAT]], i32 3, i32 3, i32 3)
; CHECK-NEXT: [[MMUL:%.*]] = call <9 x double> @llvm.matrix.multiply.v9f64.v9f64.v9f64(<9 x double> [[D]], <9 x double> [[MMUL1]], i32 3, i32 3, i32 3)
-; CHECK-NEXT: [[MFADD:%.*]] = fadd <9 x double> [[MMUL2]], [[MMUL]]
-; CHECK-NEXT: store <9 x double> [[MFADD]], ptr [[E:%.*]], align 128
+; CHECK-NEXT: [[MADD:%.*]] = fadd <9 x double> [[MMUL2]], [[MMUL]]
+; CHECK-NEXT: store <9 x double> [[MADD]], ptr [[E:%.*]], align 128
; CHECK-NEXT: ret void
;
entry:
@@ -257,3 +257,25 @@ entry:
declare <9 x double> @llvm.matrix.multiply.v9f64.v9f64.v9f64(<9 x double>, <9 x double>, i32, i32, i32)
declare <9 x double> @llvm.matrix.transpose.v9f64.v9f64(<9 x double>, i32, i32)
declare <9 x i32> @llvm.matrix.transpose.v9i32.v9i32(<9 x i32>, i32, i32)
+
+
+; (a * b + c)^T -> (a * b)^T + b^T with integer types.
+define noundef <4 x i32> @mul_add_transpose_int(<4 x i32> noundef %a, <4 x i32> noundef %b, <4 x i32> noundef %c) {
+; CHECK-LABEL: @mul_add_transpose_int(
+; CHECK-NEXT: entry:
+; CHECK-NEXT: [[TMP0:%.*]] = call <4 x i32> @llvm.matrix.multiply.v4i32.v4i32.v4i32(<4 x i32> [[A:%.*]], <4 x i32> [[B:%.*]], i32 2, i32 2, i32 2)
+; CHECK-NEXT: [[TMP1:%.*]] = call <4 x i32> @llvm.matrix.transpose.v4i32(<4 x i32> [[TMP0]], i32 2, i32 2)
+; CHECK-NEXT: [[C_T:%.*]] = call <4 x i32> @llvm.matrix.transpose.v4i32(<4 x i32> [[C:%.*]], i32 2, i32 2)
+; CHECK-NEXT: [[MADD:%.*]] = add <4 x i32> [[TMP1]], [[C_T]]
+; CHECK-NEXT: ret <4 x i32> [[MADD]]
+;
+entry:
+ %mul = tail call <4 x i32> @llvm.matrix.multiply.v4i32.v4i32.v4i32(<4 x i32> %a, <4 x i32> %b, i32 2, i32 2, i32 2)
+ %add = add <4 x i32> %mul, %c
+ %t = tail call <4 x i32> @llvm.matrix.transpose.v4i32(<4 x i32> %add, i32 2, i32 2)
+ ret <4 x i32> %t
+}
+
+declare <4 x i32> @llvm.matrix.multiply.v4i32.v4i32.v4i32(<4 x i32>, <4 x i32>, i32 immarg, i32 immarg, i32 immarg)
+
+declare <4 x i32> @llvm.matrix.transpose.v4i32(<4 x i32>, i32 immarg, i32 immarg)
More information about the llvm-commits
mailing list