[llvm] [Matrix] Fix dimensions when hoisting transpose across add. (PR #81507)

Florian Hahn via llvm-commits llvm-commits at lists.llvm.org
Mon Feb 12 10:44:10 PST 2024


https://github.com/fhahn updated https://github.com/llvm/llvm-project/pull/81507

>From c2fde6dacb03700b178e0bc039dfe928705f20a3 Mon Sep 17 00:00:00 2001
From: Florian Hahn <flo at fhahn.com>
Date: Mon, 12 Feb 2024 18:14:23 +0000
Subject: [PATCH 1/2] [Matrix] Fix dimensions when hoisting transpose across
 add.

Row and column arguments for matrix_transpose indicate the shape of the
operand. When hoisting the transpose to the result of the add, the add
operates on the original operand's shape, and so does the hoisted
transpose.

This patch also adds an assert that the shape for the original add and
the transpose match, as well as the shape of the new add matches the
cached shape for it.

The assert could potentially be moved to updateShapeAndReplaceAllUsesWith.
---
 .../Scalar/LowerMatrixIntrinsics.cpp          |  21 +++-
 .../propagate-backward.ll                     | 119 +++++++++---------
 .../transpose-opts-lifting.ll                 |  44 ++++++-
 3 files changed, 115 insertions(+), 69 deletions(-)

diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
index 03e289f7a087ac..075388f69a85b5 100644
--- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
+++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
@@ -765,8 +765,9 @@ class LowerMatrixIntrinsics {
     auto S = ShapeMap.find(&Old);
     if (S != ShapeMap.end()) {
       ShapeMap.erase(S);
-      if (supportsShapeInfo(New))
+      if (supportsShapeInfo(New)) {
         ShapeMap.insert({New, S->second});
+      }
     }
     Old.replaceAllUsesWith(New);
   }
@@ -898,20 +899,28 @@ class LowerMatrixIntrinsics {
       updateShapeAndReplaceAllUsesWith(I, NewInst);
       CleanupBinOp(I, A, B);
     }
-    // A^t + B ^t -> (A + B)^t
+    // A^t + B ^t -> (A + B)^t. Pick rows and columns from first transpose. If
+    // the shape of the second transpose is different, there's a shape conflict
+    // which gets resolved by picking the shape of the first operand.
     else if (match(&I, m_FAdd(m_Value(A), m_Value(B))) &&
              match(A, m_Intrinsic<Intrinsic::matrix_transpose>(
                           m_Value(AT), m_ConstantInt(R), m_ConstantInt(C))) &&
              match(B, m_Intrinsic<Intrinsic::matrix_transpose>(
-                          m_Value(BT), m_ConstantInt(R), m_ConstantInt(C)))) {
+                          m_Value(BT), m_ConstantInt(), m_ConstantInt()))) {
       IRBuilder<> Builder(&I);
-      Value *Add = cast<Instruction>(Builder.CreateFAdd(AT, BT, "mfadd"));
-      setShapeInfo(Add, {C, R});
+      auto *Add = cast<Instruction>(Builder.CreateFAdd(AT, BT, "mfadd"));
+      setShapeInfo(Add, {R, C});
       MatrixBuilder MBuilder(Builder);
       Instruction *NewInst = MBuilder.CreateMatrixTranspose(
-          Add, C->getZExtValue(), R->getZExtValue(), "mfadd_t");
+          Add, R->getZExtValue(), C->getZExtValue(), "mfadd_t");
       updateShapeAndReplaceAllUsesWith(I, NewInst);
+      assert(computeShapeInfoForInst(NewInst, ShapeMap) ==
+                 computeShapeInfoForInst(&I, ShapeMap) &&
+             "Shape of new instruction doesn't match original shape.");
       CleanupBinOp(I, A, B);
+      assert(computeShapeInfoForInst(Add, ShapeMap).value_or(ShapeMap[Add]) ==
+                 ShapeMap[Add] &&
+             "Shape of updated addition doesn't match cached shape.");
     }
   }
 
