[llvm] 796fb2e - [Matrix] Move multiply-add code generation into separate function (NFC).

Florian Hahn via llvm-commits llvm-commits at lists.llvm.org
Thu Mar 19 13:26:51 PDT 2020


Author: Florian Hahn
Date: 2020-03-19T20:26:19Z
New Revision: 796fb2e474989ce91af297b2e283115d9f4ca496

URL: https://github.com/llvm/llvm-project/commit/796fb2e474989ce91af297b2e283115d9f4ca496
DIFF: https://github.com/llvm/llvm-project/commit/796fb2e474989ce91af297b2e283115d9f4ca496.diff

LOG: [Matrix] Move multiply-add code generation into separate function (NFC).

This logic can be shared with the tiled code generation.

Reviewers: anemet, Gerolf, hfinkel, andrew.w.kaylor, LuoYuanke

Reviewed By: anemet

Differential Revision: https://reviews.llvm.org/D75565

Added: 
    

Modified: 
    llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
index 27ddb28aaa46..4044daeef77d 100644
--- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
+++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
@@ -181,6 +181,10 @@ class LowerMatrixIntrinsics {
 
     void setColumn(unsigned i, Value *V) { Columns[i] = V; }
 
+    Type *getElementType() {
+      return cast<VectorType>(Columns[0]->getType())->getElementType();
+    }
+
     unsigned getNumColumns() const { return Columns.size(); }
     unsigned getNumRows() const {
       assert(Columns.size() > 0 && "Cannot call getNumRows without columns");
@@ -848,6 +852,49 @@ class LowerMatrixIntrinsics {
     }
   }
 
+  /// Compute Res += A * B for tile-sized matrices with left-associating
+  /// addition.
+  void emitChainedMatrixMultiply(ColumnMatrixTy &Result,
+                                 const ColumnMatrixTy &A,
+                                 const ColumnMatrixTy &B, bool AllowContraction,
+                                 IRBuilder<> &Builder, bool isTiled) {
+    const unsigned VF = std::max<unsigned>(
+        TTI.getRegisterBitWidth(true) /
+            Result.getElementType()->getPrimitiveSizeInBits().getFixedSize(),
+        1U);
+    unsigned R = Result.getNumRows();
+    unsigned C = Result.getNumColumns();
+    unsigned M = A.getNumColumns();
+
+    for (unsigned J = 0; J < C; ++J) {
+      unsigned BlockSize = VF;
+
+      // If Result is zero, we don't need to accumulate in the K==0 iteration.
+      bool isSumZero = isa<ConstantAggregateZero>(Result.getColumn(J));
+
+      unsigned NumOps = 0;
+      for (unsigned I = 0; I < R; I += BlockSize) {
+        // Gradually lower the vectorization factor to cover the remainder.
+        while (I + BlockSize > R)
+          BlockSize /= 2;
+
+        Value *Sum =
+            isTiled ? extractVector(Result, I, J, BlockSize, Builder) : nullptr;
+        for (unsigned K = 0; K < M; ++K) {
+          Value *L = extractVector(A, I, K, BlockSize, Builder);
+          Value *RH = Builder.CreateExtractElement(B.getColumn(J), K);
+          Value *Splat = Builder.CreateVectorSplat(BlockSize, RH, "splat");
+          Sum = createMulAdd(isSumZero && K == 0 ? nullptr : Sum, L, Splat,
+                             Result.getElementType()->isFloatingPointTy(),
+                             Builder, AllowContraction, NumOps);
+        }
+        Result.setColumn(J, insertVector(Result.getColumn(J), I, Sum, Builder));
+      }
+
+      Result.addNumComputeOps(NumOps);
+    }
+  }
+
   /// Lowers llvm.matrix.multiply.
   void LowerMultiply(CallInst *MatMul) {
     IRBuilder<> Builder(MatMul);
@@ -870,35 +917,9 @@ class LowerMatrixIntrinsics {
     for (unsigned J = 0; J < C; ++J)
       Result.addColumn(UndefValue::get(VectorType::get(EltType, R)));
 
-    const unsigned VF = std::max(TTI.getRegisterBitWidth(true) /
-                                     EltType->getPrimitiveSizeInBits(),
-                                 uint64_t(1));
-
     bool AllowContract = AllowContractEnabled || (isa<FPMathOperator>(MatMul) &&
                                                   MatMul->hasAllowContract());
-    unsigned NumComputeOps = 0;
-    // Multiply columns from the first operand with scalars from the second
-    // operand.  Then move along the K axes and accumulate the columns.  With
-    // this the adds can be vectorized without reassociation.
-    for (unsigned J = 0; J < C; ++J) {
-      unsigned BlockSize = VF;
-      for (unsigned I = 0; I < R; I += BlockSize) {
-        // Gradually lower the vectorization factor to cover the remainder.
-        while (I + BlockSize > R)
-          BlockSize /= 2;
-
-        Value *Sum = nullptr;
-        for (unsigned K = 0; K < M; ++K) {
-          Value *L = extractVector(Lhs, I, K, BlockSize, Builder);
-          Value *RH = Builder.CreateExtractElement(Rhs.getColumn(J), K);
-          Value *Splat = Builder.CreateVectorSplat(BlockSize, RH, "splat");
-          Sum = createMulAdd(Sum, L, Splat, EltType->isFloatingPointTy(),
-                             Builder, AllowContract, NumComputeOps);
-        }
-        Result.setColumn(J, insertVector(Result.getColumn(J), I, Sum, Builder));
-      }
-    }
-    Result.addNumComputeOps(NumComputeOps);
+    emitChainedMatrixMultiply(Result, Lhs, Rhs, AllowContract, Builder, false);
     finalizeLowering(MatMul, Result, Builder);
   }
 


        


More information about the llvm-commits mailing list