[llvm] da09b35 - [Matrix] Optimize matrix transposes around additions

Francis Visoiu Mistrih via llvm-commits llvm-commits at lists.llvm.org
Wed Jan 11 15:22:55 PST 2023


Author: Francis Visoiu Mistrih
Date: 2023-01-11T15:21:59-08:00
New Revision: da09b35334aba76748a7531d96fd7e5ba1d66669

URL: https://github.com/llvm/llvm-project/commit/da09b35334aba76748a7531d96fd7e5ba1d66669
DIFF: https://github.com/llvm/llvm-project/commit/da09b35334aba76748a7531d96fd7e5ba1d66669.diff

LOG: [Matrix] Optimize matrix transposes around additions

First, sink the transposes to the operands to simplify redudant
ones. Then, lift them to reduce the number of realized transposes.

```
(A + B)^T -> A^T + B^T -> (A + B)^T
```

See tests for more examples.

Differential Revision: https://reviews.llvm.org/D133657

Added: 
    

Modified: 
    llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
    llvm/test/Transforms/LowerMatrixIntrinsics/after-transpose-opts.ll
    llvm/test/Transforms/LowerMatrixIntrinsics/propagate-backward.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
index 08c6406f0a30..17594b98c5bc 100644
--- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
+++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
@@ -93,6 +93,19 @@ static DISubprogram *getSubprogram(DIScope *Scope) {
   return cast<DILocalScope>(Scope)->getSubprogram();
 }
 