diff --git a/llvm/test/Transforms/LowerMatrixIntrinsics/propagate-backward.ll b/llvm/test/Transforms/LowerMatrixIntrinsics/propagate-backward.ll
index 82ae93b31035db..33a338dbc4ea0d 100644
--- a/llvm/test/Transforms/LowerMatrixIntrinsics/propagate-backward.ll
+++ b/llvm/test/Transforms/LowerMatrixIntrinsics/propagate-backward.ll
@@ -4,31 +4,35 @@
 define <8 x double> @fadd_transpose(<8 x double> %a, <8 x double> %b) {
 ; CHECK-LABEL: @fadd_transpose(
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    [[SPLIT:%.*]] = shufflevector <8 x double> [[A:%.*]], <8 x double> poison, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
-; CHECK-NEXT:    [[SPLIT2:%.*]] = shufflevector <8 x double> [[A]], <8 x double> poison, <4 x i32> <i32 4, i32 5, i32 6, i32 7>
-; CHECK-NEXT:    [[SPLIT3:%.*]] = shufflevector <8 x double> [[B:%.*]], <8 x double> poison, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
-; CHECK-NEXT:    [[SPLIT4:%.*]] = shufflevector <8 x double> [[B]], <8 x double> poison, <4 x i32> <i32 4, i32 5, i32 6, i32 7>
-; CHECK-NEXT:    [[TMP0:%.*]] = fadd <4 x double> [[SPLIT]], [[SPLIT3]]
-; CHECK-NEXT:    [[TMP1:%.*]] = fadd <4 x double> [[SPLIT2]], [[SPLIT4]]
-; CHECK-NEXT:    [[TMP2:%.*]] = extractelement <4 x double> [[TMP0]], i64 0
-; CHECK-NEXT:    [[TMP3:%.*]] = insertelement <2 x double> poison, double [[TMP2]], i64 0
-; CHECK-NEXT:    [[TMP4:%.*]] = extractelement <4 x double> [[TMP1]], i64 0
-; CHECK-NEXT:    [[TMP5:%.*]] = insertelement <2 x double> [[TMP3]], double [[TMP4]], i64 1
-; CHECK-NEXT:    [[TMP6:%.*]] = extractelement <4 x double> [[TMP0]], i64 1
-; CHECK-NEXT:    [[TMP7:%.*]] = insertelement <2 x double> poison, double [[TMP6]], i64 0
-; CHECK-NEXT:    [[TMP8:%.*]] = extractelement <4 x double> [[TMP1]], i64 1
-; CHECK-NEXT:    [[TMP9:%.*]] = insertelement <2 x double> [[TMP7]], double [[TMP8]], i64 1
-; CHECK-NEXT:    [[TMP10:%.*]] = extractelement <4 x double> [[TMP0]], i64 2
-; CHECK-NEXT:    [[TMP11:%.*]] = insertelement <2 x double> poison, double [[TMP10]], i64 0
-; CHECK-NEXT:    [[TMP12:%.*]] = extractelement <4 x double> [[TMP1]], i64 2
-; CHECK-NEXT:    [[TMP13:%.*]] = insertelement <2 x double> [[TMP11]], double [[TMP12]], i64 1
-; CHECK-NEXT:    [[TMP14:%.*]] = extractelement <4 x double> [[TMP0]], i64 3
-; CHECK-NEXT:    [[TMP15:%.*]] = insertelement <2 x double> poison, double [[TMP14]], i64 0
-; CHECK-NEXT:    [[TMP16:%.*]] = extractelement <4 x double> [[TMP1]], i64 3
-; CHECK-NEXT:    [[TMP17:%.*]] = insertelement <2 x double> [[TMP15]], double [[TMP16]], i64 1
-; CHECK-NEXT:    [[TMP18:%.*]] = shufflevector <2 x double> [[TMP5]], <2 x double> [[TMP9]], <4 x i32> <i32 0, i32 1, i32 2, i32 3>
-; CHECK-NEXT:    [[TMP19:%.*]] = shufflevector <2 x double> [[TMP13]], <2 x double> [[TMP17]], <4 x i32> <i32 0, i32 1, i32 2, i32 3>
-; CHECK-NEXT:    [[TMP20:%.*]] = shufflevector <4 x double> [[TMP18]], <4 x double> [[TMP19]], <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
+; CHECK-NEXT:    [[SPLIT:%.*]] = shufflevector <8 x double> [[A:%.*]], <8 x double> poison, <2 x i32> <i32 0, i32 1>
+; CHECK-NEXT:    [[SPLIT1:%.*]] = shufflevector <8 x double> [[A]], <8 x double> poison, <2 x i32> <i32 2, i32 3>
+; CHECK-NEXT:    [[SPLIT2:%.*]] = shufflevector <8 x double> [[A]], <8 x double> poison, <2 x i32> <i32 4, i32 5>
+; CHECK-NEXT:    [[SPLIT3:%.*]] = shufflevector <8 x double> [[A]], <8 x double> poison, <2 x i32> <i32 6, i32 7>
+; CHECK-NEXT:    [[SPLIT4:%.*]] = shufflevector <8 x double> [[B:%.*]], <8 x double> poison, <2 x i32> <i32 0, i32 1>
+; CHECK-NEXT:    [[SPLIT5:%.*]] = shufflevector <8 x double> [[B]], <8 x double> poison, <2 x i32> <i32 2, i32 3>
+; CHECK-NEXT:    [[SPLIT6:%.*]] = shufflevector <8 x double> [[B]], <8 x double> poison, <2 x i32> <i32 4, i32 5>
+; CHECK-NEXT:    [[SPLIT7:%.*]] = shufflevector <8 x double> [[B]], <8 x double> poison, <2 x i32> <i32 6, i32 7>
+; CHECK-NEXT:    [[TMP0:%.*]] = fadd <2 x double> [[SPLIT]], [[SPLIT4]]
+; CHECK-NEXT:    [[TMP1:%.*]] = fadd <2 x double> [[SPLIT1]], [[SPLIT5]]
+; CHECK-NEXT:    [[TMP2:%.*]] = fadd <2 x double> [[SPLIT2]], [[SPLIT6]]
+; CHECK-NEXT:    [[TMP3:%.*]] = fadd <2 x double> [[SPLIT3]], [[SPLIT7]]
+; CHECK-NEXT:    [[TMP4:%.*]] = extractelement <2 x double> [[TMP0]], i64 0
+; CHECK-NEXT:    [[TMP5:%.*]] = insertelement <4 x double> poison, double [[TMP4]], i64 0
+; CHECK-NEXT:    [[TMP6:%.*]] = extractelement <2 x double> [[TMP1]], i64 0
+; CHECK-NEXT:    [[TMP7:%.*]] = insertelement <4 x double> [[TMP5]], double [[TMP6]], i64 1
+; CHECK-NEXT:    [[TMP8:%.*]] = extractelement <2 x double> [[TMP2]], i64 0
+; CHECK-NEXT:    [[TMP9:%.*]] = insertelement <4 x double> [[TMP7]], double [[TMP8]], i64 2
+; CHECK-NEXT:    [[TMP10:%.*]] = extractelement <2 x double> [[TMP3]], i64 0
+; CHECK-NEXT:    [[TMP11:%.*]] = insertelement <4 x double> [[TMP9]], double [[TMP10]], i64 3
+; CHECK-NEXT:    [[TMP12:%.*]] = extractelement <2 x double> [[TMP0]], i64 1
+; CHECK-NEXT:    [[TMP13:%.*]] = insertelement <4 x double> poison, double [[TMP12]], i64 0
+; CHECK-NEXT:    [[TMP14:%.*]] = extractelement <2 x double> [[TMP1]], i64 1
+; CHECK-NEXT:    [[TMP15:%.*]] = insertelement <4 x double> [[TMP13]], double [[TMP14]], i64 1
+; CHECK-NEXT:    [[TMP16:%.*]] = extractelement <2 x double> [[TMP2]], i64 1
+; CHECK-NEXT:    [[TMP17:%.*]] = insertelement <4 x double> [[TMP15]], double [[TMP16]], i64 2
+; CHECK-NEXT:    [[TMP18:%.*]] = extractelement <2 x double> [[TMP3]], i64 1
+; CHECK-NEXT:    [[TMP19:%.*]] = insertelement <4 x double> [[TMP17]], double [[TMP18]], i64 3
+; CHECK-NEXT:    [[TMP20:%.*]] = shufflevector <4 x double> [[TMP11]], <4 x double> [[TMP19]], <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
 ; CHECK-NEXT:    ret <8 x double> [[TMP20]]
 ;
 entry:
@@ -42,40 +46,37 @@ define <8 x double> @load_fadd_transpose(ptr %A.Ptr, <8 x double> %b) {
 ; CHECK-NEXT:  entry:
 ; CHECK-NEXT:    [[COL_LOAD:%.*]] = load <2 x double>, ptr [[A_PTR:%.*]], align 8
 ; CHECK-NEXT:    [[VEC_GEP:%.*]] = getelementptr double, ptr [[A_PTR]], i64 2
-; CHECK-NEXT:    [[COL_LOAD2:%.*]] = load <2 x double>, ptr [[VEC_GEP]], align 8
-; CHECK-NEXT:    [[VEC_GEP3:%.*]] = getelementptr double, ptr [[A_PTR]], i64 4
-; CHECK-NEXT:    [[COL_LOAD4:%.*]] = load <2 x double>, ptr [[VEC_GEP3]], align 8
-; CHECK-NEXT:    [[VEC_GEP5:%.*]] = getelementptr double, ptr [[A_PTR]], i64 6
-; CHECK-NEXT:    [[COL_LOAD6:%.*]] = load <2 x double>, ptr [[VEC_GEP5]], align 8
-; CHECK-NEXT:    [[TMP0:%.*]] = shufflevector <2 x double> [[COL_LOAD]], <2 x double> [[COL_LOAD2]], <4 x i32> <i32 0, i32 1, i32 2, i32 3>
-; CHECK-NEXT:    [[TMP1:%.*]] = shufflevector <2 x double> [[COL_LOAD4]], <2 x double> [[COL_LOAD6]], <4 x i32> <i32 0, i32 1, i32 2, i32 3>
-; CHECK-NEXT:    [[TMP2:%.*]] = shufflevector <4 x double> [[TMP0]], <4 x double> [[TMP1]], <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
-; CHECK-NEXT:    [[SPLIT:%.*]] = shufflevector <8 x double> [[TMP2]], <8 x double> poison, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
-; CHECK-NEXT:    [[SPLIT7:%.*]] = shufflevector <8 x double> [[TMP2]], <8 x double> poison, <4 x i32> <i32 4, i32 5, i32 6, i32 7>
-; CHECK-NEXT:    [[SPLIT8:%.*]] = shufflevector <8 x double> [[B:%.*]], <8 x double> poison, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
-; CHECK-NEXT:    [[SPLIT9:%.*]] = shufflevector <8 x double> [[B]], <8 x double> poison, <4 x i32> <i32 4, i32 5, i32 6, i32 7>
-; CHECK-NEXT:    [[TMP3:%.*]] = fadd <4 x double> [[SPLIT]], [[SPLIT8]]
-; CHECK-NEXT:    [[TMP4:%.*]] = fadd <4 x double> [[SPLIT7]], [[SPLIT9]]
-; CHECK-NEXT:    [[TMP5:%.*]] = extractelement <4 x double> [[TMP3]], i64 0
-; CHECK-NEXT:    [[TMP6:%.*]] = insertelement <2 x double> poison, double [[TMP5]], i64 0
-; CHECK-NEXT:    [[TMP7:%.*]] = extractelement <4 x double> [[TMP4]], i64 0
-; CHECK-NEXT:    [[TMP8:%.*]] = insertelement <2 x double> [[TMP6]], double [[TMP7]], i64 1
-; CHECK-NEXT:    [[TMP9:%.*]] = extractelement <4 x double> [[TMP3]], i64 1
-; CHECK-NEXT:    [[TMP10:%.*]] = insertelement <2 x double> poison, double [[TMP9]], i64 0
-; CHECK-NEXT:    [[TMP11:%.*]] = extractelement <4 x double> [[TMP4]], i64 1
-; CHECK-NEXT:    [[TMP12:%.*]] = insertelement <2 x double> [[TMP10]], double [[TMP11]], i64 1
-; CHECK-NEXT:    [[TMP13:%.*]] = extractelement <4 x double> [[TMP3]], i64 2
-; CHECK-NEXT:    [[TMP14:%.*]] = insertelement <2 x double> poison, double [[TMP13]], i64 0
-; CHECK-NEXT:    [[TMP15:%.*]] = extractelement <4 x double> [[TMP4]], i64 2
-; CHECK-NEXT:    [[TMP16:%.*]] = insertelement <2 x double> [[TMP14]], double [[TMP15]], i64 1
-; CHECK-NEXT:    [[TMP17:%.*]] = extractelement <4 x double> [[TMP3]], i64 3
-; CHECK-NEXT:    [[TMP18:%.*]] = insertelement <2 x double> poison, double [[TMP17]], i64 0
-; CHECK-NEXT:    [[TMP19:%.*]] = extractelement <4 x double> [[TMP4]], i64 3
-; CHECK-NEXT:    [[TMP20:%.*]] = insertelement <2 x double> [[TMP18]], double [[TMP19]], i64 1
-; CHECK-NEXT:    [[TMP21:%.*]] = shufflevector <2 x double> [[TMP8]], <2 x double> [[TMP12]], <4 x i32> <i32 0, i32 1, i32 2, i32 3>
-; CHECK-NEXT:    [[TMP22:%.*]] = shufflevector <2 x double> [[TMP16]], <2 x double> [[TMP20]], <4 x i32> <i32 0, i32 1, i32 2, i32 3>
-; CHECK-NEXT:    [[TMP23:%.*]] = shufflevector <4 x double> [[TMP21]], <4 x double> [[TMP22]], <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
-; CHECK-NEXT:    ret <8 x double> [[TMP23]]
+; CHECK-NEXT:    [[COL_LOAD1:%.*]] = load <2 x double>, ptr [[VEC_GEP]], align 8
+; CHECK-NEXT:    [[VEC_GEP2:%.*]] = getelementptr double, ptr [[A_PTR]], i64 4
+; CHECK-NEXT:    [[COL_LOAD3:%.*]] = load <2 x double>, ptr [[VEC_GEP2]], align 8
+; CHECK-NEXT:    [[VEC_GEP4:%.*]] = getelementptr double, ptr [[A_PTR]], i64 6
+; CHECK-NEXT:    [[COL_LOAD5:%.*]] = load <2 x double>, ptr [[VEC_GEP4]], align 8
+; CHECK-NEXT:    [[SPLIT:%.*]] = shufflevector <8 x double> [[B:%.*]], <8 x double> poison, <2 x i32> <i32 0, i32 1>
+; CHECK-NEXT:    [[SPLIT6:%.*]] = shufflevector <8 x double> [[B]], <8 x double> poison, <2 x i32> <i32 2, i32 3>
+; CHECK-NEXT:    [[SPLIT7:%.*]] = shufflevector <8 x double> [[B]], <8 x double> poison, <2 x i32> <i32 4, i32 5>
+; CHECK-NEXT:    [[SPLIT8:%.*]] = shufflevector <8 x double> [[B]], <8 x double> poison, <2 x i32> <i32 6, i32 7>
+; CHECK-NEXT:    [[TMP0:%.*]] = fadd <2 x double> [[COL_LOAD]], [[SPLIT]]
+; CHECK-NEXT:    [[TMP1:%.*]] = fadd <2 x double> [[COL_LOAD1]], [[SPLIT6]]
+; CHECK-NEXT:    [[TMP2:%.*]] = fadd <2 x double> [[COL_LOAD3]], [[SPLIT7]]
+; CHECK-NEXT:    [[TMP3:%.*]] = fadd <2 x double> [[COL_LOAD5]], [[SPLIT8]]
+; CHECK-NEXT:    [[TMP4:%.*]] = extractelement <2 x double> [[TMP0]], i64 0
+; CHECK-NEXT:    [[TMP5:%.*]] = insertelement <4 x double> poison, double [[TMP4]], i64 0
+; CHECK-NEXT:    [[TMP6:%.*]] = extractelement <2 x double> [[TMP1]], i64 0
+; CHECK-NEXT:    [[TMP7:%.*]] = insertelement <4 x double> [[TMP5]], double [[TMP6]], i64 1
+; CHECK-NEXT:    [[TMP8:%.*]] = extractelement <2 x double> [[TMP2]], i64 0
+; CHECK-NEXT:    [[TMP9:%.*]] = insertelement <4 x double> [[TMP7]], double [[TMP8]], i64 2
+; CHECK-NEXT:    [[TMP10:%.*]] = extractelement <2 x double> [[TMP3]], i64 0
+; CHECK-NEXT:    [[TMP11:%.*]] = insertelement <4 x double> [[TMP9]], double [[TMP10]], i64 3
+; CHECK-NEXT:    [[TMP12:%.*]] = extractelement <2 x double> [[TMP0]], i64 1
+; CHECK-NEXT:    [[TMP13:%.*]] = insertelement <4 x double> poison, double [[TMP12]], i64 0
+; CHECK-NEXT:    [[TMP14:%.*]] = extractelement <2 x double> [[TMP1]], i64 1
+; CHECK-NEXT:    [[TMP15:%.*]] = insertelement <4 x double> [[TMP13]], double [[TMP14]], i64 1
+; CHECK-NEXT:    [[TMP16:%.*]] = extractelement <2 x double> [[TMP2]], i64 1
+; CHECK-NEXT:    [[TMP17:%.*]] = insertelement <4 x double> [[TMP15]], double [[TMP16]], i64 2
+; CHECK-NEXT:    [[TMP18:%.*]] = extractelement <2 x double> [[TMP3]], i64 1
+; CHECK-NEXT:    [[TMP19:%.*]] = insertelement <4 x double> [[TMP17]], double [[TMP18]], i64 3
+; CHECK-NEXT:    [[TMP20:%.*]] = shufflevector <4 x double> [[TMP11]], <4 x double> [[TMP19]], <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
+; CHECK-NEXT:    ret <8 x double> [[TMP20]]
 ;
 
 
diff --git a/llvm/test/Transforms/LowerMatrixIntrinsics/transpose-opts-lifting.ll b/llvm/test/Transforms/LowerMatrixIntrinsics/transpose-opts-lifting.ll
index d0c67556224c89..fcf83b03bc3d23 100644
--- a/llvm/test/Transforms/LowerMatrixIntrinsics/transpose-opts-lifting.ll
+++ b/llvm/test/Transforms/LowerMatrixIntrinsics/transpose-opts-lifting.ll
@@ -9,7 +9,7 @@ define <6 x double> @lift_through_add_matching_transpose_dimensions(<6 x double>
 ; CHECK-LABEL:  define <6 x double> @lift_through_add_matching_transpose_dimensions(<6 x double> %a, <6 x double> %b) {
 ; CHECK-NEXT:  entry:
 ; CHECK-NEXT:    [[A:%.+]] = fadd <6 x double> %a, %b
-; CHECK-NEXT:    [[T:%.+]] = call <6 x double> @llvm.matrix.transpose.v6f64(<6 x double> [[A]], i32 2, i32 3)
+; CHECK-NEXT:    [[T:%.+]] = call <6 x double> @llvm.matrix.transpose.v6f64(<6 x double> [[A]], i32 3, i32 2)
 ; CHECK-NEXT:    ret <6 x double> [[T]]
 ;
 entry:
@@ -25,7 +25,7 @@ define <6 x double> @lift_through_add_matching_transpose_dimensions_ops_also_hav
 ; CHECK-NEXT:    [[A:%.+]] = load <6 x double>, ptr %a.ptr
 ; CHECK-NEXT:    [[B:%.+]] = load <6 x double>, ptr %b.ptr
 ; CHECK-NEXT:    [[ADD:%.+]] = fadd <6 x double> [[A]], [[B]]
-; CHECK-NEXT:    [[T:%.+]] = call <6 x double> @llvm.matrix.transpose.v6f64(<6 x double> [[ADD]], i32 2, i32 3)
+; CHECK-NEXT:    [[T:%.+]] = call <6 x double> @llvm.matrix.transpose.v6f64(<6 x double> [[ADD]], i32 3, i32 2)
 ; CHECK-NEXT:    ret <6 x double> [[T]]
 ;
 entry:
@@ -41,10 +41,28 @@ define <6 x double> @lift_through_add_mismatching_dimensions_1(<6 x double> %a,
 ; CHECK-LABEL:  define <6 x double> @lift_through_add_mismatching_dimensions_1(<6 x double> %a, <6 x double> %b) {
 ; CHECK-NEXT:  entry:
 ; CHECK-NEXT:    [[A:%.+]] = fadd <6 x double> %a, %b
-; CHECK-NEXT:    [[T:%.+]] = call <6 x double> @llvm.matrix.transpose.v6f64(<6 x double> [[A]], i32 2, i32 3)
+; CHECK-NEXT:    [[T:%.+]] = call <6 x double> @llvm.matrix.transpose.v6f64(<6 x double> [[A]], i32 1, i32 6)
+; CHECK-NEXT:    ret <6 x double> [[T]]
+;
+entry:
+  %a.t = call <6 x double> @llvm.matrix.transpose.v6f64(<6 x double> %a, i32 1, i32 6)
+  %b.t = call <6 x double> @llvm.matrix.transpose.v6f64(<6 x double> %b, i32 3, i32 2)
+  %add = fadd <6 x double> %a.t, %b.t
+  ret <6 x double> %add
+}
+
+define <6 x double> @lift_through_add_mismatching_dimensions_1_transpose_dimensions_ops_also_have_shape_info(ptr %a.ptr, ptr %b.ptr) {
+; CHECK-LABEL: define <6 x double> @lift_through_add_mismatching_dimensions_1_transpose_dimensions_ops_also_have_shape_info(ptr %a.ptr, ptr %b.ptr)
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[A:%.+]] = load <6 x double>, ptr %a.ptr
+; CHECK-NEXT:    [[B:%.+]] = load <6 x double>, ptr %b.ptr
+; CHECK-NEXT:    [[ADD:%.+]] = fadd <6 x double> [[A]], [[B]]
+; CHECK-NEXT:    [[T:%.+]] = call <6 x double> @llvm.matrix.transpose.v6f64(<6 x double> [[ADD]], i32 1, i32 6)
 ; CHECK-NEXT:    ret <6 x double> [[T]]
 ;
 entry:
+  %a = load <6 x double>, ptr %a.ptr
+  %b = load <6 x double>, ptr %b.ptr
   %a.t = call <6 x double> @llvm.matrix.transpose.v6f64(<6 x double> %a, i32 1, i32 6)
   %b.t = call <6 x double> @llvm.matrix.transpose.v6f64(<6 x double> %b, i32 3, i32 2)
   %add = fadd <6 x double> %a.t, %b.t
@@ -55,7 +73,7 @@ define <6 x double> @lift_through_add_mismatching_dimensions_2(<6 x double> %a,
 ; CHECK-LABEL:  define <6 x double> @lift_through_add_mismatching_dimensions_2(<6 x double> %a, <6 x double> %b) {
 ; CHECK-NEXT:  entry:
 ; CHECK-NEXT:    [[A:%.+]] = fadd <6 x double> %a, %b
-; CHECK-NEXT:    [[T:%.+]] = call <6 x double> @llvm.matrix.transpose.v6f64(<6 x double> [[A]], i32 1, i32 6)
+; CHECK-NEXT:    [[T:%.+]] = call <6 x double> @llvm.matrix.transpose.v6f64(<6 x double> [[A]], i32 3, i32 2)
 ; CHECK-NEXT:    ret <6 x double> [[T]]
 ;
 
@@ -66,6 +84,24 @@ entry:
   ret <6 x double> %add
 }
 
+define <6 x double> @lift_through_add_mismatching_dimensions_2_transpose_dimensions_ops_also_have_shape_info(ptr %a.ptr, ptr %b.ptr) {
+; CHECK-LABEL: define <6 x double> @lift_through_add_mismatching_dimensions_2_transpose_dimensions_ops_also_have_shape_info(ptr %a.ptr, ptr %b.ptr)
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[A:%.+]] = load <6 x double>, ptr %a.ptr
+; CHECK-NEXT:    [[B:%.+]] = load <6 x double>, ptr %b.ptr
+; CHECK-NEXT:    [[ADD:%.+]] = fadd <6 x double> [[A]], [[B]]
+; CHECK-NEXT:    [[T:%.+]] = call <6 x double> @llvm.matrix.transpose.v6f64(<6 x double> [[ADD]], i32 3, i32 2)
+; CHECK-NEXT:    ret <6 x double> [[T]]
+;
+entry:
+  %a = load <6 x double>, ptr %a.ptr
+  %b = load <6 x double>, ptr %b.ptr
+  %a.t = call <6 x double> @llvm.matrix.transpose.v6f64(<6 x double> %a, i32 3, i32 2)
+  %b.t = call <6 x double> @llvm.matrix.transpose.v6f64(<6 x double> %b, i32 6, i32 1)
+  %add = fadd <6 x double> %a.t, %b.t
+  ret <6 x double> %add
+}
+
 define <9 x double> @lift_through_multiply(<6 x double> %a, <6 x double> %b) {
 ; CHECK-LABEL: define <9 x double> @lift_through_multiply(<6 x double> %a, <6 x double> %b) {
 ; CHECK-NEXT:  entry:

>From 3a4296f520168f222e4efd7e4cd0226feb4c14da Mon Sep 17 00:00:00 2001
From: Florian Hahn <flo at fhahn.com>
Date: Mon, 12 Feb 2024 18:43:35 +0000
Subject: [PATCH 2/2] !fixup remove unneeded {}

---
 llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp | 3 +--
 1 file changed, 1 insertion(+), 2 deletions(-)

diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
index 075388f69a85b5..67c011b747acfd 100644
--- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
+++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
@@ -765,9 +765,8 @@ class LowerMatrixIntrinsics {
     auto S = ShapeMap.find(&Old);
     if (S != ShapeMap.end()) {
       ShapeMap.erase(S);
-      if (supportsShapeInfo(New)) {
+      if (supportsShapeInfo(New))
         ShapeMap.insert({New, S->second});
-      }
     }
     Old.replaceAllUsesWith(New);
   }



More information about the llvm-commits mailing list