[llvm] bfd3883 - [Matrix] Refactor transpose distribution. NFC

Francis Visoiu Mistrih via llvm-commits llvm-commits at lists.llvm.org
Thu Jul 28 17:30:16 PDT 2022


Author: Francis Visoiu Mistrih
Date: 2022-07-28T17:30:00-07:00
New Revision: bfd3883e83dd61c68bd920d933f8ca679b788ad9

URL: https://github.com/llvm/llvm-project/commit/bfd3883e83dd61c68bd920d933f8ca679b788ad9
DIFF: https://github.com/llvm/llvm-project/commit/bfd3883e83dd61c68bd920d933f8ca679b788ad9.diff

LOG: [Matrix] Refactor transpose distribution. NFC

Use a function to distribute transposes. Preparation for future patches.

Added: 
    

Modified: 
    llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
index f1e1359255bdd..c4281b6060fe1 100644
--- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
+++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
@@ -384,6 +384,9 @@ class LowerMatrixIntrinsics {
         return NumColumns;
       return NumRows;
     }
+
+    /// Returns the transposed shape.
+    ShapeInfo t() const { return ShapeInfo(NumColumns, NumRows); }
   };
 
   /// Maps instructions to their shape information. The shape information
@@ -684,6 +687,25 @@ class LowerMatrixIntrinsics {
     return NewWorkList;
   }
 
+  /// (Op0 op Op1)^T -> Op0^T op Op1^T
+  /// Transpose \p Op0 and \p Op1 of shape \p Shape0 and \p Shape1, then use
+  /// them on both sides of \p Operation.
+  Instruction *distributeTransposes(
+      Value *Op0, ShapeInfo Shape0, Value *Op1, ShapeInfo Shape1,
+      MatrixBuilder &Builder,
+      function_ref<Instruction *(Value *, ShapeInfo, Value *, ShapeInfo)>
+          Operation) {
+    Value *T0 = Builder.CreateMatrixTranspose(
+        Op0, Shape0.NumRows, Shape0.NumColumns, Op0->getName() + "_t");
+    // We are being run after shape prop, add shape for newly created
+    // instructions so that we lower them later.
+    setShapeInfo(T0, Shape0.t());
+    Value *T1 = Builder.CreateMatrixTranspose(
+        Op1, Shape1.NumRows, Shape1.NumColumns, Op1->getName() + "_t");
+    setShapeInfo(T1, Shape1.t());
+    return Operation(T0, Shape0.t(), T1, Shape1.t());
+  }
+
   /// Try moving transposes in order to fold them away or into multiplies.
   void optimizeTransposes() {
     auto ReplaceAllUsesWith = [this](Instruction &Old, Value *New) {
@@ -741,19 +763,13 @@ class LowerMatrixIntrinsics {
           else if (match(TA, m_Intrinsic<Intrinsic::matrix_multiply>(
                                  m_Value(TAMA), m_Value(TAMB), m_ConstantInt(R),
                                  m_ConstantInt(K), m_ConstantInt(C)))) {
-            Value *T0 = Builder.CreateMatrixTranspose(TAMB, K->getZExtValue(),
-                                                      C->getZExtValue(),
-                                                      TAMB->getName() + "_t");
-            // We are being run after shape prop, add shape for newly created
-            // instructions so that we lower them later.
-            setShapeInfo(T0, {C, K});
-            Value *T1 = Builder.CreateMatrixTranspose(TAMA, R->getZExtValue(),
-                                                      K->getZExtValue(),
-                                                      TAMA->getName() + "_t");
-            setShapeInfo(T1, {K, R});
-            NewInst = Builder.CreateMatrixMultiply(T0, T1, C->getZExtValue(),
-                                                   K->getZExtValue(),
-                                                   R->getZExtValue(), "mmul");
+            NewInst = distributeTransposes(
+                TAMB, {K, C}, TAMA, {R, K}, Builder,
+                [&](Value *T0, ShapeInfo Shape0, Value *T1, ShapeInfo Shape1) {
+                  return Builder.CreateMatrixMultiply(
+                      T0, T1, Shape0.NumRows, Shape0.NumColumns,
+                      Shape1.NumColumns, "mmul");
+                });
             ReplaceAllUsesWith(I, NewInst);
             EraseFromParent(&I);
             EraseFromParent(TA);


        


More information about the llvm-commits mailing list