[llvm] 796fb2e - [Matrix] Move multiply-add code generation into separate function (NFC).
Florian Hahn via llvm-commits
llvm-commits at lists.llvm.org
Thu Mar 19 13:26:51 PDT 2020
Author: Florian Hahn
Date: 2020-03-19T20:26:19Z
New Revision: 796fb2e474989ce91af297b2e283115d9f4ca496
URL: https://github.com/llvm/llvm-project/commit/796fb2e474989ce91af297b2e283115d9f4ca496
DIFF: https://github.com/llvm/llvm-project/commit/796fb2e474989ce91af297b2e283115d9f4ca496.diff
LOG: [Matrix] Move multiply-add code generation into separate function (NFC).
This logic can be shared with the tiled code generation.
Reviewers: anemet, Gerolf, hfinkel, andrew.w.kaylor, LuoYuanke
Reviewed By: anemet
Differential Revision: https://reviews.llvm.org/D75565
Added:
Modified:
llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
Removed:
################################################################################
diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
index 27ddb28aaa46..4044daeef77d 100644
--- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
+++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
@@ -181,6 +181,10 @@ class LowerMatrixIntrinsics {
void setColumn(unsigned i, Value *V) { Columns[i] = V; }
+ Type *getElementType() {
+ return cast<VectorType>(Columns[0]->getType())->getElementType();
+ }
+
unsigned getNumColumns() const { return Columns.size(); }
unsigned getNumRows() const {
assert(Columns.size() > 0 && "Cannot call getNumRows without columns");
@@ -848,6 +852,49 @@ class LowerMatrixIntrinsics {
}
}
+ /// Compute Res += A * B for tile-sized matrices with left-associating
+ /// addition.
+ void emitChainedMatrixMultiply(ColumnMatrixTy &Result,
+ const ColumnMatrixTy &A,
+ const ColumnMatrixTy &B, bool AllowContraction,
+ IRBuilder<> &Builder, bool isTiled) {
+ const unsigned VF = std::max<unsigned>(
+ TTI.getRegisterBitWidth(true) /
+ Result.getElementType()->getPrimitiveSizeInBits().getFixedSize(),
+ 1U);
+ unsigned R = Result.getNumRows();
+ unsigned C = Result.getNumColumns();
+ unsigned M = A.getNumColumns();
+
+ for (unsigned J = 0; J < C; ++J) {
+ unsigned BlockSize = VF;
+
+ // If Result is zero, we don't need to accumulate in the K==0 iteration.
+ bool isSumZero = isa<ConstantAggregateZero>(Result.getColumn(J));
+
+ unsigned NumOps = 0;
+ for (unsigned I = 0; I < R; I += BlockSize) {
+ // Gradually lower the vectorization factor to cover the remainder.
+ while (I + BlockSize > R)
+ BlockSize /= 2;
+
+ Value *Sum =
+ isTiled ? extractVector(Result, I, J, BlockSize, Builder) : nullptr;
+ for (unsigned K = 0; K < M; ++K) {
+ Value *L = extractVector(A, I, K, BlockSize, Builder);
+ Value *RH = Builder.CreateExtractElement(B.getColumn(J), K);
+ Value *Splat = Builder.CreateVectorSplat(BlockSize, RH, "splat");
+ Sum = createMulAdd(isSumZero && K == 0 ? nullptr : Sum, L, Splat,
+ Result.getElementType()->isFloatingPointTy(),
+ Builder, AllowContraction, NumOps);
+ }
+ Result.setColumn(J, insertVector(Result.getColumn(J), I, Sum, Builder));
+ }
+
+ Result.addNumComputeOps(NumOps);
+ }
+ }
+
/// Lowers llvm.matrix.multiply.
void LowerMultiply(CallInst *MatMul) {
IRBuilder<> Builder(MatMul);
@@ -870,35 +917,9 @@ class LowerMatrixIntrinsics {
for (unsigned J = 0; J < C; ++J)
Result.addColumn(UndefValue::get(VectorType::get(EltType, R)));
- const unsigned VF = std::max(TTI.getRegisterBitWidth(true) /
- EltType->getPrimitiveSizeInBits(),
- uint64_t(1));
-
bool AllowContract = AllowContractEnabled || (isa<FPMathOperator>(MatMul) &&
MatMul->hasAllowContract());
- unsigned NumComputeOps = 0;
- // Multiply columns from the first operand with scalars from the second
- // operand. Then move along the K axes and accumulate the columns. With
- // this the adds can be vectorized without reassociation.
- for (unsigned J = 0; J < C; ++J) {
- unsigned BlockSize = VF;
- for (unsigned I = 0; I < R; I += BlockSize) {
- // Gradually lower the vectorization factor to cover the remainder.
- while (I + BlockSize > R)
- BlockSize /= 2;
-
- Value *Sum = nullptr;
- for (unsigned K = 0; K < M; ++K) {
- Value *L = extractVector(Lhs, I, K, BlockSize, Builder);
- Value *RH = Builder.CreateExtractElement(Rhs.getColumn(J), K);
- Value *Splat = Builder.CreateVectorSplat(BlockSize, RH, "splat");
- Sum = createMulAdd(Sum, L, Splat, EltType->isFloatingPointTy(),
- Builder, AllowContract, NumComputeOps);
- }
- Result.setColumn(J, insertVector(Result.getColumn(J), I, Sum, Builder));
- }
- }
- Result.addNumComputeOps(NumComputeOps);
+ emitChainedMatrixMultiply(Result, Lhs, Rhs, AllowContract, Builder, false);
finalizeLowering(MatMul, Result, Builder);
}
More information about the llvm-commits
mailing list