+/// Erase \p V from \p BB and move \II forward to avoid invalidating
+/// iterators.
+static void eraseFromParentAndMove(Value *V, BasicBlock::reverse_iterator &II,
+                                   BasicBlock &BB) {
+  auto *Inst = cast<Instruction>(V);
+  // Still used, don't erase.
+  if (!Inst->use_empty())
+    return;
+  if (II != BB.rend() && Inst == &*II)
+    ++II;
+  Inst->eraseFromParent();
+}
+
 /// Return true if V is a splat of a value (which is used when multiplying a
 /// matrix with a scalar).
 static bool isSplat(Value *V) {
@@ -107,6 +120,12 @@ auto m_AnyMul(const LTy &L, const RTy &R) {
   return m_CombineOr(m_Mul(L, R), m_FMul(L, R));
 }
 
+/// Match any add operation (fp or integer).
+template <typename LTy, typename RTy>
+auto m_AnyAdd(const LTy &L, const RTy &R) {
+  return m_CombineOr(m_Add(L, R), m_FAdd(L, R));
+}
+
 namespace {
 
 // Given an element pointer \p BasePtr to the start of a (sub) matrix, compute
@@ -725,135 +744,179 @@ class LowerMatrixIntrinsics {
     return Operation(T0, Shape0.t(), T1, Shape1.t());
   }
 
-  /// Try moving transposes in order to fold them away or into multiplies.
-  void optimizeTransposes() {
-    auto ReplaceAllUsesWith = [this](Instruction &Old, Value *New) {
-      // We need to remove Old from the ShapeMap otherwise RAUW will replace it
-      // with New. We should only add New it it supportsShapeInfo so we insert
-      // it conditionally instead.
-      auto S = ShapeMap.find(&Old);
-      if (S != ShapeMap.end()) {
-        ShapeMap.erase(S);
-        if (supportsShapeInfo(New))
-          ShapeMap.insert({New, S->second});
-      }
-      Old.replaceAllUsesWith(New);
+  void updateShapeAndReplaceAllUsesWith(Instruction &Old, Value *New) {
+    // We need to remove Old from the ShapeMap otherwise RAUW will replace it
+    // with New. We should only add New it it supportsShapeInfo so we insert
+    // it conditionally instead.
+    auto S = ShapeMap.find(&Old);
+    if (S != ShapeMap.end()) {
+      ShapeMap.erase(S);
+      if (supportsShapeInfo(New))
+        ShapeMap.insert({New, S->second});
+    }
+    Old.replaceAllUsesWith(New);
+  }
+
+  /// Sink a top-level transpose inside matmuls and adds.
+  /// This creates and erases instructions as needed, and returns the newly
+  /// created instruction while updating the iterator to avoid invalidation. If
+  /// this returns nullptr, no new instruction was created.
+  Instruction *sinkTranspose(Instruction &I, BasicBlock::reverse_iterator &II) {
+    BasicBlock &BB = *I.getParent();
+    IRBuilder<> IB(&I);
+    MatrixBuilder Builder(IB);
+
+    Value *TA, *TAMA, *TAMB;
+    ConstantInt *R, *K, *C;
+    if (!match(&I, m_Intrinsic<Intrinsic::matrix_transpose>(
+                       m_Value(TA), m_ConstantInt(R), m_ConstantInt(C))))
+      return nullptr;
+
+    // Transpose of a transpose is a nop
+    Value *TATA;
+    if (match(TA, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value(TATA)))) {
+      updateShapeAndReplaceAllUsesWith(I, TATA);
+      eraseFromParentAndMove(&I, II, BB);
+      eraseFromParentAndMove(TA, II, BB);
+      return nullptr;
+    }
+
+    // k^T -> k
+    if (isSplat(TA)) {
+      updateShapeAndReplaceAllUsesWith(I, TA);
+      eraseFromParentAndMove(&I, II, BB);
+      return nullptr;
+    }
+
+    // (A * B)^t -> B^t * A^t
+    // RxK KxC      CxK   KxR
+    if (match(TA, m_Intrinsic<Intrinsic::matrix_multiply>(
+                      m_Value(TAMA), m_Value(TAMB), m_ConstantInt(R),
+                      m_ConstantInt(K), m_ConstantInt(C)))) {
+      auto NewInst = distributeTransposes(
+          TAMB, {K, C}, TAMA, {R, K}, Builder,
+          [&](Value *T0, ShapeInfo Shape0, Value *T1, ShapeInfo Shape1) {
+            return Builder.CreateMatrixMultiply(T0, T1, Shape0.NumRows,
+                                                Shape0.NumColumns,
+                                                Shape1.NumColumns, "mmul");
+          });
+      updateShapeAndReplaceAllUsesWith(I, NewInst);
+      eraseFromParentAndMove(&I, II, BB);
+      eraseFromParentAndMove(TA, II, BB);
+      return NewInst;
+    }
+
+    // Same as above, but with a mul, which occurs when multiplied
+    // with a scalar.
+    // (A * k)^t -> A^t * k
+    //  R  x  C     RxC
+    if (match(TA, m_AnyMul(m_Value(TAMA), m_Value(TAMB))) &&
+        (isSplat(TAMA) || isSplat(TAMB))) {
+      IRBuilder<> LocalBuilder(&I);
+      // We know that the transposed operand is of shape RxC.
+      // An when multiplied with a scalar, the shape is preserved.
+      auto NewInst = distributeTransposes(
+          TAMA, {R, C}, TAMB, {R, C}, Builder,
+          [&](Value *T0, ShapeInfo Shape0, Value *T1, ShapeInfo Shape1) {
+            bool IsFP = I.getType()->isFPOrFPVectorTy();
+            auto *Mul = IsFP ? LocalBuilder.CreateFMul(T0, T1, "mmul")
+                             : LocalBuilder.CreateMul(T0, T1, "mmul");
+            auto *Result = cast<Instruction>(Mul);
+            setShapeInfo(Result, Shape0);
+            return Result;
+          });
+      updateShapeAndReplaceAllUsesWith(I, NewInst);
+      eraseFromParentAndMove(&I, II, BB);
+      eraseFromParentAndMove(TA, II, BB);
+      return NewInst;
+    }
+
+    // (A + B)^t -> A^t + B^t
+    // RxC RxC      CxR   CxR
+    if (match(TA, m_AnyAdd(m_Value(TAMA), m_Value(TAMB)))) {
+      IRBuilder<> LocalBuilder(&I);
+      auto NewInst = distributeTransposes(
+          TAMA, {R, C}, TAMB, {R, C}, Builder,
+          [&](Value *T0, ShapeInfo Shape0, Value *T1, ShapeInfo Shape1) {
+            auto *FAdd =
+                cast<Instruction>(LocalBuilder.CreateFAdd(T0, T1, "mfadd"));
+            setShapeInfo(FAdd, Shape0);
+            return FAdd;
+          });
+      updateShapeAndReplaceAllUsesWith(I, NewInst);
+      eraseFromParentAndMove(&I, II, BB);
+      eraseFromParentAndMove(TA, II, BB);
+      return NewInst;
+    }
+
+    return nullptr;
+  }
+
+  void liftTranspose(Instruction &I) {
+    // Erase dead Instructions after lifting transposes from binops.
+    auto CleanupBinOp = [](Instruction &T, Value *A, Value *B) {
+      if (T.use_empty())
+        T.eraseFromParent();
+      if (A->use_empty())
+        cast<Instruction>(A)->eraseFromParent();
+      if (A != B && B->use_empty())
+        cast<Instruction>(B)->eraseFromParent();
     };
 
-    // First sink all transposes inside matmuls, hoping that we end up with NN,
-    // NT or TN variants.
+    Value *A, *B, *AT, *BT;
+    ConstantInt *R, *K, *C;
+    // A^t * B ^t -> (B * A)^t
+    if (match(&I, m_Intrinsic<Intrinsic::matrix_multiply>(
+                      m_Value(A), m_Value(B), m_ConstantInt(R),
+                      m_ConstantInt(K), m_ConstantInt(C))) &&
+        match(A, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value(AT))) &&
+        match(B, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value((BT))))) {
+      IRBuilder<> IB(&I);
+      MatrixBuilder Builder(IB);
+      Value *M = Builder.CreateMatrixMultiply(
+          BT, AT, C->getZExtValue(), K->getZExtValue(), R->getZExtValue());
+      setShapeInfo(M, {C, R});
+      Instruction *NewInst = Builder.CreateMatrixTranspose(M, C->getZExtValue(),
+                                                           R->getZExtValue());
+      updateShapeAndReplaceAllUsesWith(I, NewInst);
+      CleanupBinOp(I, A, B);
+    }
+    // A^t + B ^t -> (A + B)^t
+    else if (match(&I, m_FAdd(m_Value(A), m_Value(B))) &&
+             match(A, m_Intrinsic<Intrinsic::matrix_transpose>(
+                          m_Value(AT), m_ConstantInt(R), m_ConstantInt(C))) &&
+             match(B, m_Intrinsic<Intrinsic::matrix_transpose>(
+                          m_Value(BT), m_ConstantInt(R), m_ConstantInt(C)))) {
+      IRBuilder<> Builder(&I);
+      Value *Add = cast<Instruction>(Builder.CreateFAdd(AT, BT, "mfadd"));
+      setShapeInfo(Add, {C, R});
+      MatrixBuilder MBuilder(Builder);
+      Instruction *NewInst = MBuilder.CreateMatrixTranspose(
+          Add, C->getZExtValue(), R->getZExtValue(), "mfadd_t");
+      updateShapeAndReplaceAllUsesWith(I, NewInst);
+      CleanupBinOp(I, A, B);
+    }
+  }
+
+  /// Try moving transposes in order to fold them away or into multiplies.
+  void optimizeTransposes() {
+    // First sink all transposes inside matmuls and adds, hoping that we end up
+    // with NN, NT or TN variants.
     for (BasicBlock &BB : reverse(Func)) {
       for (auto II = BB.rbegin(); II != BB.rend();) {
         Instruction &I = *II;
         // We may remove II.  By default continue on the next/prev instruction.
         ++II;
-        // If we were to erase II, move again.
-        auto EraseFromParent = [&II, &BB](Value *V) {
-          auto *Inst = cast<Instruction>(V);
-          if (Inst->use_empty()) {
-            if (II != BB.rend() && Inst == &*II) {
-              ++II;
-            }
-            Inst->eraseFromParent();
-          }
-        };
-
-        // If we're creating a new instruction, continue from there.
-        Instruction *NewInst = nullptr;
-
-        IRBuilder<> IB(&I);
-        MatrixBuilder Builder(IB);
-
-        Value *TA, *TAMA, *TAMB;
-        ConstantInt *R, *K, *C;
-        if (match(&I, m_Intrinsic<Intrinsic::matrix_transpose>(
-                          m_Value(TA), m_ConstantInt(R), m_ConstantInt(C)))) {
-          // Transpose of a transpose is a nop
-          Value *TATA;
-          if (match(TA,
-                    m_Intrinsic<Intrinsic::matrix_transpose>(m_Value(TATA)))) {
-            ReplaceAllUsesWith(I, TATA);
-            EraseFromParent(&I);
-            EraseFromParent(TA);
-          }
-          // k^T -> k
-          else if (isSplat(TA)) {
-            ReplaceAllUsesWith(I, TA);
-            EraseFromParent(&I);
-          }
-          // (A * B)^t -> B^t * A^t
-          // RxK KxC      CxK   KxR
-          else if (match(TA, m_Intrinsic<Intrinsic::matrix_multiply>(
-                                 m_Value(TAMA), m_Value(TAMB), m_ConstantInt(R),
-                                 m_ConstantInt(K), m_ConstantInt(C)))) {
-            NewInst = distributeTransposes(
-                TAMB, {K, C}, TAMA, {R, K}, Builder,
-                [&](Value *T0, ShapeInfo Shape0, Value *T1, ShapeInfo Shape1) {
-                  return Builder.CreateMatrixMultiply(
-                      T0, T1, Shape0.NumRows, Shape0.NumColumns,
-                      Shape1.NumColumns, "mmul");
-                });
-            ReplaceAllUsesWith(I, NewInst);
-            EraseFromParent(&I);
-            EraseFromParent(TA);
-            // Same as above, but with a mul, which occurs when multiplied
-            // with a scalar.
-            // (A * k)^t -> A^t * k
-            //  R  x  C     RxC
-          } else if (match(TA, m_AnyMul(m_Value(TAMA), m_Value(TAMB))) &&
-                     (isSplat(TAMA) || isSplat(TAMB))) {
-            IRBuilder<> LocalBuilder(&I);
-            // We know that the transposed operand is of shape RxC.
-            // An when multiplied with a scalar, the shape is preserved.
-            NewInst = distributeTransposes(
-                TAMA, {R, C}, TAMB, {R, C}, Builder,
-                [&](Value *T0, ShapeInfo Shape0, Value *T1, ShapeInfo Shape1) {
-                  bool IsFP = I.getType()->isFPOrFPVectorTy();
-                  auto *Mul = IsFP ? LocalBuilder.CreateFMul(T0, T1, "mmul")
-                                   : LocalBuilder.CreateMul(T0, T1, "mmul");
-                  auto *Result = cast<Instruction>(Mul);
-                  setShapeInfo(Result, Shape0);
-                  return Result;
-                });
-            ReplaceAllUsesWith(I, NewInst);
-            EraseFromParent(&I);
-            EraseFromParent(TA);
-          }
-        }
-
-        // If we replaced I with a new instruction, continue from there.
-        if (NewInst)
+        if (Instruction *NewInst = sinkTranspose(I, II))
           II = std::next(BasicBlock::reverse_iterator(NewInst));
       }
     }
 
-    // If we have a TT matmul, lift the transpose.  We may be able to fold into
-    // consuming multiply.
+    // If we have a TT matmul or a TT add, lift the transpose. We may be able
+    // to fold into consuming multiply or add.
     for (BasicBlock &BB : Func) {
       for (Instruction &I : llvm::make_early_inc_range(BB)) {
-        Value *A, *B, *AT, *BT;
-        ConstantInt *R, *K, *C;
-        // A^t * B ^t -> (B * A)^t
-        if (match(&I, m_Intrinsic<Intrinsic::matrix_multiply>(
-                          m_Value(A), m_Value(B), m_ConstantInt(R),
-                          m_ConstantInt(K), m_ConstantInt(C))) &&
-            match(A, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value(AT))) &&
-            match(B, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value((BT))))) {
-          IRBuilder<> IB(&I);
-          MatrixBuilder Builder(IB);
-          Value *M = Builder.CreateMatrixMultiply(
-              BT, AT, C->getZExtValue(), K->getZExtValue(), R->getZExtValue());
-          setShapeInfo(M, {C, R});
-          Instruction *NewInst = Builder.CreateMatrixTranspose(
-              M, C->getZExtValue(), R->getZExtValue());
-          ReplaceAllUsesWith(I, NewInst);
-          if (I.use_empty())
-            I.eraseFromParent();
-          if (A->use_empty())
-            cast<Instruction>(A)->eraseFromParent();
-          if (A != B && B->use_empty())
-            cast<Instruction>(B)->eraseFromParent();
-        }
+        liftTranspose(I);
       }
     }
   }

