[llvm] ffb1c21 - [Matrix] Fix crash in liftTranspose when instructions are folded.
Florian Hahn via llvm-commits
llvm-commits at lists.llvm.org
Thu Dec 5 04:58:28 PST 2024
Author: Florian Hahn
Date: 2024-12-05T12:57:54Z
New Revision: ffb1c21bd4697ab5a6f11e2f2eba4fefaa298f2a
URL: https://github.com/llvm/llvm-project/commit/ffb1c21bd4697ab5a6f11e2f2eba4fefaa298f2a
DIFF: https://github.com/llvm/llvm-project/commit/ffb1c21bd4697ab5a6f11e2f2eba4fefaa298f2a.diff
LOG: [Matrix] Fix crash in liftTranspose when instructions are folded.
Builder.Create(F)Add may constant fold the inputs, return a constant
instead of an instruction. Account for that instead of crashing.
Added:
llvm/test/Transforms/LowerMatrixIntrinsics/transpose-opts-lifting-constant-folds.ll
Modified:
llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
Removed:
################################################################################
diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
index 29844c4630751e..796fba67ee2576 100644
--- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
+++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
@@ -924,8 +924,7 @@ class LowerMatrixIntrinsics {
match(B, m_Intrinsic<Intrinsic::matrix_transpose>(
m_Value(BT), m_ConstantInt(), m_ConstantInt()))) {
IRBuilder<> Builder(&I);
- auto *Add = cast<Instruction>(Builder.CreateFAdd(AT, BT, "mfadd"));
- setShapeInfo(Add, {R, C});
+ auto *Add = Builder.CreateFAdd(AT, BT, "mfadd");
MatrixBuilder MBuilder(Builder);
Instruction *NewInst = MBuilder.CreateMatrixTranspose(
Add, R->getZExtValue(), C->getZExtValue(), "mfadd_t");
@@ -934,9 +933,13 @@ class LowerMatrixIntrinsics {
computeShapeInfoForInst(&I, ShapeMap) &&
"Shape of new instruction doesn't match original shape.");
CleanupBinOp(I, A, B);
- assert(computeShapeInfoForInst(Add, ShapeMap).value_or(ShapeMap[Add]) ==
- ShapeMap[Add] &&
- "Shape of updated addition doesn't match cached shape.");
+ if (auto *AddI = dyn_cast<Instruction>(Add)) {
+ setShapeInfo(AddI, {R, C});
+ assert(
+ computeShapeInfoForInst(AddI, ShapeMap).value_or(ShapeMap[AddI]) ==
+ ShapeMap[AddI] &&
+ "Shape of updated addition doesn't match cached shape.");
+ }
}
}
diff --git a/llvm/test/Transforms/LowerMatrixIntrinsics/transpose-opts-lifting-constant-folds.ll b/llvm/test/Transforms/LowerMatrixIntrinsics/transpose-opts-lifting-constant-folds.ll
new file mode 100644
index 00000000000000..4eaa0e63a29c4c
--- /dev/null
+++ b/llvm/test/Transforms/LowerMatrixIntrinsics/transpose-opts-lifting-constant-folds.ll
@@ -0,0 +1,39 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
+; RUN: opt -p lower-matrix-intrinsics -S %s | FileCheck %s
+
+target datalayout = "e-m:o-i64:64-f80:128-n8:16:32:64-S128"
+
+define <8 x float> @transpose_constant_fold_fadd_AT_BT() {
+; CHECK-LABEL: define <8 x float> @transpose_constant_fold_fadd_AT_BT() {
+; CHECK-NEXT: [[ENTRY:.*:]]
+; CHECK-NEXT: ret <8 x float> splat (float 2.000000e+00)
+;
+entry:
+ %t = tail call <8 x float> @llvm.matrix.transpose.v8f32(<8 x float> splat (float 1.0), i32 8, i32 1)
+ %f = fadd <8 x float> %t, %t
+ ret <8 x float> %f
+}
+
+define <8 x float> @transpose_constant_fold_fmul_A_k() {
+; CHECK-LABEL: define <8 x float> @transpose_constant_fold_fmul_A_k() {
+; CHECK-NEXT: [[ENTRY:.*:]]
+; CHECK-NEXT: [[SPLAT:%.*]] = shufflevector <8 x float> splat (float 3.000000e+00), <8 x float> poison, <8 x i32> zeroinitializer
+; CHECK-NEXT: [[SPLIT:%.*]] = shufflevector <8 x float> [[SPLAT]], <8 x float> poison, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
+; CHECK-NEXT: [[SPLIT1:%.*]] = shufflevector <8 x float> [[SPLAT]], <8 x float> poison, <4 x i32> <i32 4, i32 5, i32 6, i32 7>
+; CHECK-NEXT: [[TMP0:%.*]] = fmul <4 x float> splat (float 1.000000e+00), [[SPLIT]]
+; CHECK-NEXT: [[TMP1:%.*]] = fmul <4 x float> splat (float 1.000000e+00), [[SPLIT1]]
+; CHECK-NEXT: [[TMP2:%.*]] = shufflevector <4 x float> [[TMP0]], <4 x float> [[TMP1]], <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
+; CHECK-NEXT: ret <8 x float> [[TMP2]]
+;
+entry:
+ %t.1 = tail call <8 x float> @llvm.matrix.transpose.v8f32(<8 x float> splat (float 1.0), i32 4, i32 2)
+ %splat = shufflevector <8 x float> splat (float 3.0), <8 x float> poison, <8 x i32> zeroinitializer
+ %m = fmul <8 x float> %t.1, %splat
+ %t.2 = tail call <8 x float> @llvm.matrix.transpose.v8f32(<8 x float> %m, i32 2, i32 4)
+ ret <8 x float> %t.2
+}
+
+; Function Attrs: nocallback nofree nosync nounwind speculatable willreturn memory(none)
+declare <8 x float> @llvm.matrix.transpose.v8f32(<8 x float>, i32 immarg, i32 immarg) #0
+
+attributes #0 = { nocallback nofree nosync nounwind speculatable willreturn memory(none) }
More information about the llvm-commits
mailing list