[llvm] 78148eb - [Matrix] Fix crash during dot product lowering.

Florian Hahn via llvm-commits llvm-commits at lists.llvm.org
Wed Apr 12 07:09:07 PDT 2023


Author: Florian Hahn
Date: 2023-04-12T15:08:39+01:00
New Revision: 78148eba491142ddda3b8de8199b6d80fb34ed1a

URL: https://github.com/llvm/llvm-project/commit/78148eba491142ddda3b8de8199b6d80fb34ed1a
DIFF: https://github.com/llvm/llvm-project/commit/78148eba491142ddda3b8de8199b6d80fb34ed1a.diff

LOG: [Matrix] Fix crash during dot product lowering.

Perform dot-product lowering before instruction fusion to avoid crash in
newly added test. Also update lowerDotProduct to properly mark optimized
matmul as fused.

Added: 
    

Modified: 
    llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
    llvm/test/Transforms/LowerMatrixIntrinsics/dot-product-int.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
index 1147f018d5655..82e717fedb704 100644
--- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
+++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
@@ -978,17 +978,15 @@ class LowerMatrixIntrinsics {
         MatrixInsts.push_back(&I);
       }
 
-    // Second, try to fuse candidates.
+    // Second, try to lower any dot products
     SmallPtrSet<Instruction *, 16> FusedInsts;
+    for (CallInst *CI : MaybeFusableInsts)
+      lowerDotProduct(CI, FusedInsts, getFastMathFlags(CI));
+
+    // Third, try to fuse candidates.
     for (CallInst *CI : MaybeFusableInsts)
       LowerMatrixMultiplyFused(CI, FusedInsts);
 
-    // Third, try to lower any dot products
-    for (CallInst *CI : MaybeFusableInsts) {
-      if (FusedInsts.contains(CI)) // skip if already fused
-        continue;
-      lowerDotProduct(CI, FusedInsts, getFastMathFlags(CI));
-    }
     Changed = !FusedInsts.empty();
 
     // Fourth, lower remaining instructions with shape information.
@@ -1324,7 +1322,8 @@ class LowerMatrixIntrinsics {
   void lowerDotProduct(CallInst *MatMul,
                        SmallPtrSet<Instruction *, 16> &FusedInsts,
                        FastMathFlags FMF) {
-    if (MatrixLayout != MatrixLayoutTy::ColumnMajor)
+    if (FusedInsts.contains(MatMul) ||
+        MatrixLayout != MatrixLayoutTy::ColumnMajor)
       return;
     ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3));
     ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4));
@@ -1410,7 +1409,8 @@ class LowerMatrixIntrinsics {
     Result = Builder.CreateInsertElement(PoisonValue::get(MatMul->getType()),
                                          Result, uint64_t(0));
     MatMul->replaceAllUsesWith(Result);
-    MatMul->eraseFromParent();
+    FusedInsts.insert(MatMul);
+    ToRemove.push_back(MatMul);
   }
 
   /// Compute \p Result += \p A * \p B for input matrices with left-associating

diff  --git a/llvm/test/Transforms/LowerMatrixIntrinsics/dot-product-int.ll b/llvm/test/Transforms/LowerMatrixIntrinsics/dot-product-int.ll
index 0129dff311216..dd5c4cf3f4b4f 100644
--- a/llvm/test/Transforms/LowerMatrixIntrinsics/dot-product-int.ll
+++ b/llvm/test/Transforms/LowerMatrixIntrinsics/dot-product-int.ll
@@ -624,3 +624,41 @@ entry:
 }
 
 declare <1 x i16> @llvm.matrix.multiply.v1i16.v6i16.v6i16(<6 x i16>, <6 x i16>, i32, i32, i32)