diff  --git a/llvm/test/Transforms/LowerMatrixIntrinsics/after-transpose-opts.ll b/llvm/test/Transforms/LowerMatrixIntrinsics/after-transpose-opts.ll
index 9f1670360191..7510f43d3198 100644
--- a/llvm/test/Transforms/LowerMatrixIntrinsics/after-transpose-opts.ll
+++ b/llvm/test/Transforms/LowerMatrixIntrinsics/after-transpose-opts.ll
@@ -100,10 +100,9 @@ define void @at_plus_bt(ptr %Aptr, ptr %Bptr, ptr %C) {
 ; CHECK-NEXT:  entry:
 ; CHECK-NEXT:    [[A:%.*]] = load <9 x double>, ptr [[APTR:%.*]], align 128
 ; CHECK-NEXT:    [[B:%.*]] = load <9 x double>, ptr [[BPTR:%.*]], align 128
-; CHECK-NEXT:    [[AT:%.*]] = call <9 x double> @llvm.matrix.transpose.v9f64(<9 x double> [[A]], i32 3, i32 3)
-; CHECK-NEXT:    [[BT:%.*]] = call <9 x double> @llvm.matrix.transpose.v9f64(<9 x double> [[B]], i32 3, i32 3)
-; CHECK-NEXT:    [[FADD:%.*]] = fadd <9 x double> [[AT]], [[BT]]
-; CHECK-NEXT:    store <9 x double> [[FADD]], ptr [[C:%.*]], align 128
+; CHECK-NEXT:    [[MFADD:%.*]] = fadd <9 x double> [[A]], [[B]]
+; CHECK-NEXT:    [[MFADD_T:%.*]] = call <9 x double> @llvm.matrix.transpose.v9f64(<9 x double> [[MFADD]], i32 3, i32 3)
+; CHECK-NEXT:    store <9 x double> [[MFADD_T]], ptr [[C:%.*]], align 128
 ; CHECK-NEXT:    ret void
 ;
 entry:
@@ -122,9 +121,9 @@ define void @a_plus_b_t(ptr %Aptr, ptr %Bptr, ptr %C) {
 ; CHECK-NEXT:  entry:
 ; CHECK-NEXT:    [[A:%.*]] = load <9 x double>, ptr [[APTR:%.*]], align 128
 ; CHECK-NEXT:    [[B:%.*]] = load <9 x double>, ptr [[BPTR:%.*]], align 128
-; CHECK-NEXT:    [[FADD:%.*]] = fadd <9 x double> [[A]], [[B]]
-; CHECK-NEXT:    [[T:%.*]] = call <9 x double> @llvm.matrix.transpose.v9f64(<9 x double> [[FADD]], i32 3, i32 3)
-; CHECK-NEXT:    store <9 x double> [[T]], ptr [[C:%.*]], align 128
+; CHECK-NEXT:    [[MFADD1:%.*]] = fadd <9 x double> [[A]], [[B]]
+; CHECK-NEXT:    [[MFADD_T:%.*]] = call <9 x double> @llvm.matrix.transpose.v9f64(<9 x double> [[MFADD1]], i32 3, i32 3)
+; CHECK-NEXT:    store <9 x double> [[MFADD_T]], ptr [[C:%.*]], align 128
 ; CHECK-NEXT:    ret void
 ;
 entry:
@@ -145,11 +144,10 @@ define void @atbt_plus_ctdt(ptr %Aptr, ptr %Bptr, ptr %Cptr, ptr %Dptr, ptr %E)
 ; CHECK-NEXT:    [[C:%.*]] = load <9 x double>, ptr [[CPTR:%.*]], align 128
 ; CHECK-NEXT:    [[D:%.*]] = load <9 x double>, ptr [[DPTR:%.*]], align 128
 ; CHECK-NEXT:    [[TMP0:%.*]] = call <9 x double> @llvm.matrix.multiply.v9f64.v9f64.v9f64(<9 x double> [[B]], <9 x double> [[A]], i32 3, i32 3, i32 3)
-; CHECK-NEXT:    [[TMP1:%.*]] = call <9 x double> @llvm.matrix.transpose.v9f64(<9 x double> [[TMP0]], i32 3, i32 3)
-; CHECK-NEXT:    [[TMP2:%.*]] = call <9 x double> @llvm.matrix.multiply.v9f64.v9f64.v9f64(<9 x double> [[D]], <9 x double> [[C]], i32 3, i32 3, i32 3)
-; CHECK-NEXT:    [[TMP3:%.*]] = call <9 x double> @llvm.matrix.transpose.v9f64(<9 x double> [[TMP2]], i32 3, i32 3)
-; CHECK-NEXT:    [[FADD:%.*]] = fadd <9 x double> [[TMP1]], [[TMP3]]
-; CHECK-NEXT:    store <9 x double> [[FADD]], ptr [[E:%.*]], align 128
+; CHECK-NEXT:    [[TMP1:%.*]] = call <9 x double> @llvm.matrix.multiply.v9f64.v9f64.v9f64(<9 x double> [[D]], <9 x double> [[C]], i32 3, i32 3, i32 3)
+; CHECK-NEXT:    [[MFADD:%.*]] = fadd <9 x double> [[TMP0]], [[TMP1]]
+; CHECK-NEXT:    [[MFADD_T:%.*]] = call <9 x double> @llvm.matrix.transpose.v9f64(<9 x double> [[MFADD]], i32 3, i32 3)
+; CHECK-NEXT:    store <9 x double> [[MFADD_T]], ptr [[E:%.*]], align 128
 ; CHECK-NEXT:    ret void
 ;
 entry:
@@ -200,17 +198,13 @@ define void @atbt_plus_kctdt_t(ptr %Aptr, ptr %Bptr, ptr %Cptr, ptr %Dptr, doubl
 ; CHECK-NEXT:    [[B:%.*]] = load <9 x double>, ptr [[BPTR:%.*]], align 128
 ; CHECK-NEXT:    [[C:%.*]] = load <9 x double>, ptr [[CPTR:%.*]], align 128
 ; CHECK-NEXT:    [[D:%.*]] = load <9 x double>, ptr [[DPTR:%.*]], align 128
-; CHECK-NEXT:    [[CT:%.*]] = call <9 x double> @llvm.matrix.transpose.v9f64(<9 x double> [[C]], i32 3, i32 3)
-; CHECK-NEXT:    [[DT:%.*]] = call <9 x double> @llvm.matrix.transpose.v9f64(<9 x double> [[D]], i32 3, i32 3)
-; CHECK-NEXT:    [[TMP0:%.*]] = call <9 x double> @llvm.matrix.multiply.v9f64.v9f64.v9f64(<9 x double> [[B]], <9 x double> [[A]], i32 3, i32 3, i32 3)
-; CHECK-NEXT:    [[TMP1:%.*]] = call <9 x double> @llvm.matrix.transpose.v9f64(<9 x double> [[TMP0]], i32 3, i32 3)
 ; CHECK-NEXT:    [[VECK:%.*]] = insertelement <9 x double> poison, double [[K:%.*]], i64 0
 ; CHECK-NEXT:    [[SPLAT:%.*]] = shufflevector <9 x double> [[VECK]], <9 x double> poison, <9 x i32> zeroinitializer
-; CHECK-NEXT:    [[KCT:%.*]] = call <9 x double> @llvm.matrix.multiply.v9f64.v9f64.v9f64(<9 x double> [[SPLAT]], <9 x double> [[CT]], i32 3, i32 3, i32 3)
-; CHECK-NEXT:    [[KCTDT:%.*]] = call <9 x double> @llvm.matrix.multiply.v9f64.v9f64.v9f64(<9 x double> [[KCT]], <9 x double> [[DT]], i32 3, i32 3, i32 3)
-; CHECK-NEXT:    [[FADD:%.*]] = fadd <9 x double> [[TMP1]], [[KCTDT]]
-; CHECK-NEXT:    [[T:%.*]] = call <9 x double> @llvm.matrix.transpose.v9f64(<9 x double> [[FADD]], i32 3, i32 3)
-; CHECK-NEXT:    store <9 x double> [[T]], ptr [[E:%.*]], align 128
+; CHECK-NEXT:    [[MMUL2:%.*]] = call <9 x double> @llvm.matrix.multiply.v9f64.v9f64.v9f64(<9 x double> [[B]], <9 x double> [[A]], i32 3, i32 3, i32 3)
+; CHECK-NEXT:    [[MMUL1:%.*]] = call <9 x double> @llvm.matrix.multiply.v9f64.v9f64.v9f64(<9 x double> [[C]], <9 x double> [[SPLAT]], i32 3, i32 3, i32 3)
+; CHECK-NEXT:    [[MMUL:%.*]] = call <9 x double> @llvm.matrix.multiply.v9f64.v9f64.v9f64(<9 x double> [[D]], <9 x double> [[MMUL1]], i32 3, i32 3, i32 3)
+; CHECK-NEXT:    [[MFADD:%.*]] = fadd <9 x double> [[MMUL2]], [[MMUL]]
+; CHECK-NEXT:    store <9 x double> [[MFADD]], ptr [[E:%.*]], align 128
 ; CHECK-NEXT:    ret void
 ;
 entry:

diff  --git a/llvm/test/Transforms/LowerMatrixIntrinsics/propagate-backward.ll b/llvm/test/Transforms/LowerMatrixIntrinsics/propagate-backward.ll
index 33a338dbc4ea..82ae93b31035 100644
--- a/llvm/test/Transforms/LowerMatrixIntrinsics/propagate-backward.ll
+++ b/llvm/test/Transforms/LowerMatrixIntrinsics/propagate-backward.ll
@@ -4,35 +4,31 @@
 define <8 x double> @fadd_transpose(<8 x double> %a, <8 x double> %b) {
 ; CHECK-LABEL: @fadd_transpose(
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    [[SPLIT:%.*]] = shufflevector <8 x double> [[A:%.*]], <8 x double> poison, <2 x i32> <i32 0, i32 1>
-; CHECK-NEXT:    [[SPLIT1:%.*]] = shufflevector <8 x double> [[A]], <8 x double> poison, <2 x i32> <i32 2, i32 3>
-; CHECK-NEXT:    [[SPLIT2:%.*]] = shufflevector <8 x double> [[A]], <8 x double> poison, <2 x i32> <i32 4, i32 5>
-; CHECK-NEXT:    [[SPLIT3:%.*]] = shufflevector <8 x double> [[A]], <8 x double> poison, <2 x i32> <i32 6, i32 7>
-; CHECK-NEXT:    [[SPLIT4:%.*]] = shufflevector <8 x double> [[B:%.*]], <8 x double> poison, <2 x i32> <i32 0, i32 1>
-; CHECK-NEXT:    [[SPLIT5:%.*]] = shufflevector <8 x double> [[B]], <8 x double> poison, <2 x i32> <i32 2, i32 3>
-; CHECK-NEXT:    [[SPLIT6:%.*]] = shufflevector <8 x double> [[B]], <8 x double> poison, <2 x i32> <i32 4, i32 5>
-; CHECK-NEXT:    [[SPLIT7:%.*]] = shufflevector <8 x double> [[B]], <8 x double> poison, <2 x i32> <i32 6, i32 7>
-; CHECK-NEXT:    [[TMP0:%.*]] = fadd <2 x double> [[SPLIT]], [[SPLIT4]]
-; CHECK-NEXT:    [[TMP1:%.*]] = fadd <2 x double> [[SPLIT1]], [[SPLIT5]]
-; CHECK-NEXT:    [[TMP2:%.*]] = fadd <2 x double> [[SPLIT2]], [[SPLIT6]]
-; CHECK-NEXT:    [[TMP3:%.*]] = fadd <2 x double> [[SPLIT3]], [[SPLIT7]]
-; CHECK-NEXT:    [[TMP4:%.*]] = extractelement <2 x double> [[TMP0]], i64 0
-; CHECK-NEXT:    [[TMP5:%.*]] = insertelement <4 x double> poison, double [[TMP4]], i64 0
-; CHECK-NEXT:    [[TMP6:%.*]] = extractelement <2 x double> [[TMP1]], i64 0
-; CHECK-NEXT:    [[TMP7:%.*]] = insertelement <4 x double> [[TMP5]], double [[TMP6]], i64 1
-; CHECK-NEXT:    [[TMP8:%.*]] = extractelement <2 x double> [[TMP2]], i64 0
-; CHECK-NEXT:    [[TMP9:%.*]] = insertelement <4 x double> [[TMP7]], double [[TMP8]], i64 2
-; CHECK-NEXT:    [[TMP10:%.*]] = extractelement <2 x double> [[TMP3]], i64 0
-; CHECK-NEXT:    [[TMP11:%.*]] = insertelement <4 x double> [[TMP9]], double [[TMP10]], i64 3
-; CHECK-NEXT:    [[TMP12:%.*]] = extractelement <2 x double> [[TMP0]], i64 1
-; CHECK-NEXT:    [[TMP13:%.*]] = insertelement <4 x double> poison, double [[TMP12]], i64 0
-; CHECK-NEXT:    [[TMP14:%.*]] = extractelement <2 x double> [[TMP1]], i64 1
-; CHECK-NEXT:    [[TMP15:%.*]] = insertelement <4 x double> [[TMP13]], double [[TMP14]], i64 1
-; CHECK-NEXT:    [[TMP16:%.*]] = extractelement <2 x double> [[TMP2]], i64 1
-; CHECK-NEXT:    [[TMP17:%.*]] = insertelement <4 x double> [[TMP15]], double [[TMP16]], i64 2
-; CHECK-NEXT:    [[TMP18:%.*]] = extractelement <2 x double> [[TMP3]], i64 1
-; CHECK-NEXT:    [[TMP19:%.*]] = insertelement <4 x double> [[TMP17]], double [[TMP18]], i64 3
-; CHECK-NEXT:    [[TMP20:%.*]] = shufflevector <4 x double> [[TMP11]], <4 x double> [[TMP19]], <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
+; CHECK-NEXT:    [[SPLIT:%.*]] = shufflevector <8 x double> [[A:%.*]], <8 x double> poison, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
+; CHECK-NEXT:    [[SPLIT2:%.*]] = shufflevector <8 x double> [[A]], <8 x double> poison, <4 x i32> <i32 4, i32 5, i32 6, i32 7>
+; CHECK-NEXT:    [[SPLIT3:%.*]] = shufflevector <8 x double> [[B:%.*]], <8 x double> poison, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
+; CHECK-NEXT:    [[SPLIT4:%.*]] = shufflevector <8 x double> [[B]], <8 x double> poison, <4 x i32> <i32 4, i32 5, i32 6, i32 7>
+; CHECK-NEXT:    [[TMP0:%.*]] = fadd <4 x double> [[SPLIT]], [[SPLIT3]]
+; CHECK-NEXT:    [[TMP1:%.*]] = fadd <4 x double> [[SPLIT2]], [[SPLIT4]]
+; CHECK-NEXT:    [[TMP2:%.*]] = extractelement <4 x double> [[TMP0]], i64 0
+; CHECK-NEXT:    [[TMP3:%.*]] = insertelement <2 x double> poison, double [[TMP2]], i64 0
+; CHECK-NEXT:    [[TMP4:%.*]] = extractelement <4 x double> [[TMP1]], i64 0
+; CHECK-NEXT:    [[TMP5:%.*]] = insertelement <2 x double> [[TMP3]], double [[TMP4]], i64 1
+; CHECK-NEXT:    [[TMP6:%.*]] = extractelement <4 x double> [[TMP0]], i64 1
+; CHECK-NEXT:    [[TMP7:%.*]] = insertelement <2 x double> poison, double [[TMP6]], i64 0
+; CHECK-NEXT:    [[TMP8:%.*]] = extractelement <4 x double> [[TMP1]], i64 1
+; CHECK-NEXT:    [[TMP9:%.*]] = insertelement <2 x double> [[TMP7]], double [[TMP8]], i64 1
+; CHECK-NEXT:    [[TMP10:%.*]] = extractelement <4 x double> [[TMP0]], i64 2
+; CHECK-NEXT:    [[TMP11:%.*]] = insertelement <2 x double> poison, double [[TMP10]], i64 0
+; CHECK-NEXT:    [[TMP12:%.*]] = extractelement <4 x double> [[TMP1]], i64 2
+; CHECK-NEXT:    [[TMP13:%.*]] = insertelement <2 x double> [[TMP11]], double [[TMP12]], i64 1
+; CHECK-NEXT:    [[TMP14:%.*]] = extractelement <4 x double> [[TMP0]], i64 3
+; CHECK-NEXT:    [[TMP15:%.*]] = insertelement <2 x double> poison, double [[TMP14]], i64 0
+; CHECK-NEXT:    [[TMP16:%.*]] = extractelement <4 x double> [[TMP1]], i64 3
+; CHECK-NEXT:    [[TMP17:%.*]] = insertelement <2 x double> [[TMP15]], double [[TMP16]], i64 1
+; CHECK-NEXT:    [[TMP18:%.*]] = shufflevector <2 x double> [[TMP5]], <2 x double> [[TMP9]], <4 x i32> <i32 0, i32 1, i32 2, i32 3>
+; CHECK-NEXT:    [[TMP19:%.*]] = shufflevector <2 x double> [[TMP13]], <2 x double> [[TMP17]], <4 x i32> <i32 0, i32 1, i32 2, i32 3>
+; CHECK-NEXT:    [[TMP20:%.*]] = shufflevector <4 x double> [[TMP18]], <4 x double> [[TMP19]], <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
 ; CHECK-NEXT:    ret <8 x double> [[TMP20]]
 ;
 entry:
@@ -46,37 +42,40 @@ define <8 x double> @load_fadd_transpose(ptr %A.Ptr, <8 x double> %b) {
 ; CHECK-NEXT:  entry:
 ; CHECK-NEXT:    [[COL_LOAD:%.*]] = load <2 x double>, ptr [[A_PTR:%.*]], align 8
 ; CHECK-NEXT:    [[VEC_GEP:%.*]] = getelementptr double, ptr [[A_PTR]], i64 2
-; CHECK-NEXT:    [[COL_LOAD1:%.*]] = load <2 x double>, ptr [[VEC_GEP]], align 8
-; CHECK-NEXT:    [[VEC_GEP2:%.*]] = getelementptr double, ptr [[A_PTR]], i64 4
-; CHECK-NEXT:    [[COL_LOAD3:%.*]] = load <2 x double>, ptr [[VEC_GEP2]], align 8
-; CHECK-NEXT:    [[VEC_GEP4:%.*]] = getelementptr double, ptr [[A_PTR]], i64 6
-; CHECK-NEXT:    [[COL_LOAD5:%.*]] = load <2 x double>, ptr [[VEC_GEP4]], align 8
-; CHECK-NEXT:    [[SPLIT:%.*]] = shufflevector <8 x double> [[B:%.*]], <8 x double> poison, <2 x i32> <i32 0, i32 1>
-; CHECK-NEXT:    [[SPLIT6:%.*]] = shufflevector <8 x double> [[B]], <8 x double> poison, <2 x i32> <i32 2, i32 3>
-; CHECK-NEXT:    [[SPLIT7:%.*]] = shufflevector <8 x double> [[B]], <8 x double> poison, <2 x i32> <i32 4, i32 5>
-; CHECK-NEXT:    [[SPLIT8:%.*]] = shufflevector <8 x double> [[B]], <8 x double> poison, <2 x i32> <i32 6, i32 7>
-; CHECK-NEXT:    [[TMP0:%.*]] = fadd <2 x double> [[COL_LOAD]], [[SPLIT]]
-; CHECK-NEXT:    [[TMP1:%.*]] = fadd <2 x double> [[COL_LOAD1]], [[SPLIT6]]
-; CHECK-NEXT:    [[TMP2:%.*]] = fadd <2 x double> [[COL_LOAD3]], [[SPLIT7]]
-; CHECK-NEXT:    [[TMP3:%.*]] = fadd <2 x double> [[COL_LOAD5]], [[SPLIT8]]
-; CHECK-NEXT:    [[TMP4:%.*]] = extractelement <2 x double> [[TMP0]], i64 0
-; CHECK-NEXT:    [[TMP5:%.*]] = insertelement <4 x double> poison, double [[TMP4]], i64 0
-; CHECK-NEXT:    [[TMP6:%.*]] = extractelement <2 x double> [[TMP1]], i64 0
-; CHECK-NEXT:    [[TMP7:%.*]] = insertelement <4 x double> [[TMP5]], double [[TMP6]], i64 1
-; CHECK-NEXT:    [[TMP8:%.*]] = extractelement <2 x double> [[TMP2]], i64 0
-; CHECK-NEXT:    [[TMP9:%.*]] = insertelement <4 x double> [[TMP7]], double [[TMP8]], i64 2
-; CHECK-NEXT:    [[TMP10:%.*]] = extractelement <2 x double> [[TMP3]], i64 0
-; CHECK-NEXT:    [[TMP11:%.*]] = insertelement <4 x double> [[TMP9]], double [[TMP10]], i64 3
-; CHECK-NEXT:    [[TMP12:%.*]] = extractelement <2 x double> [[TMP0]], i64 1
-; CHECK-NEXT:    [[TMP13:%.*]] = insertelement <4 x double> poison, double [[TMP12]], i64 0
-; CHECK-NEXT:    [[TMP14:%.*]] = extractelement <2 x double> [[TMP1]], i64 1
-; CHECK-NEXT:    [[TMP15:%.*]] = insertelement <4 x double> [[TMP13]], double [[TMP14]], i64 1
-; CHECK-NEXT:    [[TMP16:%.*]] = extractelement <2 x double> [[TMP2]], i64 1
-; CHECK-NEXT:    [[TMP17:%.*]] = insertelement <4 x double> [[TMP15]], double [[TMP16]], i64 2
-; CHECK-NEXT:    [[TMP18:%.*]] = extractelement <2 x double> [[TMP3]], i64 1
-; CHECK-NEXT:    [[TMP19:%.*]] = insertelement <4 x double> [[TMP17]], double [[TMP18]], i64 3
-; CHECK-NEXT:    [[TMP20:%.*]] = shufflevector <4 x double> [[TMP11]], <4 x double> [[TMP19]], <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
-; CHECK-NEXT:    ret <8 x double> [[TMP20]]
+; CHECK-NEXT:    [[COL_LOAD2:%.*]] = load <2 x double>, ptr [[VEC_GEP]], align 8
+; CHECK-NEXT:    [[VEC_GEP3:%.*]] = getelementptr double, ptr [[A_PTR]], i64 4
+; CHECK-NEXT:    [[COL_LOAD4:%.*]] = load <2 x double>, ptr [[VEC_GEP3]], align 8
+; CHECK-NEXT:    [[VEC_GEP5:%.*]] = getelementptr double, ptr [[A_PTR]], i64 6
+; CHECK-NEXT:    [[COL_LOAD6:%.*]] = load <2 x double>, ptr [[VEC_GEP5]], align 8
+; CHECK-NEXT:    [[TMP0:%.*]] = shufflevector <2 x double> [[COL_LOAD]], <2 x double> [[COL_LOAD2]], <4 x i32> <i32 0, i32 1, i32 2, i32 3>
+; CHECK-NEXT:    [[TMP1:%.*]] = shufflevector <2 x double> [[COL_LOAD4]], <2 x double> [[COL_LOAD6]], <4 x i32> <i32 0, i32 1, i32 2, i32 3>
+; CHECK-NEXT:    [[TMP2:%.*]] = shufflevector <4 x double> [[TMP0]], <4 x double> [[TMP1]], <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
+; CHECK-NEXT:    [[SPLIT:%.*]] = shufflevector <8 x double> [[TMP2]], <8 x double> poison, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
+; CHECK-NEXT:    [[SPLIT7:%.*]] = shufflevector <8 x double> [[TMP2]], <8 x double> poison, <4 x i32> <i32 4, i32 5, i32 6, i32 7>
+; CHECK-NEXT:    [[SPLIT8:%.*]] = shufflevector <8 x double> [[B:%.*]], <8 x double> poison, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
+; CHECK-NEXT:    [[SPLIT9:%.*]] = shufflevector <8 x double> [[B]], <8 x double> poison, <4 x i32> <i32 4, i32 5, i32 6, i32 7>
+; CHECK-NEXT:    [[TMP3:%.*]] = fadd <4 x double> [[SPLIT]], [[SPLIT8]]
+; CHECK-NEXT:    [[TMP4:%.*]] = fadd <4 x double> [[SPLIT7]], [[SPLIT9]]
+; CHECK-NEXT:    [[TMP5:%.*]] = extractelement <4 x double> [[TMP3]], i64 0
+; CHECK-NEXT:    [[TMP6:%.*]] = insertelement <2 x double> poison, double [[TMP5]], i64 0
+; CHECK-NEXT:    [[TMP7:%.*]] = extractelement <4 x double> [[TMP4]], i64 0
+; CHECK-NEXT:    [[TMP8:%.*]] = insertelement <2 x double> [[TMP6]], double [[TMP7]], i64 1
+; CHECK-NEXT:    [[TMP9:%.*]] = extractelement <4 x double> [[TMP3]], i64 1
+; CHECK-NEXT:    [[TMP10:%.*]] = insertelement <2 x double> poison, double [[TMP9]], i64 0
+; CHECK-NEXT:    [[TMP11:%.*]] = extractelement <4 x double> [[TMP4]], i64 1
+; CHECK-NEXT:    [[TMP12:%.*]] = insertelement <2 x double> [[TMP10]], double [[TMP11]], i64 1
+; CHECK-NEXT:    [[TMP13:%.*]] = extractelement <4 x double> [[TMP3]], i64 2
+; CHECK-NEXT:    [[TMP14:%.*]] = insertelement <2 x double> poison, double [[TMP13]], i64 0
+; CHECK-NEXT:    [[TMP15:%.*]] = extractelement <4 x double> [[TMP4]], i64 2
+; CHECK-NEXT:    [[TMP16:%.*]] = insertelement <2 x double> [[TMP14]], double [[TMP15]], i64 1
+; CHECK-NEXT:    [[TMP17:%.*]] = extractelement <4 x double> [[TMP3]], i64 3
+; CHECK-NEXT:    [[TMP18:%.*]] = insertelement <2 x double> poison, double [[TMP17]], i64 0
+; CHECK-NEXT:    [[TMP19:%.*]] = extractelement <4 x double> [[TMP4]], i64 3
+; CHECK-NEXT:    [[TMP20:%.*]] = insertelement <2 x double> [[TMP18]], double [[TMP19]], i64 1
+; CHECK-NEXT:    [[TMP21:%.*]] = shufflevector <2 x double> [[TMP8]], <2 x double> [[TMP12]], <4 x i32> <i32 0, i32 1, i32 2, i32 3>
+; CHECK-NEXT:    [[TMP22:%.*]] = shufflevector <2 x double> [[TMP16]], <2 x double> [[TMP20]], <4 x i32> <i32 0, i32 1, i32 2, i32 3>
+; CHECK-NEXT:    [[TMP23:%.*]] = shufflevector <4 x double> [[TMP21]], <4 x double> [[TMP22]], <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
+; CHECK-NEXT:    ret <8 x double> [[TMP23]]
 ;
 
 


        


More information about the llvm-commits mailing list