[llvm] ffb1c21 - [Matrix] Fix crash in liftTranspose when instructions are folded.

Florian Hahn via llvm-commits llvm-commits at lists.llvm.org
Thu Dec 5 04:58:28 PST 2024


Author: Florian Hahn
Date: 2024-12-05T12:57:54Z
New Revision: ffb1c21bd4697ab5a6f11e2f2eba4fefaa298f2a

URL: https://github.com/llvm/llvm-project/commit/ffb1c21bd4697ab5a6f11e2f2eba4fefaa298f2a
DIFF: https://github.com/llvm/llvm-project/commit/ffb1c21bd4697ab5a6f11e2f2eba4fefaa298f2a.diff

LOG: [Matrix] Fix crash in liftTranspose when instructions are folded.

Builder.Create(F)Add may constant fold the inputs, return a constant
instead of an instruction. Account for that instead of crashing.

Added: 
    llvm/test/Transforms/LowerMatrixIntrinsics/transpose-opts-lifting-constant-folds.ll

Modified: 
    llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
index 29844c4630751e..796fba67ee2576 100644
--- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
+++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
@@ -924,8 +924,7 @@ class LowerMatrixIntrinsics {
              match(B, m_Intrinsic<Intrinsic::matrix_transpose>(
                           m_Value(BT), m_ConstantInt(), m_ConstantInt()))) {
       IRBuilder<> Builder(&I);
-      auto *Add = cast<Instruction>(Builder.CreateFAdd(AT, BT, "mfadd"));
-      setShapeInfo(Add, {R, C});
+      auto *Add = Builder.CreateFAdd(AT, BT, "mfadd");
       MatrixBuilder MBuilder(Builder);
       Instruction *NewInst = MBuilder.CreateMatrixTranspose(
           Add, R->getZExtValue(), C->getZExtValue(), "mfadd_t");
@@ -934,9 +933,13 @@ class LowerMatrixIntrinsics {
                  computeShapeInfoForInst(&I, ShapeMap) &&
              "Shape of new instruction doesn't match original shape.");
       CleanupBinOp(I, A, B);
-      assert(computeShapeInfoForInst(Add, ShapeMap).value_or(ShapeMap[Add]) ==
-                 ShapeMap[Add] &&
-             "Shape of updated addition doesn't match cached shape.");
+      if (auto *AddI = dyn_cast<Instruction>(Add)) {
+        setShapeInfo(AddI, {R, C});
+        assert(
+            computeShapeInfoForInst(AddI, ShapeMap).value_or(ShapeMap[AddI]) ==
+                ShapeMap[AddI] &&
+            "Shape of updated addition doesn't match cached shape.");
+      }
     }
   }
 

diff  --git a/llvm/test/Transforms/LowerMatrixIntrinsics/transpose-opts-lifting-constant-folds.ll b/llvm/test/Transforms/LowerMatrixIntrinsics/transpose-opts-lifting-constant-folds.ll
new file mode 100644
index 00000000000000..4eaa0e63a29c4c
--- /dev/null
+++ b/llvm/test/Transforms/LowerMatrixIntrinsics/transpose-opts-lifting-constant-folds.ll
@@ -0,0 +1,39 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
+; RUN: opt -p lower-matrix-intrinsics -S %s | FileCheck %s
+
+target datalayout = "e-m:o-i64:64-f80:128-n8:16:32:64-S128"
+
+define <8 x float> @transpose_constant_fold_fadd_AT_BT() {
+; CHECK-LABEL: define <8 x float> @transpose_constant_fold_fadd_AT_BT() {
+; CHECK-NEXT:  [[ENTRY:.*:]]
+; CHECK-NEXT:    ret <8 x float> splat (float 2.000000e+00)
+;
+entry:
+  %t = tail call <8 x float> @llvm.matrix.transpose.v8f32(<8 x float> splat (float 1.0), i32 8, i32 1)
+  %f = fadd <8 x float> %t, %t
+  ret <8 x float> %f
+}
+
+define <8 x float> @transpose_constant_fold_fmul_A_k() {
+; CHECK-LABEL: define <8 x float> @transpose_constant_fold_fmul_A_k() {
+; CHECK-NEXT:  [[ENTRY:.*:]]
+; CHECK-NEXT:    [[SPLAT:%.*]] = shufflevector <8 x float> splat (float 3.000000e+00), <8 x float> poison, <8 x i32> zeroinitializer
+; CHECK-NEXT:    [[SPLIT:%.*]] = shufflevector <8 x float> [[SPLAT]], <8 x float> poison, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
+; CHECK-NEXT:    [[SPLIT1:%.*]] = shufflevector <8 x float> [[SPLAT]], <8 x float> poison, <4 x i32> <i32 4, i32 5, i32 6, i32 7>
+; CHECK-NEXT:    [[TMP0:%.*]] = fmul <4 x float> splat (float 1.000000e+00), [[SPLIT]]
+; CHECK-NEXT:    [[TMP1:%.*]] = fmul <4 x float> splat (float 1.000000e+00), [[SPLIT1]]
+; CHECK-NEXT:    [[TMP2:%.*]] = shufflevector <4 x float> [[TMP0]], <4 x float> [[TMP1]], <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
+; CHECK-NEXT:    ret <8 x float> [[TMP2]]
+;
+entry:
+  %t.1 = tail call <8 x float> @llvm.matrix.transpose.v8f32(<8 x float> splat (float 1.0), i32 4, i32 2)
+  %splat = shufflevector <8 x float> splat (float 3.0), <8 x float> poison, <8 x i32> zeroinitializer
+  %m = fmul <8 x float> %t.1, %splat
+  %t.2 = tail call <8 x float> @llvm.matrix.transpose.v8f32(<8 x float> %m, i32 2, i32 4)
+  ret <8 x float> %t.2
+}
+
+; Function Attrs: nocallback nofree nosync nounwind speculatable willreturn memory(none)
+declare <8 x float> @llvm.matrix.transpose.v8f32(<8 x float>, i32 immarg, i32 immarg) #0
+
+attributes #0 = { nocallback nofree nosync nounwind speculatable willreturn memory(none) }


        


More information about the llvm-commits mailing list