+
+define void @transposed_multiply_feeding_dot_produc_v4i32() {
+; CHECK-LABEL: @transposed_multiply_feeding_dot_produc_v4i32(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[SCALAR_SPLAT_SPLAT_I_I_I_I:%.*]] = shufflevector <4 x i32> zeroinitializer, <4 x i32> zeroinitializer, <4 x i32> zeroinitializer
+; CHECK-NEXT:    [[SPLIT:%.*]] = shufflevector <4 x i32> [[SCALAR_SPLAT_SPLAT_I_I_I_I]], <4 x i32> poison, <2 x i32> <i32 0, i32 1>
+; CHECK-NEXT:    [[SPLIT1:%.*]] = shufflevector <4 x i32> [[SCALAR_SPLAT_SPLAT_I_I_I_I]], <4 x i32> poison, <2 x i32> <i32 2, i32 3>
+; CHECK-NEXT:    [[BLOCK:%.*]] = shufflevector <2 x i32> [[SPLIT]], <2 x i32> poison, <2 x i32> <i32 0, i32 1>
+; CHECK-NEXT:    [[TMP0:%.*]] = mul <2 x i32> [[BLOCK]], zeroinitializer
+; CHECK-NEXT:    [[BLOCK2:%.*]] = shufflevector <2 x i32> [[SPLIT1]], <2 x i32> poison, <2 x i32> <i32 0, i32 1>
+; CHECK-NEXT:    [[TMP1:%.*]] = mul <2 x i32> [[BLOCK2]], zeroinitializer
+; CHECK-NEXT:    [[TMP2:%.*]] = add <2 x i32> [[TMP0]], [[TMP1]]
+; CHECK-NEXT:    [[TMP3:%.*]] = shufflevector <2 x i32> [[TMP2]], <2 x i32> poison, <2 x i32> <i32 0, i32 1>
+; CHECK-NEXT:    [[TMP4:%.*]] = shufflevector <2 x i32> undef, <2 x i32> [[TMP3]], <2 x i32> <i32 2, i32 3>
+; CHECK-NEXT:    [[BLOCK3:%.*]] = shufflevector <2 x i32> [[SPLIT]], <2 x i32> poison, <2 x i32> <i32 0, i32 1>
+; CHECK-NEXT:    [[TMP5:%.*]] = mul <2 x i32> [[BLOCK3]], zeroinitializer
+; CHECK-NEXT:    [[BLOCK4:%.*]] = shufflevector <2 x i32> [[SPLIT1]], <2 x i32> poison, <2 x i32> <i32 0, i32 1>
+; CHECK-NEXT:    [[TMP6:%.*]] = mul <2 x i32> [[BLOCK4]], zeroinitializer
+; CHECK-NEXT:    [[TMP7:%.*]] = add <2 x i32> [[TMP5]], [[TMP6]]
+; CHECK-NEXT:    [[TMP8:%.*]] = shufflevector <2 x i32> [[TMP7]], <2 x i32> poison, <2 x i32> <i32 0, i32 1>
+; CHECK-NEXT:    [[TMP9:%.*]] = shufflevector <2 x i32> undef, <2 x i32> [[TMP8]], <2 x i32> <i32 2, i32 3>
+; CHECK-NEXT:    [[TMP10:%.*]] = shufflevector <2 x i32> [[TMP4]], <2 x i32> [[TMP9]], <4 x i32> <i32 0, i32 1, i32 2, i32 3>
+; CHECK-NEXT:    [[TMP11:%.*]] = mul <4 x i32> [[TMP10]], zeroinitializer
+; CHECK-NEXT:    [[TMP12:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[TMP11]])
+; CHECK-NEXT:    [[TMP13:%.*]] = insertelement <1 x i32> poison, i32 [[TMP12]], i64 0
+; CHECK-NEXT:    ret void
+;
+entry:
+  %scalar.splat.splat.i.i.i.i = shufflevector <4 x i32> zeroinitializer, <4 x i32> zeroinitializer, <4 x i32> zeroinitializer
+  %0 = call <4 x i32> @llvm.matrix.multiply.v4i32.v4i32.v4i32(<4 x i32> zeroinitializer, <4 x i32> %scalar.splat.splat.i.i.i.i, i32 2, i32 2, i32 2)
+  %1 = call <4 x i32> @llvm.matrix.transpose.v4i32(<4 x i32> %0, i32 4, i32 1)
+  %2 = call <1 x i32> @llvm.matrix.multiply.v1i32.v4i32.v4i32(<4 x i32> %1, <4 x i32> zeroinitializer, i32 1, i32 4, i32 1)
+  ret void
+}
+
+declare <4 x i32> @llvm.matrix.transpose.v4i32(<4 x i32>, i32 immarg, i32 immarg)
+
+declare <4 x i32> @llvm.matrix.multiply.v4i32.v4i32.v4i32(<4 x i32>, <4 x i32>, i32 immarg, i32 immarg, i32 immarg)


        


More information about the llvm-commits mailing list