[Mlir-commits] [mlir] [MLIR][Linalg] Transform pass to optimize and lower vector contract operation to FMA (PR #121748)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Jan 6 21:21:54 PST 2025
https://github.com/arun-thmn updated https://github.com/llvm/llvm-project/pull/121748
>From 84130f65a351a731f6ba153f5b147db59d7ec09c Mon Sep 17 00:00:00 2001
From: Arun Thangamani <arun.thangamani at intel.com>
Date: Thu, 2 Jan 2025 07:27:29 -0800
Subject: [PATCH 1/3] =?UTF-8?q?initial=20changes=20for=20upstreaming=20hoi?=
=?UTF-8?q?st=20vector=20transfers=C2=A0and=20contract=20to=20fma?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
.../Linalg/TransformOps/LinalgTransformOps.td | 24 ++
.../Dialect/Linalg/Transforms/Transforms.h | 10 +
.../TransformOps/LinalgTransformOps.cpp | 13 +
.../Dialect/Linalg/Transforms/CMakeLists.txt | 2 +
.../Transforms/HoistVectorTransfers.cpp | 234 ++++++++++++
.../Linalg/Transforms/VectorContractToFMA.cpp | 356 ++++++++++++++++++
.../Dialect/Linalg/hoist-vector-transfer.mlir | 97 +++++
.../Linalg/vector-contract-to-fma.mlir | 113 ++++++
8 files changed, 849 insertions(+)
create mode 100644 mlir/lib/Dialect/Linalg/Transforms/HoistVectorTransfers.cpp
create mode 100644 mlir/lib/Dialect/Linalg/Transforms/VectorContractToFMA.cpp
create mode 100644 mlir/test/Dialect/Linalg/hoist-vector-transfer.mlir
create mode 100644 mlir/test/Dialect/Linalg/vector-contract-to-fma.mlir
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 2e713bca24efc5..d614bb4789767b 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -106,6 +106,30 @@ def ApplyFoldAddIntoDestPatternsOp : Op<Transform_Dialect,
let assemblyFormat = "attr-dict";
}
+
+def ApplyHoistVectorTransferPatternsOp : Op<Transform_Dialect,
+ "apply_patterns.vector.hoist_vector_transfer",
+ [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+ let description = [{
+ Hoists the vector transfer reads/writes outside the reduction and k-loop.
+ }];
+
+ let assemblyFormat = "attr-dict";
+}
+
+
+def ApplyVectorContractToFMAPatternsOp : Op<Transform_Dialect,
+ "apply_patterns.vector.contract_to_fma",
+ [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+ let description = [{
+ Implements the lowering of vector contraction op for GEMM of size MxN to
+ sequence of vector FMAs wrapped inside scf.for loop with iterargs to
+ accumulate the result of FMAs.
+ }];
+
+ let assemblyFormat = "attr-dict";
+}
+
def ApplyPadVectorizationPatternsOp : Op<Transform_Dialect,
"apply_patterns.linalg.pad_vectorization",
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 1dc700f22c2027..d35f99826004ed 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1824,6 +1824,16 @@ void populateConstantFoldLinalgOperations(RewritePatternSet &patterns,
/// suffices for achieving the sum.
void populateFoldAddIntoDestPatterns(RewritePatternSet &patterns);
+
+/// Pattern to hoists the vector transfer reads/writes outside the reduction and
+/// k-loop.
+void populateHoistVectorTransferPatterns(RewritePatternSet &patterns);
+
+
+/// Pattern to lower vector contraction op for GEMM of size MxN to
+/// sequence of vector FMAs
+void populateVectorContractToFMAPatterns(RewritePatternSet &patterns);
+
/// Pattern to fuse a `tensor.pad` operation with the producer of its source,
/// if the producer is a `linalg` operation with all parallel iterator types.
void populateFuseTensorPadWithProducerLinalgOpPatterns(
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 221ca27b80fdd0..849632aed8a13b 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -262,6 +262,19 @@ void transform::ApplyFoldAddIntoDestPatternsOp::populatePatterns(
linalg::populateFoldAddIntoDestPatterns(patterns);
}
+
+void transform::ApplyHoistVectorTransferPatternsOp::populatePatterns(
+ RewritePatternSet &patterns) {
+ linalg::populateHoistVectorTransferPatterns(patterns);
+}
+
+
+void transform::ApplyVectorContractToFMAPatternsOp::populatePatterns(
+ RewritePatternSet &patterns) {
+ linalg::populateVectorContractToFMAPatterns(patterns);
+}
+
+
void transform::ApplyPadVectorizationPatternsOp::populatePatterns(
RewritePatternSet &patterns) {
linalg::populatePadOpVectorizationPatterns(patterns);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
index 3594b084138124..90d926201cd753 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
@@ -41,6 +41,8 @@ add_mlir_dialect_library(MLIRLinalgTransforms
DecomposeGenericByUnfoldingPermutation.cpp
Vectorization.cpp
WinogradConv2D.cpp
+ HoistVectorTransfers.cpp
+ VectorContractToFMA.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Linalg
diff --git a/mlir/lib/Dialect/Linalg/Transforms/HoistVectorTransfers.cpp b/mlir/lib/Dialect/Linalg/Transforms/HoistVectorTransfers.cpp
new file mode 100644
index 00000000000000..5911d36cbafef6
--- /dev/null
+++ b/mlir/lib/Dialect/Linalg/Transforms/HoistVectorTransfers.cpp
@@ -0,0 +1,234 @@
+//===-HoistVectorTransfers.cpp -----------------------------------------*-
+// C++-*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements tile configuration hoisting on parallel loops.
+//
+//===----------------------------------------------------------------------===//
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
+#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/IRMapping.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+using namespace mlir;
+
+static FailureOr<SmallVector<vector::TransferReadOp>>
+getContractOperands(vector::ContractionOp contractOp) {
+ SmallVector<vector::TransferReadOp> list;
+ for (int i = 0; i < 3; i++) {
+ auto vectorReadOp =
+ contractOp.getOperand(i).getDefiningOp<vector::TransferReadOp>();
+ if (!vectorReadOp)
+ return failure();
+ list.push_back(vectorReadOp);
+ }
+ return list;
+}
+
+static FailureOr<SmallVector<memref::SubViewOp>>
+getReadOperands(SmallVector<vector::TransferReadOp> readOps) {
+ SmallVector<memref::SubViewOp> list;
+ for (vector::TransferReadOp readOp : readOps) {
+ auto subViewOp = readOp.getOperand(0).getDefiningOp<memref::SubViewOp>();
+ if (!subViewOp)
+ return failure();
+ list.push_back(subViewOp);
+ }
+ return list;
+}
+
+static FailureOr<SmallVector<scf::ForOp>>
+getNestedLoop(vector::ContractionOp contractOp) {
+ SmallVector<scf::ForOp> list;
+ Operation *current = contractOp;
+ for (int i = 0; i < 4; i++) {
+ Operation *parent = current->getParentOfType<scf::ForOp>();
+ if (!parent)
+ return failure();
+ list.push_back(dyn_cast<scf::ForOp>(parent));
+ current = parent;
+ }
+ return list;
+}
+
+static LogicalResult checkNestedLoop(SmallVector<scf::ForOp> loops,
+ SmallVector<memref::SubViewOp> subviews) {
+ auto subviewOpLhsOffsets = subviews[0].getOffsets();
+ auto subviewOpRhsOffsets = subviews[1].getOffsets();
+ auto subviewOpAccOffsets = subviews[2].getOffsets();
+
+ Value ivK = loops[0].getInductionVar();
+ if (ivK != subviewOpLhsOffsets[2] || ivK != subviewOpRhsOffsets[1])
+ return failure();
+
+ Value ivReduction = loops[1].getInductionVar();
+ if (ivReduction != subviewOpLhsOffsets[0] ||
+ ivReduction != subviewOpRhsOffsets[0])
+ return failure();
+
+ Value ivN = loops[2].getInductionVar();
+ if (ivN != subviewOpAccOffsets[1] || ivN != subviewOpRhsOffsets[2])
+ return failure();
+
+ Value ivM = loops[3].getInductionVar();
+ if (ivM != subviewOpLhsOffsets[1] || ivM != subviewOpAccOffsets[0])
+ return failure();
+
+ return success();
+}
+
+struct HoistVectorTransferOp : OpRewritePattern<vector::ContractionOp> {
+ using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
+ PatternRewriter &rewriter) const override {
+
+ // Check the vector contract operation satisfies the required pattern.
+ // Check the Acc, Lhs, and Rhs of contract operation
+
+ auto operands = getContractOperands(contractOp);
+ if (failed(operands))
+ return rewriter.notifyMatchFailure(contractOp,
+ "Invalid operands for contract op");
+
+ auto readOps = *operands;
+ auto vectorReadOpAcc = readOps[2];
+ auto vectorReadOpLhs = readOps[0];
+ auto vectorReadOpRhs = readOps[1];
+
+ // Check whether the operand of vector transfer read is a subview
+ auto subviews = getReadOperands(readOps);
+ if (failed(subviews))
+ return rewriter.notifyMatchFailure(
+ contractOp, "Vector read op operands are not a subview");
+
+ // Check the operation type MatMul, B-MatMul, or BR-MatMul
+ SmallVector<vector::IteratorType> contractIteratorTypes =
+ contractOp.getIteratorTypesArray();
+ int reductionCount =
+ std::count(contractIteratorTypes.begin(), contractIteratorTypes.end(),
+ vector::IteratorType::reduction);
+
+ auto vectorReadOpLhsType = cast<ShapedType>(vectorReadOpLhs.getType());
+ auto vectorReadOpRhsRank =
+ (cast<ShapedType>(vectorReadOpRhs.getType())).getRank();
+
+ if (reductionCount == 2 &&
+ (vectorReadOpLhsType.getRank() != 3 || vectorReadOpRhsRank != 3))
+ return rewriter.notifyMatchFailure(
+ contractOp, "Invalid rank for batch reduce operation");
+
+ if (reductionCount == 1)
+ return rewriter.notifyMatchFailure(
+ contractOp, "Batch matmul operation not supported yet");
+
+ if (reductionCount > 2)
+ return rewriter.notifyMatchFailure(
+ contractOp, "The vector contract operation is not a gemm");
+
+ // Check the K-dim to be 1
+ int64_t K =
+ vectorReadOpLhsType.getDimSize(vectorReadOpLhsType.getRank() - 1);
+ if (K != 1)
+ return rewriter.notifyMatchFailure(contractOp, "K dim is not 1");
+
+ // Check whether the linalg tiling + vector contract pattern matches for the
+ // 4-nested loop structure
+ auto loops = getNestedLoop(contractOp);
+ if (failed(loops))
+ return rewriter.notifyMatchFailure(
+ contractOp, "Invalid loop nest in contract pattern");
+
+ auto checkLoops = checkNestedLoop(*loops, *subviews);
+ if (failed(checkLoops))
+ return rewriter.notifyMatchFailure(
+ contractOp, "Loops doesn't match the iv in subviews");
+
+ auto nestedLoops = *loops;
+ auto kForOp = nestedLoops[0];
+ auto reductionForOp = nestedLoops[1];
+
+ // Move the vector transfer read before the reduction and k loop
+ rewriter.setInsertionPoint(reductionForOp);
+ auto *cloneVectorReadOp = rewriter.clone(*vectorReadOpAcc);
+
+ // Code to re-create the reduction and k loop with iter args
+ auto vectorReadOpValue = cloneVectorReadOp->getResult(0);
+ auto newReductionForOp = rewriter.create<scf::ForOp>(
+ reductionForOp.getLoc(), reductionForOp.getLowerBound(),
+ reductionForOp.getUpperBound(), reductionForOp.getStep(),
+ ValueRange{vectorReadOpValue},
+ [&](OpBuilder &rewriterNewReductionForOp, Location locNewReductionForOp,
+ Value ivNewReductionForOp, ValueRange iterArgsNewReductionForOp) {
+ auto newKForOp = rewriter.create<scf::ForOp>(
+ kForOp.getLoc(), kForOp.getLowerBound(), kForOp.getUpperBound(),
+ kForOp.getStep(), iterArgsNewReductionForOp,
+ [&](OpBuilder &rewriterNewKForOp, Location locNewKForOp,
+ Value ivNewKForOp, ValueRange iterArgsNewKForOp) {
+ IRMapping mapper;
+ mapper.map(reductionForOp.getInductionVar(),
+ ivNewReductionForOp);
+ mapper.map(kForOp.getInductionVar(), ivNewKForOp);
+
+ for (auto &op : kForOp.getBody()->without_terminator()) {
+ rewriterNewKForOp.clone(op, mapper);
+ }
+ rewriterNewKForOp.create<scf::YieldOp>(locNewKForOp,
+ iterArgsNewKForOp);
+ });
+ rewriterNewReductionForOp.create<scf::YieldOp>(
+ locNewReductionForOp, newKForOp.getResult(0));
+ });
+
+ // Code to hoist vector transfer write after reduction loop and also to
+ // update the yield of k loop
+ auto newKForOp =
+ llvm::dyn_cast<scf::ForOp>(newReductionForOp.getBody()->front());
+ Value newcontractOpValue;
+ vector::TransferWriteOp vectorWriteOperation;
+ Block *bodyBlock = newKForOp.getBody();
+ for (auto &op : bodyBlock->getOperations()) {
+ if (auto vectorContractOp = llvm::dyn_cast<vector::ContractionOp>(op)) {
+ vectorContractOp.setOperand(vectorContractOp.getNumOperands() - 1,
+ newKForOp.getRegionIterArgs()[0]);
+ newcontractOpValue = vectorContractOp.getResult();
+ }
+ if (auto yieldOp = llvm::dyn_cast<scf::YieldOp>(op)) {
+ yieldOp.setOperand(0, newcontractOpValue);
+ }
+ if (auto vectorWriteOp = llvm::dyn_cast<vector::TransferWriteOp>(op)) {
+ vectorWriteOperation = vectorWriteOp;
+ }
+ }
+
+ vectorWriteOperation.setOperand(0, newReductionForOp.getResult(0));
+ vectorWriteOperation->moveBefore(reductionForOp);
+
+ // Erase the old vector contract operation
+ for (auto result : contractOp->getResults()) {
+ for (auto *userOp : result.getUsers()) {
+ userOp->erase();
+ }
+ }
+ contractOp.erase();
+
+ return success();
+ }
+};
+
+void linalg::populateHoistVectorTransferPatterns(RewritePatternSet &patterns) {
+ patterns.add<HoistVectorTransferOp>(patterns.getContext());
+}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/VectorContractToFMA.cpp b/mlir/lib/Dialect/Linalg/Transforms/VectorContractToFMA.cpp
new file mode 100644
index 00000000000000..2a8132b93bdcb0
--- /dev/null
+++ b/mlir/lib/Dialect/Linalg/Transforms/VectorContractToFMA.cpp
@@ -0,0 +1,356 @@
+
+//===--------------- VectorContractToFMA.cpp ------------*- C++-*-===//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements lowering of vector contraction to vector fma.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+#define DEBUG_TYPE "vector-contract-to-fma"
+
+using namespace mlir;
+
+/// Returns true if the \p map is transposed.
+static bool isTransposed(AffineMap map) {
+ auto results = map.getResults();
+ // Assert if the map does not have 3 or 4 inputs ([] m, n, k).
+ assert((map.getNumInputs() == 3 || map.getNumInputs() == 4) &&
+ "3 or 4 input dim expected");
+ // Assert if the result is not 2D.
+ assert(map.getNumResults() == 2 && "Only 2 output dim expected");
+
+ // Check the last two dimensions for transposition.
+ auto dimExpr0 = dyn_cast<AffineDimExpr>(results[0]);
+ auto dimExpr1 = dyn_cast<AffineDimExpr>(results[1]);
+ assert((dimExpr0 && dimExpr1) && "Unexpected dim expression");
+
+ // Exclude output map result.
+ bool isOutputResultMap =
+ dimExpr0 ==
+ mlir::getAffineDimExpr(map.getNumInputs() - 3, map.getContext()) &&
+ dimExpr1 ==
+ mlir::getAffineDimExpr(map.getNumInputs() - 2, map.getContext());
+ assert(!isOutputResultMap && "Output result map not expected");
+
+ // It's transposed if result found as (k, m) or (n, k), else not transposed.
+ if ((dimExpr0 ==
+ mlir::getAffineDimExpr(map.getNumInputs() - 1, map.getContext()) &&
+ dimExpr1 ==
+ mlir::getAffineDimExpr(map.getNumInputs() - 3, map.getContext())) ||
+ (dimExpr0 ==
+ mlir::getAffineDimExpr(map.getNumInputs() - 2, map.getContext()) &&
+ dimExpr1 ==
+ mlir::getAffineDimExpr(map.getNumInputs() - 1, map.getContext())))
+ return true;
+ return false;
+}
+
+
+// Structure to hold transformation context
+struct TransformationContext {
+ scf::ForOp innerForOp;
+ scf::ForOp outerForOp;
+ scf::ForOp outermostLoop;
+};
+
+enum class MatMulType { Standard, Batch, BatchReduce };
+
+struct VectorContractToFMA
+ : public OpRewritePattern<vector::ContractionOp> {
+ using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::ContractionOp op,
+ PatternRewriter &rewriter) const override {
+ if (op.getKind() != vector::CombiningKind::ADD)
+ return rewriter.notifyMatchFailure(
+ op, "Unsupported combining kind, only supports ADD at the moment)");
+
+ auto maskableOp = cast<vector::MaskableOpInterface>(op.getOperation());
+ if (maskableOp.isMasked())
+ return rewriter.notifyMatchFailure(op, "Masked contractOp not supported");
+
+ SmallVector<AffineMap, 3> maps = op.getIndexingMapsArray();
+ if (llvm::any_of(
+ maps, [](AffineMap map) { return !map.isProjectedPermutation(); }))
+ return rewriter.notifyMatchFailure(op, "Unexpected map");
+
+ // Check for the variant of matrix multiply.
+ auto iteratorTypes = op.getIteratorTypesArray();
+ MatMulType matmulType;
+ unsigned outerDimIndex = 0;
+ if (iteratorTypes.size() > 3) {
+ outerDimIndex = iteratorTypes.size() - 4;
+ matmulType =
+ iteratorTypes[outerDimIndex] == vector::IteratorType::parallel
+ ? MatMulType::Batch
+ : MatMulType::BatchReduce;
+ outerDimIndex++;
+ } else if (iteratorTypes.size() == 3) {
+ matmulType = MatMulType::Standard;
+ } else {
+ return rewriter.notifyMatchFailure(op, "Not a gemm");
+ }
+
+ if (matmulType == MatMulType::Batch)
+ return rewriter.notifyMatchFailure(op, "Batch matmul not supported");
+ if (iteratorTypes[outerDimIndex] != vector::IteratorType::parallel ||
+ iteratorTypes[outerDimIndex + 1] != vector::IteratorType::parallel ||
+ iteratorTypes[outerDimIndex + 2] != vector::IteratorType::reduction)
+ return rewriter.notifyMatchFailure(op, "Not a gemm");
+
+ SmallVector<Value, 4> results;
+
+ auto lhs = op.getLhs();
+ auto rhs = op.getRhs();
+ auto acc = op.getAcc();
+ auto lhsDefiningOp = lhs.getDefiningOp<vector::TransferReadOp>();
+ auto rhsDefiningOp = rhs.getDefiningOp<vector::TransferReadOp>();
+ auto accDefiningOp = acc.getDefiningOp<vector::TransferReadOp>();
+ if (!lhsDefiningOp || !rhsDefiningOp)
+ return failure();
+
+ // Accumulator can be a TransferReadOp but must be coming from the chain of
+ // iterargs of nested loop.
+ if (accDefiningOp)
+ return failure();
+
+ // Make sure the inputs being read are whole tensor or subview.
+ if (!llvm::all_of(lhsDefiningOp.getIndices(), isZeroIndex) ||
+ !llvm::all_of(rhsDefiningOp.getIndices(), isZeroIndex)) {
+ return failure();
+ }
+
+ auto lhsType = cast<ShapedType>(lhsDefiningOp.getType());
+ auto rhsType = cast<ShapedType>(rhsDefiningOp.getType());
+ // auto accType = acc.getType();
+ // auto accType = cast<ShapedType>(accDefiningOp.getType());
+
+ if (matmulType == MatMulType::BatchReduce &&
+ (lhsType.getRank() != 3 || rhsType.getRank() != 3))
+ return failure();
+
+ if (matmulType == MatMulType::Standard &&
+ (lhsType.getRank() != 2 || rhsType.getRank() != 2))
+ return failure();
+
+ // Check for non-transposed matrices.
+ auto mapLHS = maps[0];
+ auto mapRHS = maps[1];
+ if (matmulType == MatMulType::BatchReduce) {
+ mapLHS = mapLHS.dropResult(0);
+ mapRHS = mapRHS.dropResult(0);
+ }
+ if (isTransposed(mapLHS) || isTransposed(mapRHS))
+ return rewriter.notifyMatchFailure(
+ op, "Transposed matrices are not expected");
+
+ // Verify that the accumulator is coming through a chain of iterargs of
+ // nested loop and it is define by 'TransferReadOp'.
+ //
+ struct TransformationContext ctx;
+
+ ctx.innerForOp = op->getParentOfType<scf::ForOp>();
+ if (!ctx.innerForOp)
+ return failure();
+ ctx.outerForOp = ctx.innerForOp->getParentOfType<scf::ForOp>();
+ if (!ctx.outerForOp)
+ return failure();
+ ctx.outermostLoop = ctx.outerForOp->getParentOfType<scf::ForOp>();
+ if (!ctx.outermostLoop)
+ return failure();
+
+ // Verify original inner loop has only one iterarg.
+ auto origIterArgs = ctx.innerForOp.getRegionIterArgs();
+ if (origIterArgs.size() != 1)
+ return failure();
+
+ // Verify chain, accumulator must be inner loop's iterarg.
+ auto bbArg = dyn_cast<BlockArgument>(acc);
+ if (!bbArg)
+ return failure();
+
+ // This block arg must be init arg, not induction variable.
+ if (bbArg.getOwner() != ctx.innerForOp.getBody() ||
+ bbArg.getArgNumber() == 0) {
+ return failure();
+ }
+
+ // This iterarg must be intialized by outer loop's iterarg.
+ auto innerInitValue =
+ ctx.innerForOp.getInitArgs()[bbArg.getArgNumber() - 1];
+ auto outerBBArg = dyn_cast<BlockArgument>(innerInitValue);
+ if (!outerBBArg)
+ return failure();
+
+ // This block arg must be init arg, not induction variable.
+ if (outerBBArg.getOwner() != ctx.outerForOp.getBody() ||
+ outerBBArg.getArgNumber() == 0) {
+ return failure();
+ }
+
+ // Outer loop's iterarg initializer must be a TransferReadOp.
+ acc = ctx.outerForOp.getInitArgs()[outerBBArg.getArgNumber() - 1];
+
+ // This must be defined by vector.transfer_read
+ if (!acc.getDefiningOp<vector::TransferReadOp>())
+ return failure();
+
+ accDefiningOp = acc.getDefiningOp<vector::TransferReadOp>();
+ if (!accDefiningOp)
+ return failure();
+
+ // Only 2-D output expected.
+ auto accType = cast<ShapedType>(accDefiningOp.getType());
+ if (accType.getRank() != 2)
+ return failure();
+
+ int64_t M = accType.getDimSize(0);
+ int64_t N = accType.getDimSize(1);
+ int64_t K = lhsType.getDimSize(lhsType.getRank() - 1);
+
+ // K must be 1.
+ if (K != 1)
+ return failure();
+
+ auto accSubview = accDefiningOp.getSource();
+ Location loc = op.getLoc();
+
+ // Create M different <1xN> subviews.
+ auto memrefType = cast<MemRefType>(accSubview.getType());
+ auto elementType = memrefType.getElementType();
+ SmallVector<OpFoldResult> mixedSizes = {rewriter.getIndexAttr(K),
+ rewriter.getIndexAttr(N)};
+ SmallVector<OpFoldResult> mixedStrides = {rewriter.getIndexAttr(1),
+ rewriter.getIndexAttr(1)};
+
+ rewriter.setInsertionPoint(
+ ctx.outermostLoop.getBody(),
+ std::prev(ctx.outermostLoop.getBody()->end(), 1));
+
+ Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+ SmallVector<Value, 4> subview_2_splits;
+ for (int i = 0; i < M; i++) {
+ SmallVector<OpFoldResult> mixedOffsets = {
+ rewriter.getIndexAttr(i),
+ rewriter.getIndexAttr(0),
+ };
+ auto split = rewriter.create<memref::SubViewOp>(
+ loc, accSubview, mixedOffsets, mixedSizes, mixedStrides);
+ subview_2_splits.push_back(split);
+ }
+
+ // Intialize each accumulator with a vector of size N
+ SmallVector<Value, 4> initAccs;
+ for (auto subview : subview_2_splits) {
+ auto acc = rewriter.create<vector::LoadOp>(
+ loc, VectorType::get({N}, elementType), subview, ValueRange{c0, c0});
+ initAccs.push_back(acc);
+ }
+
+ // Create new outer loop with M different accumulators.
+ auto newOuterForOp = rewriter.create<scf::ForOp>(
+ loc, ctx.outerForOp.getLowerBound(), ctx.outerForOp.getUpperBound(),
+ ctx.outerForOp.getStep(), initAccs,
+ [&](OpBuilder &nestedBuilder, Location loc, Value iv,
+ ValueRange iterArgs) {
+ // Create new inner loop with M accumulators.
+ auto newInnerForOp = nestedBuilder.create<scf::ForOp>(
+ loc, ctx.innerForOp.getLowerBound(),
+ ctx.innerForOp.getUpperBound(), ctx.innerForOp.getStep(),
+ iterArgs,
+ [&](OpBuilder &innerBuilder, Location loc, Value innerIv,
+ ValueRange innerIterArgs) {
+ IRMapping mapping;
+ mapping.map(
+ lhsDefiningOp.getSource().getDefiningOp()->getOperand(1),
+ iv);
+ mapping.map(
+ lhsDefiningOp.getSource().getDefiningOp()->getOperand(3),
+ innerIv);
+ auto lhsClone = innerBuilder.clone(
+ *lhsDefiningOp.getSource().getDefiningOp(), mapping);
+
+ // Load and broadcast individual elements
+ SmallVector<Value, 4> broadcasts;
+ for (int i = 0; i < M; i++) {
+ auto elem = innerBuilder.create<memref::LoadOp>(
+ loc, lhsClone->getResult(0),
+ ValueRange{
+ c0,
+ innerBuilder.create<arith::ConstantIndexOp>(loc, i),
+ c0});
+ auto bcast = innerBuilder.create<vector::BroadcastOp>(
+ loc, VectorType::get({N}, elem.getType()), elem);
+ broadcasts.push_back(bcast);
+ }
+
+ IRMapping rhsMapping;
+ rhsMapping.map(
+ rhsDefiningOp.getSource().getDefiningOp()->getOperand(1),
+ iv);
+ rhsMapping.map(
+ rhsDefiningOp.getSource().getDefiningOp()->getOperand(2),
+ innerIv);
+ auto rhsClone = innerBuilder.clone(
+ *rhsDefiningOp.getSource().getDefiningOp(), rhsMapping);
+ auto rowVec = innerBuilder.create<vector::LoadOp>(
+ loc, VectorType::get({N}, elementType),
+ rhsClone->getResult(0), ValueRange{c0, c0, c0});
+
+ // Create M different FMAs using broadcasts and current
+ // accumulator values.
+ for (int i = 0; i < M; i++) {
+ auto fma = innerBuilder.create<vector::FMAOp>(
+ loc, broadcasts[i], rowVec, innerIterArgs[i]);
+ results.push_back(fma);
+ }
+
+ // Yield all M results
+ innerBuilder.create<scf::YieldOp>(loc, results);
+ });
+
+ // Yield results from inner loop to outer loop
+ nestedBuilder.create<scf::YieldOp>(loc, newInnerForOp.getResults());
+ });
+
+ Value matResult = ctx.outerForOp.getResult(0);
+ Operation *writeOp;
+ for (auto user : matResult.getUsers()) {
+ writeOp = dyn_cast<vector::TransferWriteOp>(user);
+ if (writeOp)
+ break;
+ }
+
+ // Store final results back to original locations.
+ if (writeOp) {
+ for (int i = 0; i < M; i++) {
+ rewriter.create<vector::StoreOp>(loc, newOuterForOp.getResult(i),
+ subview_2_splits[i],
+ ValueRange{c0, c0});
+ }
+ }
+
+ // Erase original write.
+ if (writeOp)
+ rewriter.eraseOp(writeOp);
+
+ return success();
+ }
+
+};
+
+void linalg::populateVectorContractToFMAPatterns(RewritePatternSet &patterns) {
+ patterns.add<VectorContractToFMA>(patterns.getContext());
+}
diff --git a/mlir/test/Dialect/Linalg/hoist-vector-transfer.mlir b/mlir/test/Dialect/Linalg/hoist-vector-transfer.mlir
new file mode 100644
index 00000000000000..f1f24e4e53cb66
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/hoist-vector-transfer.mlir
@@ -0,0 +1,97 @@
+// RUN: mlir-opt %s -transform-interpreter -cse -split-input-file | FileCheck %s
+
+#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
+#map2 = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+ memref.global "private" constant @__constant_24x64x64xf32 : memref<24x64x64xf32> = dense<1.000000e+00> {alignment = 64 : i64}
+ func.func @simple_gemm(%arg0: memref<8x24x32x64xf32>) -> memref<8x24x32x64xf32> {
+ %cst = arith.constant 0.000000e+00 : f32
+ %cst_0 = arith.constant dense<0.000000e+00> : vector<32x64xf32>
+ %c1 = arith.constant 1 : index
+ %c24 = arith.constant 24 : index
+ %c64 = arith.constant 64 : index
+ %c4 = arith.constant 4 : index
+ %c32 = arith.constant 32 : index
+ %c0 = arith.constant 0 : index
+ %0 = memref.get_global @__constant_24x64x64xf32 : memref<24x64x64xf32>
+ %alloc = memref.alloc() {alignment = 64 : i64} : memref<8x24x32x64xf32>
+ scf.forall (%arg1, %arg2) in (8, 24) {
+ %subview = memref.subview %alloc[%arg1, %arg2, 0, 0] [1, 1, 32, 64] [1, 1, 1, 1] : memref<8x24x32x64xf32> to memref<32x64xf32, strided<[64, 1], offset: ?>>
+ vector.transfer_write %cst_0, %subview[%c0, %c0] {in_bounds = [true, true]} : vector<32x64xf32>, memref<32x64xf32, strided<[64, 1], offset: ?>>
+ %subview_1 = memref.subview %arg0[%arg1, 0, 0, 0] [1, 24, 32, 64] [1, 1, 1, 1] : memref<8x24x32x64xf32> to memref<24x32x64xf32, strided<[2048, 64, 1], offset: ?>>
+ scf.for %arg3 = %c0 to %c32 step %c4 {
+ scf.for %arg4 = %c0 to %c64 step %c64 {
+ %subview_2 = memref.subview %subview[%arg3, %arg4] [4, 64] [1, 1] : memref<32x64xf32, strided<[64, 1], offset: ?>> to memref<4x64xf32, strided<[64, 1], offset: ?>>
+ scf.for %arg5 = %c0 to %c24 step %c1 {
+ scf.for %arg6 = %c0 to %c64 step %c1 {
+ %subview_3 = memref.subview %subview_1[%arg5, %arg3, %arg6] [1, 4, 1] [1, 1, 1] : memref<24x32x64xf32, strided<[2048, 64, 1], offset: ?>> to memref<1x4x1xf32, strided<[2048, 64, 1], offset: ?>>
+ %subview_4 = memref.subview %0[%arg5, %arg6, %arg4] [1, 1, 64] [1, 1, 1] : memref<24x64x64xf32> to memref<1x1x64xf32, strided<[4096, 64, 1], offset: ?>>
+ %1 = vector.transfer_read %subview_3[%c0, %c0, %c0], %cst {in_bounds = [true, true, true]} : memref<1x4x1xf32, strided<[2048, 64, 1], offset: ?>>, vector<1x4x1xf32>
+ %2 = vector.transfer_read %subview_4[%c0, %c0, %c0], %cst {in_bounds = [true, true, true]} : memref<1x1x64xf32, strided<[4096, 64, 1], offset: ?>>, vector<1x1x64xf32>
+ %3 = vector.transfer_read %subview_2[%c0, %c0], %cst {in_bounds = [true, true]} : memref<4x64xf32, strided<[64, 1], offset: ?>>, vector<4x64xf32>
+ %4 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %1, %2, %3 : vector<1x4x1xf32>, vector<1x1x64xf32> into vector<4x64xf32>
+ vector.transfer_write %4, %subview_2[%c0, %c0] {in_bounds = [true, true]} : vector<4x64xf32>, memref<4x64xf32, strided<[64, 1], offset: ?>>
+ }
+ }
+ }
+ }
+ }
+ return %alloc : memref<8x24x32x64xf32>
+ }
+
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
+// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+
+// CHECK-LABEL: memref.global "private" constant @__constant_24x64x64xf32 : memref<24x64x64xf32> = dense<1.000000e+00> {alignment = 64 : i64}
+
+// CHECK-LABEL: func.func @simple_gemm(
+// CHECK-SAME: %[[VAL_0:.*]]: memref<8x24x32x64xf32>) -> memref<8x24x32x64xf32> {
+// CHECK: %[[VAL_1:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[VAL_2:.*]] = arith.constant dense<0.000000e+00> : vector<32x64xf32>
+// CHECK: %[[VAL_3:.*]] = arith.constant 1 : index
+// CHECK: %[[VAL_4:.*]] = arith.constant 24 : index
+// CHECK: %[[VAL_5:.*]] = arith.constant 64 : index
+// CHECK: %[[VAL_6:.*]] = arith.constant 4 : index
+// CHECK: %[[VAL_7:.*]] = arith.constant 32 : index
+// CHECK: %[[VAL_8:.*]] = arith.constant 0 : index
+// CHECK: %[[VAL_9:.*]] = memref.get_global @__constant_24x64x64xf32 : memref<24x64x64xf32>
+// CHECK: %[[VAL_10:.*]] = memref.alloc() {alignment = 64 : i64} : memref<8x24x32x64xf32>
+// CHECK: scf.forall (%[[VAL_11:.*]], %[[VAL_12:.*]]) in (8, 24) {
+// CHECK: %[[VAL_13:.*]] = memref.subview %[[VAL_10]]{{\[}}%[[VAL_11]], %[[VAL_12]], 0, 0] [1, 1, 32, 64] [1, 1, 1, 1] : memref<8x24x32x64xf32> to memref<32x64xf32, strided<[64, 1], offset: ?>>
+// CHECK: vector.transfer_write %[[VAL_2]], %[[VAL_13]]{{\[}}%[[VAL_8]], %[[VAL_8]]] {in_bounds = [true, true]} : vector<32x64xf32>, memref<32x64xf32, strided<[64, 1], offset: ?>>
+// CHECK: %[[VAL_14:.*]] = memref.subview %[[VAL_0]]{{\[}}%[[VAL_11]], 0, 0, 0] [1, 24, 32, 64] [1, 1, 1, 1] : memref<8x24x32x64xf32> to memref<24x32x64xf32, strided<[2048, 64, 1], offset: ?>>
+// CHECK: scf.for %[[VAL_15:.*]] = %[[VAL_8]] to %[[VAL_7]] step %[[VAL_6]] {
+// CHECK: scf.for %[[VAL_16:.*]] = %[[VAL_8]] to %[[VAL_5]] step %[[VAL_5]] {
+// CHECK: %[[VAL_17:.*]] = memref.subview %[[VAL_13]]{{\[}}%[[VAL_15]], %[[VAL_16]]] [4, 64] [1, 1] : memref<32x64xf32, strided<[64, 1], offset: ?>> to memref<4x64xf32, strided<[64, 1], offset: ?>>
+// CHECK: %[[VAL_18:.*]] = vector.transfer_read %[[VAL_17]]{{\[}}%[[VAL_8]], %[[VAL_8]]], %[[VAL_1]] {in_bounds = [true, true]} : memref<4x64xf32, strided<[64, 1], offset: ?>>, vector<4x64xf32>
+// CHECK: %[[VAL_19:.*]] = scf.for %[[VAL_20:.*]] = %[[VAL_8]] to %[[VAL_4]] step %[[VAL_3]] iter_args(%[[VAL_21:.*]] = %[[VAL_18]]) -> (vector<4x64xf32>) {
+// CHECK: %[[VAL_22:.*]] = scf.for %[[VAL_23:.*]] = %[[VAL_8]] to %[[VAL_5]] step %[[VAL_3]] iter_args(%[[VAL_24:.*]] = %[[VAL_21]]) -> (vector<4x64xf32>) {
+// CHECK: %[[VAL_25:.*]] = memref.subview %[[VAL_14]]{{\[}}%[[VAL_20]], %[[VAL_15]], %[[VAL_23]]] [1, 4, 1] [1, 1, 1] : memref<24x32x64xf32, strided<[2048, 64, 1], offset: ?>> to memref<1x4x1xf32, strided<[2048, 64, 1], offset: ?>>
+// CHECK: %[[VAL_26:.*]] = memref.subview %[[VAL_9]]{{\[}}%[[VAL_20]], %[[VAL_23]], %[[VAL_16]]] [1, 1, 64] [1, 1, 1] : memref<24x64x64xf32> to memref<1x1x64xf32, strided<[4096, 64, 1], offset: ?>>
+// CHECK: %[[VAL_27:.*]] = vector.transfer_read %[[VAL_25]]{{\[}}%[[VAL_8]], %[[VAL_8]], %[[VAL_8]]], %[[VAL_1]] {in_bounds = [true, true, true]} : memref<1x4x1xf32, strided<[2048, 64, 1], offset: ?>>, vector<1x4x1xf32>
+// CHECK: %[[VAL_28:.*]] = vector.transfer_read %[[VAL_26]]{{\[}}%[[VAL_8]], %[[VAL_8]], %[[VAL_8]]], %[[VAL_1]] {in_bounds = [true, true, true]} : memref<1x1x64xf32, strided<[4096, 64, 1], offset: ?>>, vector<1x1x64xf32>
+// CHECK: %[[VAL_29:.*]] = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %4, %5, %arg8 : vector<1x4x1xf32>, vector<1x1x64xf32> into vector<4x64xf32>
+// CHECK: scf.yield %[[VAL_29]] : vector<4x64xf32>
+// CHECK: }
+// CHECK: scf.yield %[[VAL_22]] : vector<4x64xf32>
+// CHECK: }
+// CHECK: vector.transfer_write %[[VAL_19]], %[[VAL_17]]{{\[}}%[[VAL_8]], %[[VAL_8]]] {in_bounds = [true, true]} : vector<4x64xf32>, memref<4x64xf32, strided<[64, 1], offset: ?>>
+// CHECK: }
+// CHECK: }
+// CHECK: }
+// CHECK: return %[[VAL_10]] : memref<8x24x32x64xf32>
+// CHECK: }
+
+
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.vector.hoist_vector_transfer
+ } : !transform.any_op
+ transform.yield
+ }
+}
diff --git a/mlir/test/Dialect/Linalg/vector-contract-to-fma.mlir b/mlir/test/Dialect/Linalg/vector-contract-to-fma.mlir
new file mode 100644
index 00000000000000..ba11074e3c9637
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/vector-contract-to-fma.mlir
@@ -0,0 +1,113 @@
+// RUN: mlir-opt %s -transform-interpreter -cse -split-input-file | FileCheck %s
+
+#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
+#map2 = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+ memref.global "private" constant @__constant_24x64x64xf32 : memref<24x64x64xf32> = dense<1.000000e+00> {alignment = 64 : i64}
+ func.func @simple_gemm(%arg0: memref<8x24x32x64xf32>) -> memref<8x24x32x64xf32> {
+ %cst = arith.constant 0.000000e+00 : f32
+ %cst_0 = arith.constant dense<0.000000e+00> : vector<32x64xf32>
+ %c1 = arith.constant 1 : index
+ %c24 = arith.constant 24 : index
+ %c64 = arith.constant 64 : index
+ %c4 = arith.constant 4 : index
+ %c32 = arith.constant 32 : index
+ %c0 = arith.constant 0 : index
+ %0 = memref.get_global @__constant_24x64x64xf32 : memref<24x64x64xf32>
+ %alloc = memref.alloc() {alignment = 64 : i64} : memref<8x24x32x64xf32>
+ scf.forall (%arg1, %arg2) in (8, 24) {
+ %subview = memref.subview %alloc[%arg1, %arg2, 0, 0] [1, 1, 32, 64] [1, 1, 1, 1] : memref<8x24x32x64xf32> to memref<32x64xf32, strided<[64, 1], offset: ?>>
+ vector.transfer_write %cst_0, %subview[%c0, %c0] {in_bounds = [true, true]} : vector<32x64xf32>, memref<32x64xf32, strided<[64, 1], offset: ?>>
+ %subview_1 = memref.subview %arg0[%arg1, 0, 0, 0] [1, 24, 32, 64] [1, 1, 1, 1] : memref<8x24x32x64xf32> to memref<24x32x64xf32, strided<[2048, 64, 1], offset: ?>>
+ scf.for %arg3 = %c0 to %c32 step %c4 {
+ scf.for %arg4 = %c0 to %c64 step %c64 {
+ %subview_2 = memref.subview %subview[%arg3, %arg4] [4, 64] [1, 1] : memref<32x64xf32, strided<[64, 1], offset: ?>> to memref<4x64xf32, strided<[64, 1], offset: ?>>
+ %1 = vector.transfer_read %subview_2[%c0, %c0], %cst {in_bounds = [true, true]} : memref<4x64xf32, strided<[64, 1], offset: ?>>, vector<4x64xf32>
+ %2 = scf.for %arg5 = %c0 to %c24 step %c1 iter_args(%arg6 = %1) -> (vector<4x64xf32>) {
+ %3 = scf.for %arg7 = %c0 to %c64 step %c1 iter_args(%arg8 = %arg6) -> (vector<4x64xf32>) {
+ %subview_3 = memref.subview %subview_1[%arg5, %arg3, %arg7] [1, 4, 1] [1, 1, 1] : memref<24x32x64xf32, strided<[2048, 64, 1], offset: ?>> to memref<1x4x1xf32, strided<[2048, 64, 1], offset: ?>>
+ %subview_4 = memref.subview %0[%arg5, %arg7, %arg4] [1, 1, 64] [1, 1, 1] : memref<24x64x64xf32> to memref<1x1x64xf32, strided<[4096, 64, 1], offset: ?>>
+ %4 = vector.transfer_read %subview_3[%c0, %c0, %c0], %cst {in_bounds = [true, true, true]} : memref<1x4x1xf32, strided<[2048, 64, 1], offset: ?>>, vector<1x4x1xf32>
+ %5 = vector.transfer_read %subview_4[%c0, %c0, %c0], %cst {in_bounds = [true, true, true]} : memref<1x1x64xf32, strided<[4096, 64, 1], offset: ?>>, vector<1x1x64xf32>
+ %6 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %4, %5, %arg8 : vector<1x4x1xf32>, vector<1x1x64xf32> into vector<4x64xf32>
+ scf.yield %6 : vector<4x64xf32>
+ }
+ scf.yield %3 : vector<4x64xf32>
+ }
+ vector.transfer_write %2, %subview_2[%c0, %c0] {in_bounds = [true, true]} : vector<4x64xf32>, memref<4x64xf32, strided<[64, 1], offset: ?>>
+ }
+ }
+ }
+ return %alloc : memref<8x24x32x64xf32>
+ }
+
+// CHECK-LABEL: memref.global "private" constant @__constant_24x64x64xf32 : memref<24x64x64xf32> = dense<1.000000e+00> {alignment = 64 : i64}
+
+// CHECK-LABEL: func.func @simple_gemm(
+// CHECK-SAME: %[[VAL_0:.*]]: memref<8x24x32x64xf32>) -> memref<8x24x32x64xf32> {
+// CHECK: %[[VAL_1:.*]] = arith.constant 3 : index
+// CHECK: %[[VAL_2:.*]] = arith.constant 2 : index
+// CHECK: %[[VAL_3:.*]] = arith.constant dense<0.000000e+00> : vector<32x64xf32>
+// CHECK: %[[VAL_4:.*]] = arith.constant 1 : index
+// CHECK: %[[VAL_5:.*]] = arith.constant 24 : index
+// CHECK: %[[VAL_6:.*]] = arith.constant 64 : index
+// CHECK: %[[VAL_7:.*]] = arith.constant 4 : index
+// CHECK: %[[VAL_8:.*]] = arith.constant 32 : index
+// CHECK: %[[VAL_9:.*]] = arith.constant 0 : index
+// CHECK: %[[VAL_10:.*]] = memref.get_global @__constant_24x64x64xf32 : memref<24x64x64xf32>
+// CHECK: %[[VAL_11:.*]] = memref.alloc() {alignment = 64 : i64} : memref<8x24x32x64xf32>
+// CHECK: scf.forall (%[[VAL_12:.*]], %[[VAL_13:.*]]) in (8, 24) {
+// CHECK: %[[VAL_14:.*]] = memref.subview %[[VAL_11]]{{\[}}%[[VAL_12]], %[[VAL_13]], 0, 0] [1, 1, 32, 64] [1, 1, 1, 1] : memref<8x24x32x64xf32> to memref<32x64xf32, strided<[64, 1], offset: ?>>
+// CHECK: vector.transfer_write %[[VAL_3]], %[[VAL_14]]{{\[}}%[[VAL_9]], %[[VAL_9]]] {in_bounds = [true, true]} : vector<32x64xf32>, memref<32x64xf32, strided<[64, 1], offset: ?>>
+// CHECK: %[[VAL_15:.*]] = memref.subview %[[VAL_0]]{{\[}}%[[VAL_12]], 0, 0, 0] [1, 24, 32, 64] [1, 1, 1, 1] : memref<8x24x32x64xf32> to memref<24x32x64xf32, strided<[2048, 64, 1], offset: ?>>
+// CHECK: scf.for %[[VAL_16:.*]] = %[[VAL_9]] to %[[VAL_8]] step %[[VAL_7]] {
+// CHECK: scf.for %[[VAL_17:.*]] = %[[VAL_9]] to %[[VAL_6]] step %[[VAL_6]] {
+// CHECK: %[[VAL_18:.*]] = memref.subview %[[VAL_14]]{{\[}}%[[VAL_16]], %[[VAL_17]]] [4, 64] [1, 1] : memref<32x64xf32, strided<[64, 1], offset: ?>> to memref<4x64xf32, strided<[64, 1], offset: ?>>
+// CHECK: %[[VAL_19:.*]] = memref.subview %[[VAL_18]][0, 0] [1, 64] [1, 1] : memref<4x64xf32, strided<[64, 1], offset: ?>> to memref<1x64xf32, strided<[64, 1], offset: ?>>
+// CHECK: %[[VAL_20:.*]] = memref.subview %[[VAL_18]][1, 0] [1, 64] [1, 1] : memref<4x64xf32, strided<[64, 1], offset: ?>> to memref<1x64xf32, strided<[64, 1], offset: ?>>
+// CHECK: %[[VAL_21:.*]] = memref.subview %[[VAL_18]][2, 0] [1, 64] [1, 1] : memref<4x64xf32, strided<[64, 1], offset: ?>> to memref<1x64xf32, strided<[64, 1], offset: ?>>
+// CHECK: %[[VAL_22:.*]] = memref.subview %[[VAL_18]][3, 0] [1, 64] [1, 1] : memref<4x64xf32, strided<[64, 1], offset: ?>> to memref<1x64xf32, strided<[64, 1], offset: ?>>
+// CHECK: %[[VAL_23:.*]] = vector.load %[[VAL_19]]{{\[}}%[[VAL_9]], %[[VAL_9]]] : memref<1x64xf32, strided<[64, 1], offset: ?>>, vector<64xf32>
+// CHECK: %[[VAL_24:.*]] = vector.load %[[VAL_20]]{{\[}}%[[VAL_9]], %[[VAL_9]]] : memref<1x64xf32, strided<[64, 1], offset: ?>>, vector<64xf32>
+// CHECK: %[[VAL_25:.*]] = vector.load %[[VAL_21]]{{\[}}%[[VAL_9]], %[[VAL_9]]] : memref<1x64xf32, strided<[64, 1], offset: ?>>, vector<64xf32>
+// CHECK: %[[VAL_26:.*]] = vector.load %[[VAL_22]]{{\[}}%[[VAL_9]], %[[VAL_9]]] : memref<1x64xf32, strided<[64, 1], offset: ?>>, vector<64xf32>
+// CHECK: %[[VAL_27:.*]]:4 = scf.for %[[VAL_28:.*]] = %[[VAL_9]] to %[[VAL_5]] step %[[VAL_4]] iter_args(%[[VAL_29:.*]] = %[[VAL_23]], %[[VAL_30:.*]] = %[[VAL_24]], %[[VAL_31:.*]] = %[[VAL_25]], %[[VAL_32:.*]] = %[[VAL_26]]) -> (vector<64xf32>, vector<64xf32>, vector<64xf32>, vector<64xf32>) {
+// CHECK: %[[VAL_33:.*]]:4 = scf.for %[[VAL_34:.*]] = %[[VAL_9]] to %[[VAL_6]] step %[[VAL_4]] iter_args(%[[VAL_35:.*]] = %[[VAL_29]], %[[VAL_36:.*]] = %[[VAL_30]], %[[VAL_37:.*]] = %[[VAL_31]], %[[VAL_38:.*]] = %[[VAL_32]]) -> (vector<64xf32>, vector<64xf32>, vector<64xf32>, vector<64xf32>) {
+// CHECK: %[[VAL_39:.*]] = memref.subview %[[VAL_15]]{{\[}}%[[VAL_28]], %[[VAL_16]], %[[VAL_34]]] [1, 4, 1] [1, 1, 1] : memref<24x32x64xf32, strided<[2048, 64, 1], offset: ?>> to memref<1x4x1xf32, strided<[2048, 64, 1], offset: ?>>
+// CHECK: %[[VAL_40:.*]] = memref.load %[[VAL_39]]{{\[}}%[[VAL_9]], %[[VAL_9]], %[[VAL_9]]] : memref<1x4x1xf32, strided<[2048, 64, 1], offset: ?>>
+// CHECK: %[[VAL_41:.*]] = vector.broadcast %[[VAL_40]] : f32 to vector<64xf32>
+// CHECK: %[[VAL_42:.*]] = memref.load %[[VAL_39]]{{\[}}%[[VAL_9]], %[[VAL_4]], %[[VAL_9]]] : memref<1x4x1xf32, strided<[2048, 64, 1], offset: ?>>
+// CHECK: %[[VAL_43:.*]] = vector.broadcast %[[VAL_42]] : f32 to vector<64xf32>
+// CHECK: %[[VAL_44:.*]] = memref.load %[[VAL_39]]{{\[}}%[[VAL_9]], %[[VAL_2]], %[[VAL_9]]] : memref<1x4x1xf32, strided<[2048, 64, 1], offset: ?>>
+// CHECK: %[[VAL_45:.*]] = vector.broadcast %[[VAL_44]] : f32 to vector<64xf32>
+// CHECK: %[[VAL_46:.*]] = memref.load %[[VAL_39]]{{\[}}%[[VAL_9]], %[[VAL_1]], %[[VAL_9]]] : memref<1x4x1xf32, strided<[2048, 64, 1], offset: ?>>
+// CHECK: %[[VAL_47:.*]] = vector.broadcast %[[VAL_46]] : f32 to vector<64xf32>
+// CHECK: %[[VAL_48:.*]] = memref.subview %[[VAL_10]]{{\[}}%[[VAL_28]], %[[VAL_34]], %[[VAL_17]]] [1, 1, 64] [1, 1, 1] : memref<24x64x64xf32> to memref<1x1x64xf32, strided<[4096, 64, 1], offset: ?>>
+// CHECK: %[[VAL_49:.*]] = vector.load %[[VAL_48]]{{\[}}%[[VAL_9]], %[[VAL_9]], %[[VAL_9]]] : memref<1x1x64xf32, strided<[4096, 64, 1], offset: ?>>, vector<64xf32>
+// CHECK: %[[VAL_50:.*]] = vector.fma %[[VAL_41]], %[[VAL_49]], %[[VAL_35]] : vector<64xf32>
+// CHECK: %[[VAL_51:.*]] = vector.fma %[[VAL_43]], %[[VAL_49]], %[[VAL_36]] : vector<64xf32>
+// CHECK: %[[VAL_52:.*]] = vector.fma %[[VAL_45]], %[[VAL_49]], %[[VAL_37]] : vector<64xf32>
+// CHECK: %[[VAL_53:.*]] = vector.fma %[[VAL_47]], %[[VAL_49]], %[[VAL_38]] : vector<64xf32>
+// CHECK: scf.yield %[[VAL_50]], %[[VAL_51]], %[[VAL_52]], %[[VAL_53]] : vector<64xf32>, vector<64xf32>, vector<64xf32>, vector<64xf32>
+// CHECK: }
+// CHECK: scf.yield %[[VAL_54:.*]]#0, %[[VAL_54]]#1, %[[VAL_54]]#2, %[[VAL_54]]#3 : vector<64xf32>, vector<64xf32>, vector<64xf32>, vector<64xf32>
+// CHECK: }
+// CHECK: vector.store %[[VAL_55:.*]]#0, %[[VAL_19]]{{\[}}%[[VAL_9]], %[[VAL_9]]] : memref<1x64xf32, strided<[64, 1], offset: ?>>, vector<64xf32>
+// CHECK: vector.store %[[VAL_55]]#1, %[[VAL_20]]{{\[}}%[[VAL_9]], %[[VAL_9]]] : memref<1x64xf32, strided<[64, 1], offset: ?>>, vector<64xf32>
+// CHECK: vector.store %[[VAL_55]]#2, %[[VAL_21]]{{\[}}%[[VAL_9]], %[[VAL_9]]] : memref<1x64xf32, strided<[64, 1], offset: ?>>, vector<64xf32>
+// CHECK: vector.store %[[VAL_55]]#3, %[[VAL_22]]{{\[}}%[[VAL_9]], %[[VAL_9]]] : memref<1x64xf32, strided<[64, 1], offset: ?>>, vector<64xf32>
+// CHECK: }
+// CHECK: }
+// CHECK: }
+// CHECK: return %[[VAL_11]] : memref<8x24x32x64xf32>
+// CHECK: }
+
+ module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %0 {
+ transform.apply_patterns.vector.contract_to_fma
+ } : !transform.any_op
+ transform.yield
+ }
+ }
>From 83ed5c4db59730cc0b3bd38a0e5c2551904992b1 Mon Sep 17 00:00:00 2001
From: Arun Thangamani <arun.thangamani at intel.com>
Date: Mon, 6 Jan 2025 02:21:34 -0800
Subject: [PATCH 2/3] added few more test-cases and code re-factoring
---
.../Linalg/TransformOps/LinalgTransformOps.td | 8 +-
.../Dialect/Linalg/Transforms/Transforms.h | 9 +-
.../Transforms/HoistVectorTransfers.cpp | 49 ++++++--
.../Linalg/Transforms/VectorContractToFMA.cpp | 56 ++++++++-
.../Dialect/Linalg/hoist-vector-transfer.mlir | 93 ++++++++++++++-
.../Linalg/vector-contract-to-fma.mlir | 107 ++++--------------
6 files changed, 216 insertions(+), 106 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index d614bb4789767b..9acaf1ba231a0d 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -111,7 +111,8 @@ def ApplyHoistVectorTransferPatternsOp : Op<Transform_Dialect,
"apply_patterns.vector.hoist_vector_transfer",
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
let description = [{
- Hoists the vector transfer reads/writes outside the reduction and k-loop.
+ Finds pattern to hoist the possible vector transfer reads/writes outside the reduction and k-loop
+ for a batch reduce matmul operation.
}];
let assemblyFormat = "attr-dict";
@@ -122,9 +123,8 @@ def ApplyVectorContractToFMAPatternsOp : Op<Transform_Dialect,
"apply_patterns.vector.contract_to_fma",
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
let description = [{
- Implements the lowering of vector contraction op for GEMM of size MxN to
- sequence of vector FMAs wrapped inside scf.for loop with iterargs to
- accumulate the result of FMAs.
+ Collects pattern to lower vector contraction operation (batch reduce matmul) for GEMM of size MxN to
+ sequence of vector FMAs.
}];
let assemblyFormat = "attr-dict";
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index d35f99826004ed..6f639b45408d87 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1825,13 +1825,12 @@ void populateConstantFoldLinalgOperations(RewritePatternSet &patterns,
void populateFoldAddIntoDestPatterns(RewritePatternSet &patterns);
-/// Pattern to hoists the vector transfer reads/writes outside the reduction and
-/// k-loop.
+/// Pattern to hoists the vector transfer reads/writes outside the reduction and
+/// k-loop for batch reduce matmul operation if licm fails.
void populateHoistVectorTransferPatterns(RewritePatternSet &patterns);
-
-/// Pattern to lower vector contraction op for GEMM of size MxN to
-/// sequence of vector FMAs
+/// Pattern to lower vector contraction operation (batch reduce matmul) for GEMM of size MxN to
+/// sequence of vector FMAs.
void populateVectorContractToFMAPatterns(RewritePatternSet &patterns);
/// Pattern to fuse a `tensor.pad` operation with the producer of its source,
diff --git a/mlir/lib/Dialect/Linalg/Transforms/HoistVectorTransfers.cpp b/mlir/lib/Dialect/Linalg/Transforms/HoistVectorTransfers.cpp
index 5911d36cbafef6..1e741010c741ec 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/HoistVectorTransfers.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/HoistVectorTransfers.cpp
@@ -1,15 +1,11 @@
-//===-HoistVectorTransfers.cpp -----------------------------------------*-
-// C++-*-===//
+//===- HoistVectorTransfers.cpp ---------------------------------------*- C++-*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
-//
-// This file implements tile configuration hoisting on parallel loops.
-//
-//===----------------------------------------------------------------------===//
+
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
@@ -25,6 +21,7 @@
using namespace mlir;
+// Function to retrives vector transfer read operations (Acc, Lhs, and Rhs) from contraction operation.
static FailureOr<SmallVector<vector::TransferReadOp>>
getContractOperands(vector::ContractionOp contractOp) {
SmallVector<vector::TransferReadOp> list;
@@ -38,6 +35,7 @@ getContractOperands(vector::ContractionOp contractOp) {
return list;
}
+// Function to retrive subview from vector transfer read operation.
static FailureOr<SmallVector<memref::SubViewOp>>
getReadOperands(SmallVector<vector::TransferReadOp> readOps) {
SmallVector<memref::SubViewOp> list;
@@ -50,6 +48,7 @@ getReadOperands(SmallVector<vector::TransferReadOp> readOps) {
return list;
}
+// Function to retrive the tiled nested loop structure (m->n->reduction->k) for the contract operation
static FailureOr<SmallVector<scf::ForOp>>
getNestedLoop(vector::ContractionOp contractOp) {
SmallVector<scf::ForOp> list;
@@ -64,6 +63,7 @@ getNestedLoop(vector::ContractionOp contractOp) {
return list;
}
+// Function to check iv of nested loops matches with the subview
static LogicalResult checkNestedLoop(SmallVector<scf::ForOp> loops,
SmallVector<memref::SubViewOp> subviews) {
auto subviewOpLhsOffsets = subviews[0].getOffsets();
@@ -90,6 +90,40 @@ static LogicalResult checkNestedLoop(SmallVector<scf::ForOp> loops,
return success();
}
+/// Hoist vector transfer read and write operations for the tiled batch reduce matmul operation
+/// outside the reduction and k-loop.
+///
+/// As an example, the following pseudo-code will be rewritten
+/// scf.for %arg3 = %c0 to %c32 step %c4 // m-loop
+/// scf.for %arg4 = %c0 to %c64 step %c64 // n-loop
+/// %subview_2 = memref.subview %subview[%arg3, %arg4] [4, 64] [1, 1]
+/// scf.for %arg5 = %c0 to %c24 step %c1 // reduction-loop
+/// scf.for %arg6 = %c0 to %c64 step %c1 // k-loop
+/// %subview_3 = memref.subview %subview_1[%arg5, %arg3, %arg6] [1, 4, 1] [1, 1, 1]
+/// %subview_4 = memref.subview %0[%arg5, %arg6, %arg4] [1, 1, 64] [1, 1, 1]
+/// %1 = vector.transfer_read %subview_3[%c0, %c0, %c0], %cst {in_bounds = [true, true, true]}
+/// %2 = vector.transfer_read %subview_4[%c0, %c0, %c0], %cst {in_bounds = [true, true, true]}
+/// %3 = vector.transfer_read %subview_2[%c0, %c0], %cst {in_bounds = [true, true]}
+/// %4 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %1, %2, %3
+/// vector.transfer_write %4, %subview_2[%c0, %c0] {in_bounds = [true, true]}
+/// to:
+/// scf.for %arg3 = %c0 to %c32 step %c4
+/// scf.for %arg4 = %c0 to %c64 step %c64
+/// %subview_2 = memref.subview %subview[%arg3, %arg4] [4, 64] [1, 1]
+/// %1 = vector.transfer_read %subview_2[%c0, %c0], %cst {in_bounds = [true, true]}
+/// %2 = scf.for %arg5 = %c0 to %c24 step %c1 iter_args(%arg6 = %1) -> (!type) {
+/// %3 = scf.for %arg7 = %c0 to %c64 step %c1 iter_args(%arg8 = %arg6) -> (!type) {
+/// %subview_3 = memref.subview %subview_1[%arg5, %arg3, %arg7] [1, 4, 1] [1, 1, 1]
+/// %subview_4 = memref.subview %0[%arg5, %arg7, %arg4] [1, 1, 64] [1, 1, 1]
+/// %4 = vector.transfer_read %subview_3[%c0, %c0, %c0], %cst {in_bounds = [true, true, true]}
+/// %5 = vector.transfer_read %subview_4[%c0, %c0, %c0], %cst {in_bounds = [true, true, true]}
+/// %6 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %4, %5, %arg8
+/// scf.yield %6 : !type
+/// }
+/// scf.yield %3 : !type
+/// }
+/// vector.transfer_write %2, %subview_2[%c0, %c0] {in_bounds = [true, true]}
+///
struct HoistVectorTransferOp : OpRewritePattern<vector::ContractionOp> {
using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
@@ -98,7 +132,6 @@ struct HoistVectorTransferOp : OpRewritePattern<vector::ContractionOp> {
// Check the vector contract operation satisfies the required pattern.
// Check the Acc, Lhs, and Rhs of contract operation
-
auto operands = getContractOperands(contractOp);
if (failed(operands))
return rewriter.notifyMatchFailure(contractOp,
@@ -145,7 +178,7 @@ struct HoistVectorTransferOp : OpRewritePattern<vector::ContractionOp> {
if (K != 1)
return rewriter.notifyMatchFailure(contractOp, "K dim is not 1");
- // Check whether the linalg tiling + vector contract pattern matches for the
+ // Check whether the BR-matmul tiling + vector contract pattern matches for the
// 4-nested loop structure
auto loops = getNestedLoop(contractOp);
if (failed(loops))
diff --git a/mlir/lib/Dialect/Linalg/Transforms/VectorContractToFMA.cpp b/mlir/lib/Dialect/Linalg/Transforms/VectorContractToFMA.cpp
index 2a8132b93bdcb0..4d3dac6a2b4d03 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/VectorContractToFMA.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/VectorContractToFMA.cpp
@@ -1,4 +1,3 @@
-
//===--------------- VectorContractToFMA.cpp ------------*- C++-*-===//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -67,6 +66,61 @@ struct TransformationContext {
enum class MatMulType { Standard, Batch, BatchReduce };
+
+/// Pattern to lower vector contraction operation (batch reduce matmul) for GEMM of size MxN to
+/// sequence of vector FMAs.
+///
+/// As an example, the following pseudo-code will be rewritten
+/// %subview_1 = memref.subview %subview[%arg3, %arg4] [4, 64] [1, 1]
+/// %1 = vector.transfer_read %subview_1[%c0, %c0], %cst {in_bounds = [true, true]}
+/// %2 = scf.for %arg5 = %c0 to %c24 step %c1 iter_args(%arg6 = %1) -> (!type) {
+/// %3 = scf.for %arg7 = %c0 to %c64 step %c1 iter_args(%arg8 = %arg6) -> (!type) {
+/// %subview_3 = memref.subview %subview_0[%arg5, %arg3, %arg7] [1, 4, 1] [1, 1, 1]
+/// %subview_4 = memref.subview %0[%arg5, %arg7, %arg4] [1, 1, 64] [1, 1, 1]
+/// %4 = vector.transfer_read %subview_3[%c0, %c0, %c0], %cst {in_bounds = [true, true, true]}
+/// %5 = vector.transfer_read %subview_4[%c0, %c0, %c0], %cst {in_bounds = [true, true, true]}
+/// %6 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %4, %5, %arg8
+/// scf.yield %6 : !type
+/// }
+/// scf.yield %3 : !type
+/// }
+/// vector.transfer_write %2, %subview_1[%c0, %c0] {in_bounds = [true, true]}
+/// to:
+/// %subview_1 = memref.subview %subview[%arg3, %arg4] [4, 64] [1, 1]
+/// %subview_2 = memref.subview %subview_1[0, 0] [1, 64] [1, 1]
+/// %subview_3 = memref.subview %subview_1[1, 0] [1, 64] [1, 1]
+/// %subview_4 = memref.subview %subview_1[2, 0] [1, 64] [1, 1]
+/// %subview_5 = memref.subview %subview_1[3, 0] [1, 64] [1, 1]
+/// %1 = vector.load %subview_2[%c0, %c0]
+/// %2 = vector.load %subview_3[%c0, %c0]
+/// %3 = vector.load %subview_4[%c0, %c0]
+/// %4 = vector.load %subview_5[%c0, %c0]
+/// %5:4 = scf.for %arg5 = %c0 to %c24 step %c1 iter_args(%arg6 = %1, %arg7 = %2, %arg8 = %3, %arg9 = %4) -> (!type, !type, !type, !type) {
+/// %6:4 = scf.for %arg10 = %c0 to %c64 step %c1 iter_args(%arg11 = %arg6, %arg12 = %arg7, %arg13 = %arg8, %arg14 = %arg9) -> (!type, !type, !type, !type) {
+/// %subview_6 = memref.subview %subview_0[%arg5, %arg3, %arg10] [1, 4, 1] [1, 1, 1]
+/// %7 = memref.load %subview_6[%c0, %c0, %c0]
+/// %8 = vector.broadcast %7 : f32 to !type
+/// %9 = memref.load %subview_6[%c0, %c1, %c0]
+/// %10 = vector.broadcast %9 : f32 to !type
+/// %11 = memref.load %subview_6[%c0, %c2, %c0]
+/// %12 = vector.broadcast %11 : f32 to !type
+/// %13 = memref.load %subview_6[%c0, %c3, %c0]
+/// %14 = vector.broadcast %13 : f32 to !type
+/// %subview_7 = memref.subview %0[%arg5, %arg10, %arg4] [1, 1, 64] [1, 1, 1]
+/// %15 = vector.load %subview_7[%c0, %c0, %c0]
+/// %16 = vector.fma %8, %15, %arg11 : !type
+/// %17 = vector.fma %10, %15, %arg12 : !type
+/// %18 = vector.fma %12, %15, %arg13 : !type
+/// %19 = vector.fma %14, %15, %arg14 : !type
+/// scf.yield %16, %17, %18, %19 : !type, !type, !type, !type
+/// }
+/// scf.yield %6#0, %6#1, %6#2, %6#3 : !type, !type, !type, !type
+/// }
+/// vector.store %5#0, %subview_2[%c0, %c0]
+/// vector.store %5#1, %subview_3[%c0, %c0]
+/// vector.store %5#2, %subview_4[%c0, %c0]
+/// vector.store %5#3, %subview_5[%c0, %c0])
+///
struct VectorContractToFMA
: public OpRewritePattern<vector::ContractionOp> {
using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
diff --git a/mlir/test/Dialect/Linalg/hoist-vector-transfer.mlir b/mlir/test/Dialect/Linalg/hoist-vector-transfer.mlir
index f1f24e4e53cb66..3b57f159108ead 100644
--- a/mlir/test/Dialect/Linalg/hoist-vector-transfer.mlir
+++ b/mlir/test/Dialect/Linalg/hoist-vector-transfer.mlir
@@ -4,7 +4,7 @@
#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
#map2 = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
memref.global "private" constant @__constant_24x64x64xf32 : memref<24x64x64xf32> = dense<1.000000e+00> {alignment = 64 : i64}
- func.func @simple_gemm(%arg0: memref<8x24x32x64xf32>) -> memref<8x24x32x64xf32> {
+ func.func @tiled_gemm_hoist_vector_transfer_operations(%arg0: memref<8x24x32x64xf32>) -> memref<8x24x32x64xf32> {
%cst = arith.constant 0.000000e+00 : f32
%cst_0 = arith.constant dense<0.000000e+00> : vector<32x64xf32>
%c1 = arith.constant 1 : index
@@ -46,7 +46,7 @@
// CHECK-LABEL: memref.global "private" constant @__constant_24x64x64xf32 : memref<24x64x64xf32> = dense<1.000000e+00> {alignment = 64 : i64}
-// CHECK-LABEL: func.func @simple_gemm(
+// CHECK-LABEL: func.func @tiled_gemm_hoist_vector_transfer_operations(
// CHECK-SAME: %[[VAL_0:.*]]: memref<8x24x32x64xf32>) -> memref<8x24x32x64xf32> {
// CHECK: %[[VAL_1:.*]] = arith.constant 0.000000e+00 : f32
// CHECK: %[[VAL_2:.*]] = arith.constant dense<0.000000e+00> : vector<32x64xf32>
@@ -95,3 +95,92 @@ module attributes {transform.with_named_sequence} {
transform.yield
}
}
+
+
+// -----
+
+#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
+#map2 = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+module {
+ memref.global "private" constant @__constant_24x64x64xf32 : memref<24x64x64xf32> = dense<1.000000e+00> {alignment = 64 : i64}
+ func.func @gemm_without_tiling_so_no_hoisting(%arg0: memref<8x24x32x64xf32>) -> memref<8x24x32x64xf32> {
+ %cst = arith.constant 0.000000e+00 : f32
+ %cst_0 = arith.constant dense<0.000000e+00> : vector<32x64xf32>
+ %c0 = arith.constant 0 : index
+ %0 = memref.get_global @__constant_24x64x64xf32 : memref<24x64x64xf32>
+ %alloc = memref.alloc() {alignment = 64 : i64} : memref<8x24x32x64xf32>
+ scf.forall (%arg1, %arg2) in (8, 24) {
+ %subview = memref.subview %alloc[%arg1, %arg2, 0, 0] [1, 1, 32, 64] [1, 1, 1, 1] : memref<8x24x32x64xf32> to memref<32x64xf32, strided<[64, 1], offset: ?>>
+ vector.transfer_write %cst_0, %subview[%c0, %c0] {in_bounds = [true, true]} : vector<32x64xf32>, memref<32x64xf32, strided<[64, 1], offset: ?>>
+ %subview_1 = memref.subview %arg0[%arg1, 0, 0, 0] [1, 24, 32, 64] [1, 1, 1, 1] : memref<8x24x32x64xf32> to memref<24x32x64xf32, strided<[2048, 64, 1], offset: ?>>
+ %1 = vector.transfer_read %subview_1[%c0, %c0, %c0], %cst {in_bounds = [true, true, true]} : memref<24x32x64xf32, strided<[2048, 64, 1], offset: ?>>, vector<24x32x64xf32>
+ %2 = vector.transfer_read %0[%c0, %c0, %c0], %cst {in_bounds = [true, true, true]} : memref<24x64x64xf32>, vector<24x64x64xf32>
+ %3 = vector.transfer_read %subview[%c0, %c0], %cst {in_bounds = [true, true]} : memref<32x64xf32, strided<[64, 1], offset: ?>>, vector<32x64xf32>
+ %4 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %1, %2, %3 : vector<24x32x64xf32>, vector<24x64x64xf32> into vector<32x64xf32>
+ vector.transfer_write %4, %subview[%c0, %c0] {in_bounds = [true, true]} : vector<32x64xf32>, memref<32x64xf32, strided<[64, 1], offset: ?>>
+ }
+ return %alloc : memref<8x24x32x64xf32>
+ }
+}
+
+// CHECK-LABEL: func.func @gemm_without_tiling_so_no_hoisting
+// CHECK: memref.subview
+// CHECK-NEXT: vector.transfer_write
+// CHECK-NEXT: memref.subview
+// CHECK-NEXT: vector.transfer_read
+// CHECK-NEXT: vector.transfer_read
+// CHECK-NEXT: vector.transfer_read
+// CHECK-NEXT: vector.contract
+// CHECK-NEXT: vector.transfer_write
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.vector.hoist_vector_transfer
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
+
+// -----
+
+#map = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+module {
+ func.func @gemm_with_args_so_no_hoisting(%arg0: tensor<4x1xf32>, %arg1: tensor<1x64xf32>, %arg2: tensor<4x64xf32>) -> tensor<4x64xf32> {
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant 0.000000e+00 : f32
+ %0 = vector.transfer_read %arg0[%c0, %c0], %cst {in_bounds = [true, true]} : tensor<4x1xf32>, vector<4x1xf32>
+ %1 = vector.transfer_read %arg1[%c0, %c0], %cst {in_bounds = [true, true]} : tensor<1x64xf32>, vector<1x64xf32>
+ %2 = vector.transfer_read %arg2[%c0, %c0], %cst {in_bounds = [true, true]} : tensor<4x64xf32>, vector<4x64xf32>
+ %3 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %0, %1, %2 : vector<4x1xf32>, vector<1x64xf32> into vector<4x64xf32>
+ %4 = vector.transfer_write %3, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<4x64xf32>, tensor<4x64xf32>
+ return %4 : tensor<4x64xf32>
+ }
+}
+
+
+// CHECK-LABEL: func.func @gemm_with_args_so_no_hoisting
+// CHECK: vector.transfer_read
+// CHECK-NEXT: vector.transfer_read
+// CHECK-NEXT: vector.transfer_read
+// CHECK-NEXT: vector.contract
+// CHECK-NEXT: vector.transfer_write
+// CHECK-NEXT: return
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.vector.hoist_vector_transfer
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
+
diff --git a/mlir/test/Dialect/Linalg/vector-contract-to-fma.mlir b/mlir/test/Dialect/Linalg/vector-contract-to-fma.mlir
index ba11074e3c9637..f18c0dcb573d70 100644
--- a/mlir/test/Dialect/Linalg/vector-contract-to-fma.mlir
+++ b/mlir/test/Dialect/Linalg/vector-contract-to-fma.mlir
@@ -1,106 +1,41 @@
// RUN: mlir-opt %s -transform-interpreter -cse -split-input-file | FileCheck %s
#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
-#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
+#mapTransposeB = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
#map2 = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
- memref.global "private" constant @__constant_24x64x64xf32 : memref<24x64x64xf32> = dense<1.000000e+00> {alignment = 64 : i64}
- func.func @simple_gemm(%arg0: memref<8x24x32x64xf32>) -> memref<8x24x32x64xf32> {
+
+ func.func @transpose_matrix_no_conversion_to_fma(%arg0: memref<16x32x128xf32>, %arg1: memref<16x128x64xf32>, %arg2: memref<32x64xf32>) {
%cst = arith.constant 0.000000e+00 : f32
- %cst_0 = arith.constant dense<0.000000e+00> : vector<32x64xf32>
%c1 = arith.constant 1 : index
- %c24 = arith.constant 24 : index
+ %c16 = arith.constant 16 : index
%c64 = arith.constant 64 : index
+ %c128 = arith.constant 128 : index
%c4 = arith.constant 4 : index
%c32 = arith.constant 32 : index
%c0 = arith.constant 0 : index
- %0 = memref.get_global @__constant_24x64x64xf32 : memref<24x64x64xf32>
- %alloc = memref.alloc() {alignment = 64 : i64} : memref<8x24x32x64xf32>
- scf.forall (%arg1, %arg2) in (8, 24) {
- %subview = memref.subview %alloc[%arg1, %arg2, 0, 0] [1, 1, 32, 64] [1, 1, 1, 1] : memref<8x24x32x64xf32> to memref<32x64xf32, strided<[64, 1], offset: ?>>
- vector.transfer_write %cst_0, %subview[%c0, %c0] {in_bounds = [true, true]} : vector<32x64xf32>, memref<32x64xf32, strided<[64, 1], offset: ?>>
- %subview_1 = memref.subview %arg0[%arg1, 0, 0, 0] [1, 24, 32, 64] [1, 1, 1, 1] : memref<8x24x32x64xf32> to memref<24x32x64xf32, strided<[2048, 64, 1], offset: ?>>
- scf.for %arg3 = %c0 to %c32 step %c4 {
- scf.for %arg4 = %c0 to %c64 step %c64 {
- %subview_2 = memref.subview %subview[%arg3, %arg4] [4, 64] [1, 1] : memref<32x64xf32, strided<[64, 1], offset: ?>> to memref<4x64xf32, strided<[64, 1], offset: ?>>
- %1 = vector.transfer_read %subview_2[%c0, %c0], %cst {in_bounds = [true, true]} : memref<4x64xf32, strided<[64, 1], offset: ?>>, vector<4x64xf32>
- %2 = scf.for %arg5 = %c0 to %c24 step %c1 iter_args(%arg6 = %1) -> (vector<4x64xf32>) {
- %3 = scf.for %arg7 = %c0 to %c64 step %c1 iter_args(%arg8 = %arg6) -> (vector<4x64xf32>) {
- %subview_3 = memref.subview %subview_1[%arg5, %arg3, %arg7] [1, 4, 1] [1, 1, 1] : memref<24x32x64xf32, strided<[2048, 64, 1], offset: ?>> to memref<1x4x1xf32, strided<[2048, 64, 1], offset: ?>>
- %subview_4 = memref.subview %0[%arg5, %arg7, %arg4] [1, 1, 64] [1, 1, 1] : memref<24x64x64xf32> to memref<1x1x64xf32, strided<[4096, 64, 1], offset: ?>>
- %4 = vector.transfer_read %subview_3[%c0, %c0, %c0], %cst {in_bounds = [true, true, true]} : memref<1x4x1xf32, strided<[2048, 64, 1], offset: ?>>, vector<1x4x1xf32>
- %5 = vector.transfer_read %subview_4[%c0, %c0, %c0], %cst {in_bounds = [true, true, true]} : memref<1x1x64xf32, strided<[4096, 64, 1], offset: ?>>, vector<1x1x64xf32>
- %6 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %4, %5, %arg8 : vector<1x4x1xf32>, vector<1x1x64xf32> into vector<4x64xf32>
- scf.yield %6 : vector<4x64xf32>
- }
+
+ scf.for %arg5 = %c0 to %c32 step %c4 {
+ scf.for %arg6 = %c0 to %c128 step %c64 {
+ %subview_2 = memref.subview %arg2[%arg5, %arg6] [4, 64] [1, 1] : memref<32x64xf32> to memref<4x64xf32, strided<[64, 1], offset: ?>>
+ %2 = vector.transfer_read %subview_2[%c0, %c0], %cst {in_bounds = [true, true]} : memref<4x64xf32, strided<[64, 1], offset: ?>>, vector<4x64xf32>
+ %con = scf.for %arg7 = %c0 to %c16 step %c1 iter_args(%argcon = %2) -> vector<4x64xf32> {
+ %con1 = scf.for %arg8 = %c0 to %c64 step %c1 iter_args(%argcon1 = %argcon) -> vector<4x64xf32> {
+ %subview_3 = memref.subview %arg0[%arg7, %arg5, %arg8] [1, 4, 1] [1, 1, 1] : memref<16x32x128xf32> to memref<1x4x1xf32, strided<[4096, 128, 1], offset: ?>>
+ %subview_4 = memref.subview %arg1[%arg7, %arg8, %arg6] [1, 1, 64] [1, 1, 1] : memref<16x128x64xf32> to memref<1x1x64xf32, strided<[8192, 64, 1], offset: ?>>
+ %0 = vector.transfer_read %subview_3[%c0, %c0, %c0], %cst {in_bounds = [true, true, true]} : memref<1x4x1xf32, strided<[4096, 128, 1], offset: ?>>, vector<1x4x1xf32>
+ %1 = vector.transfer_read %subview_4[%c0, %c0, %c0], %cst {permutation_map = affine_map<(d0, d1, d2) -> (d0, d2, d1)>, in_bounds = [true, true, true]} : memref<1x1x64xf32, strided<[8192, 64, 1], offset: ?>>, vector<1x64x1xf32>
+ %3 = vector.contract {indexing_maps = [#map, #mapTransposeB, #map2], iterator_types = ["reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %0, %1, %argcon1 : vector<1x4x1xf32>, vector<1x64x1xf32> into vector<4x64xf32>
scf.yield %3 : vector<4x64xf32>
}
- vector.transfer_write %2, %subview_2[%c0, %c0] {in_bounds = [true, true]} : vector<4x64xf32>, memref<4x64xf32, strided<[64, 1], offset: ?>>
+ scf.yield %con1 : vector<4x64xf32>
}
+ vector.transfer_write %con, %subview_2[%c0, %c0] {in_bounds = [true, true]} : vector<4x64xf32>, memref<4x64xf32, strided<[64, 1], offset: ?>>
}
}
- return %alloc : memref<8x24x32x64xf32>
+ return
}
-// CHECK-LABEL: memref.global "private" constant @__constant_24x64x64xf32 : memref<24x64x64xf32> = dense<1.000000e+00> {alignment = 64 : i64}
-
-// CHECK-LABEL: func.func @simple_gemm(
-// CHECK-SAME: %[[VAL_0:.*]]: memref<8x24x32x64xf32>) -> memref<8x24x32x64xf32> {
-// CHECK: %[[VAL_1:.*]] = arith.constant 3 : index
-// CHECK: %[[VAL_2:.*]] = arith.constant 2 : index
-// CHECK: %[[VAL_3:.*]] = arith.constant dense<0.000000e+00> : vector<32x64xf32>
-// CHECK: %[[VAL_4:.*]] = arith.constant 1 : index
-// CHECK: %[[VAL_5:.*]] = arith.constant 24 : index
-// CHECK: %[[VAL_6:.*]] = arith.constant 64 : index
-// CHECK: %[[VAL_7:.*]] = arith.constant 4 : index
-// CHECK: %[[VAL_8:.*]] = arith.constant 32 : index
-// CHECK: %[[VAL_9:.*]] = arith.constant 0 : index
-// CHECK: %[[VAL_10:.*]] = memref.get_global @__constant_24x64x64xf32 : memref<24x64x64xf32>
-// CHECK: %[[VAL_11:.*]] = memref.alloc() {alignment = 64 : i64} : memref<8x24x32x64xf32>
-// CHECK: scf.forall (%[[VAL_12:.*]], %[[VAL_13:.*]]) in (8, 24) {
-// CHECK: %[[VAL_14:.*]] = memref.subview %[[VAL_11]]{{\[}}%[[VAL_12]], %[[VAL_13]], 0, 0] [1, 1, 32, 64] [1, 1, 1, 1] : memref<8x24x32x64xf32> to memref<32x64xf32, strided<[64, 1], offset: ?>>
-// CHECK: vector.transfer_write %[[VAL_3]], %[[VAL_14]]{{\[}}%[[VAL_9]], %[[VAL_9]]] {in_bounds = [true, true]} : vector<32x64xf32>, memref<32x64xf32, strided<[64, 1], offset: ?>>
-// CHECK: %[[VAL_15:.*]] = memref.subview %[[VAL_0]]{{\[}}%[[VAL_12]], 0, 0, 0] [1, 24, 32, 64] [1, 1, 1, 1] : memref<8x24x32x64xf32> to memref<24x32x64xf32, strided<[2048, 64, 1], offset: ?>>
-// CHECK: scf.for %[[VAL_16:.*]] = %[[VAL_9]] to %[[VAL_8]] step %[[VAL_7]] {
-// CHECK: scf.for %[[VAL_17:.*]] = %[[VAL_9]] to %[[VAL_6]] step %[[VAL_6]] {
-// CHECK: %[[VAL_18:.*]] = memref.subview %[[VAL_14]]{{\[}}%[[VAL_16]], %[[VAL_17]]] [4, 64] [1, 1] : memref<32x64xf32, strided<[64, 1], offset: ?>> to memref<4x64xf32, strided<[64, 1], offset: ?>>
-// CHECK: %[[VAL_19:.*]] = memref.subview %[[VAL_18]][0, 0] [1, 64] [1, 1] : memref<4x64xf32, strided<[64, 1], offset: ?>> to memref<1x64xf32, strided<[64, 1], offset: ?>>
-// CHECK: %[[VAL_20:.*]] = memref.subview %[[VAL_18]][1, 0] [1, 64] [1, 1] : memref<4x64xf32, strided<[64, 1], offset: ?>> to memref<1x64xf32, strided<[64, 1], offset: ?>>
-// CHECK: %[[VAL_21:.*]] = memref.subview %[[VAL_18]][2, 0] [1, 64] [1, 1] : memref<4x64xf32, strided<[64, 1], offset: ?>> to memref<1x64xf32, strided<[64, 1], offset: ?>>
-// CHECK: %[[VAL_22:.*]] = memref.subview %[[VAL_18]][3, 0] [1, 64] [1, 1] : memref<4x64xf32, strided<[64, 1], offset: ?>> to memref<1x64xf32, strided<[64, 1], offset: ?>>
-// CHECK: %[[VAL_23:.*]] = vector.load %[[VAL_19]]{{\[}}%[[VAL_9]], %[[VAL_9]]] : memref<1x64xf32, strided<[64, 1], offset: ?>>, vector<64xf32>
-// CHECK: %[[VAL_24:.*]] = vector.load %[[VAL_20]]{{\[}}%[[VAL_9]], %[[VAL_9]]] : memref<1x64xf32, strided<[64, 1], offset: ?>>, vector<64xf32>
-// CHECK: %[[VAL_25:.*]] = vector.load %[[VAL_21]]{{\[}}%[[VAL_9]], %[[VAL_9]]] : memref<1x64xf32, strided<[64, 1], offset: ?>>, vector<64xf32>
-// CHECK: %[[VAL_26:.*]] = vector.load %[[VAL_22]]{{\[}}%[[VAL_9]], %[[VAL_9]]] : memref<1x64xf32, strided<[64, 1], offset: ?>>, vector<64xf32>
-// CHECK: %[[VAL_27:.*]]:4 = scf.for %[[VAL_28:.*]] = %[[VAL_9]] to %[[VAL_5]] step %[[VAL_4]] iter_args(%[[VAL_29:.*]] = %[[VAL_23]], %[[VAL_30:.*]] = %[[VAL_24]], %[[VAL_31:.*]] = %[[VAL_25]], %[[VAL_32:.*]] = %[[VAL_26]]) -> (vector<64xf32>, vector<64xf32>, vector<64xf32>, vector<64xf32>) {
-// CHECK: %[[VAL_33:.*]]:4 = scf.for %[[VAL_34:.*]] = %[[VAL_9]] to %[[VAL_6]] step %[[VAL_4]] iter_args(%[[VAL_35:.*]] = %[[VAL_29]], %[[VAL_36:.*]] = %[[VAL_30]], %[[VAL_37:.*]] = %[[VAL_31]], %[[VAL_38:.*]] = %[[VAL_32]]) -> (vector<64xf32>, vector<64xf32>, vector<64xf32>, vector<64xf32>) {
-// CHECK: %[[VAL_39:.*]] = memref.subview %[[VAL_15]]{{\[}}%[[VAL_28]], %[[VAL_16]], %[[VAL_34]]] [1, 4, 1] [1, 1, 1] : memref<24x32x64xf32, strided<[2048, 64, 1], offset: ?>> to memref<1x4x1xf32, strided<[2048, 64, 1], offset: ?>>
-// CHECK: %[[VAL_40:.*]] = memref.load %[[VAL_39]]{{\[}}%[[VAL_9]], %[[VAL_9]], %[[VAL_9]]] : memref<1x4x1xf32, strided<[2048, 64, 1], offset: ?>>
-// CHECK: %[[VAL_41:.*]] = vector.broadcast %[[VAL_40]] : f32 to vector<64xf32>
-// CHECK: %[[VAL_42:.*]] = memref.load %[[VAL_39]]{{\[}}%[[VAL_9]], %[[VAL_4]], %[[VAL_9]]] : memref<1x4x1xf32, strided<[2048, 64, 1], offset: ?>>
-// CHECK: %[[VAL_43:.*]] = vector.broadcast %[[VAL_42]] : f32 to vector<64xf32>
-// CHECK: %[[VAL_44:.*]] = memref.load %[[VAL_39]]{{\[}}%[[VAL_9]], %[[VAL_2]], %[[VAL_9]]] : memref<1x4x1xf32, strided<[2048, 64, 1], offset: ?>>
-// CHECK: %[[VAL_45:.*]] = vector.broadcast %[[VAL_44]] : f32 to vector<64xf32>
-// CHECK: %[[VAL_46:.*]] = memref.load %[[VAL_39]]{{\[}}%[[VAL_9]], %[[VAL_1]], %[[VAL_9]]] : memref<1x4x1xf32, strided<[2048, 64, 1], offset: ?>>
-// CHECK: %[[VAL_47:.*]] = vector.broadcast %[[VAL_46]] : f32 to vector<64xf32>
-// CHECK: %[[VAL_48:.*]] = memref.subview %[[VAL_10]]{{\[}}%[[VAL_28]], %[[VAL_34]], %[[VAL_17]]] [1, 1, 64] [1, 1, 1] : memref<24x64x64xf32> to memref<1x1x64xf32, strided<[4096, 64, 1], offset: ?>>
-// CHECK: %[[VAL_49:.*]] = vector.load %[[VAL_48]]{{\[}}%[[VAL_9]], %[[VAL_9]], %[[VAL_9]]] : memref<1x1x64xf32, strided<[4096, 64, 1], offset: ?>>, vector<64xf32>
-// CHECK: %[[VAL_50:.*]] = vector.fma %[[VAL_41]], %[[VAL_49]], %[[VAL_35]] : vector<64xf32>
-// CHECK: %[[VAL_51:.*]] = vector.fma %[[VAL_43]], %[[VAL_49]], %[[VAL_36]] : vector<64xf32>
-// CHECK: %[[VAL_52:.*]] = vector.fma %[[VAL_45]], %[[VAL_49]], %[[VAL_37]] : vector<64xf32>
-// CHECK: %[[VAL_53:.*]] = vector.fma %[[VAL_47]], %[[VAL_49]], %[[VAL_38]] : vector<64xf32>
-// CHECK: scf.yield %[[VAL_50]], %[[VAL_51]], %[[VAL_52]], %[[VAL_53]] : vector<64xf32>, vector<64xf32>, vector<64xf32>, vector<64xf32>
-// CHECK: }
-// CHECK: scf.yield %[[VAL_54:.*]]#0, %[[VAL_54]]#1, %[[VAL_54]]#2, %[[VAL_54]]#3 : vector<64xf32>, vector<64xf32>, vector<64xf32>, vector<64xf32>
-// CHECK: }
-// CHECK: vector.store %[[VAL_55:.*]]#0, %[[VAL_19]]{{\[}}%[[VAL_9]], %[[VAL_9]]] : memref<1x64xf32, strided<[64, 1], offset: ?>>, vector<64xf32>
-// CHECK: vector.store %[[VAL_55]]#1, %[[VAL_20]]{{\[}}%[[VAL_9]], %[[VAL_9]]] : memref<1x64xf32, strided<[64, 1], offset: ?>>, vector<64xf32>
-// CHECK: vector.store %[[VAL_55]]#2, %[[VAL_21]]{{\[}}%[[VAL_9]], %[[VAL_9]]] : memref<1x64xf32, strided<[64, 1], offset: ?>>, vector<64xf32>
-// CHECK: vector.store %[[VAL_55]]#3, %[[VAL_22]]{{\[}}%[[VAL_9]], %[[VAL_9]]] : memref<1x64xf32, strided<[64, 1], offset: ?>>, vector<64xf32>
-// CHECK: }
-// CHECK: }
-// CHECK: }
-// CHECK: return %[[VAL_11]] : memref<8x24x32x64xf32>
-// CHECK: }
+// CHECK-NOT: vector.fma
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
>From 8aa56e403e3a30eaae5d3deba6a73072b6f60df7 Mon Sep 17 00:00:00 2001
From: Arun Thangamani <arun.thangamani at intel.com>
Date: Mon, 6 Jan 2025 21:21:38 -0800
Subject: [PATCH 3/3] created a separate PR for vector.contract to fma
---
.../Linalg/TransformOps/LinalgTransformOps.td | 13 -
.../Dialect/Linalg/Transforms/Transforms.h | 5 -
.../TransformOps/LinalgTransformOps.cpp | 8 -
.../Dialect/Linalg/Transforms/CMakeLists.txt | 1 -
.../Linalg/Transforms/VectorContractToFMA.cpp | 410 ------------------
.../Dialect/Linalg/hoist-vector-transfer.mlir | 171 ++++----
.../Linalg/vector-contract-to-fma.mlir | 48 --
7 files changed, 78 insertions(+), 578 deletions(-)
delete mode 100644 mlir/lib/Dialect/Linalg/Transforms/VectorContractToFMA.cpp
delete mode 100644 mlir/test/Dialect/Linalg/vector-contract-to-fma.mlir
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 12b3ddb9c74904..6b890272bb6b49 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -106,7 +106,6 @@ def ApplyFoldAddIntoDestPatternsOp : Op<Transform_Dialect,
let assemblyFormat = "attr-dict";
}
-
def ApplyHoistVectorTransferPatternsOp : Op<Transform_Dialect,
"apply_patterns.vector.hoist_vector_transfer",
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
@@ -118,18 +117,6 @@ def ApplyHoistVectorTransferPatternsOp : Op<Transform_Dialect,
let assemblyFormat = "attr-dict";
}
-
-def ApplyVectorContractToFMAPatternsOp : Op<Transform_Dialect,
- "apply_patterns.vector.contract_to_fma",
- [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
- let description = [{
- Collects pattern to lower vector contraction operation (batch reduce matmul) for GEMM of size MxN to
- sequence of vector FMAs.
- }];
-
- let assemblyFormat = "attr-dict";
-}
-
def ApplyPadVectorizationPatternsOp : Op<Transform_Dialect,
"apply_patterns.linalg.pad_vectorization",
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 6f639b45408d87..8a06df4fed3632 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1824,15 +1824,10 @@ void populateConstantFoldLinalgOperations(RewritePatternSet &patterns,
/// suffices for achieving the sum.
void populateFoldAddIntoDestPatterns(RewritePatternSet &patterns);
-
/// Pattern to hoists the vector transfer reads/writes outside the reduction and
/// k-loop for batch reduce matmul operation if licm fails.
void populateHoistVectorTransferPatterns(RewritePatternSet &patterns);
-/// Pattern to lower vector contraction operation (batch reduce matmul) for GEMM of size MxN to
-/// sequence of vector FMAs.
-void populateVectorContractToFMAPatterns(RewritePatternSet &patterns);
-
/// Pattern to fuse a `tensor.pad` operation with the producer of its source,
/// if the producer is a `linalg` operation with all parallel iterator types.
void populateFuseTensorPadWithProducerLinalgOpPatterns(
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index ed838f8476486f..61a3db7302d8db 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -262,19 +262,11 @@ void transform::ApplyFoldAddIntoDestPatternsOp::populatePatterns(
linalg::populateFoldAddIntoDestPatterns(patterns);
}
-
void transform::ApplyHoistVectorTransferPatternsOp::populatePatterns(
RewritePatternSet &patterns) {
linalg::populateHoistVectorTransferPatterns(patterns);
}
-
-void transform::ApplyVectorContractToFMAPatternsOp::populatePatterns(
- RewritePatternSet &patterns) {
- linalg::populateVectorContractToFMAPatterns(patterns);
-}
-
-
void transform::ApplyPadVectorizationPatternsOp::populatePatterns(
RewritePatternSet &patterns) {
linalg::populatePadOpVectorizationPatterns(patterns);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
index 90d926201cd753..63758a654f803f 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
@@ -42,7 +42,6 @@ add_mlir_dialect_library(MLIRLinalgTransforms
Vectorization.cpp
WinogradConv2D.cpp
HoistVectorTransfers.cpp
- VectorContractToFMA.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Linalg
diff --git a/mlir/lib/Dialect/Linalg/Transforms/VectorContractToFMA.cpp b/mlir/lib/Dialect/Linalg/Transforms/VectorContractToFMA.cpp
deleted file mode 100644
index 4d3dac6a2b4d03..00000000000000
--- a/mlir/lib/Dialect/Linalg/Transforms/VectorContractToFMA.cpp
+++ /dev/null
@@ -1,410 +0,0 @@
-//===--------------- VectorContractToFMA.cpp ------------*- C++-*-===//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// This file implements lowering of vector contraction to vector fma.
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
-#include "mlir/Dialect/MemRef/IR/MemRef.h"
-#include "mlir/Dialect/SCF/IR/SCF.h"
-#include "mlir/Dialect/Tensor/IR/Tensor.h"
-#include "mlir/Dialect/Vector/IR/VectorOps.h"
-#include "mlir/Pass/Pass.h"
-#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
-
-#define DEBUG_TYPE "vector-contract-to-fma"
-
-using namespace mlir;
-
-/// Returns true if the \p map is transposed.
-static bool isTransposed(AffineMap map) {
- auto results = map.getResults();
- // Assert if the map does not have 3 or 4 inputs ([] m, n, k).
- assert((map.getNumInputs() == 3 || map.getNumInputs() == 4) &&
- "3 or 4 input dim expected");
- // Assert if the result is not 2D.
- assert(map.getNumResults() == 2 && "Only 2 output dim expected");
-
- // Check the last two dimensions for transposition.
- auto dimExpr0 = dyn_cast<AffineDimExpr>(results[0]);
- auto dimExpr1 = dyn_cast<AffineDimExpr>(results[1]);
- assert((dimExpr0 && dimExpr1) && "Unexpected dim expression");
-
- // Exclude output map result.
- bool isOutputResultMap =
- dimExpr0 ==
- mlir::getAffineDimExpr(map.getNumInputs() - 3, map.getContext()) &&
- dimExpr1 ==
- mlir::getAffineDimExpr(map.getNumInputs() - 2, map.getContext());
- assert(!isOutputResultMap && "Output result map not expected");
-
- // It's transposed if result found as (k, m) or (n, k), else not transposed.
- if ((dimExpr0 ==
- mlir::getAffineDimExpr(map.getNumInputs() - 1, map.getContext()) &&
- dimExpr1 ==
- mlir::getAffineDimExpr(map.getNumInputs() - 3, map.getContext())) ||
- (dimExpr0 ==
- mlir::getAffineDimExpr(map.getNumInputs() - 2, map.getContext()) &&
- dimExpr1 ==
- mlir::getAffineDimExpr(map.getNumInputs() - 1, map.getContext())))
- return true;
- return false;
-}
-
-
-// Structure to hold transformation context
-struct TransformationContext {
- scf::ForOp innerForOp;
- scf::ForOp outerForOp;
- scf::ForOp outermostLoop;
-};
-
-enum class MatMulType { Standard, Batch, BatchReduce };
-
-
-/// Pattern to lower vector contraction operation (batch reduce matmul) for GEMM of size MxN to
-/// sequence of vector FMAs.
-///
-/// As an example, the following pseudo-code will be rewritten
-/// %subview_1 = memref.subview %subview[%arg3, %arg4] [4, 64] [1, 1]
-/// %1 = vector.transfer_read %subview_1[%c0, %c0], %cst {in_bounds = [true, true]}
-/// %2 = scf.for %arg5 = %c0 to %c24 step %c1 iter_args(%arg6 = %1) -> (!type) {
-/// %3 = scf.for %arg7 = %c0 to %c64 step %c1 iter_args(%arg8 = %arg6) -> (!type) {
-/// %subview_3 = memref.subview %subview_0[%arg5, %arg3, %arg7] [1, 4, 1] [1, 1, 1]
-/// %subview_4 = memref.subview %0[%arg5, %arg7, %arg4] [1, 1, 64] [1, 1, 1]
-/// %4 = vector.transfer_read %subview_3[%c0, %c0, %c0], %cst {in_bounds = [true, true, true]}
-/// %5 = vector.transfer_read %subview_4[%c0, %c0, %c0], %cst {in_bounds = [true, true, true]}
-/// %6 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %4, %5, %arg8
-/// scf.yield %6 : !type
-/// }
-/// scf.yield %3 : !type
-/// }
-/// vector.transfer_write %2, %subview_1[%c0, %c0] {in_bounds = [true, true]}
-/// to:
-/// %subview_1 = memref.subview %subview[%arg3, %arg4] [4, 64] [1, 1]
-/// %subview_2 = memref.subview %subview_1[0, 0] [1, 64] [1, 1]
-/// %subview_3 = memref.subview %subview_1[1, 0] [1, 64] [1, 1]
-/// %subview_4 = memref.subview %subview_1[2, 0] [1, 64] [1, 1]
-/// %subview_5 = memref.subview %subview_1[3, 0] [1, 64] [1, 1]
-/// %1 = vector.load %subview_2[%c0, %c0]
-/// %2 = vector.load %subview_3[%c0, %c0]
-/// %3 = vector.load %subview_4[%c0, %c0]
-/// %4 = vector.load %subview_5[%c0, %c0]
-/// %5:4 = scf.for %arg5 = %c0 to %c24 step %c1 iter_args(%arg6 = %1, %arg7 = %2, %arg8 = %3, %arg9 = %4) -> (!type, !type, !type, !type) {
-/// %6:4 = scf.for %arg10 = %c0 to %c64 step %c1 iter_args(%arg11 = %arg6, %arg12 = %arg7, %arg13 = %arg8, %arg14 = %arg9) -> (!type, !type, !type, !type) {
-/// %subview_6 = memref.subview %subview_0[%arg5, %arg3, %arg10] [1, 4, 1] [1, 1, 1]
-/// %7 = memref.load %subview_6[%c0, %c0, %c0]
-/// %8 = vector.broadcast %7 : f32 to !type
-/// %9 = memref.load %subview_6[%c0, %c1, %c0]
-/// %10 = vector.broadcast %9 : f32 to !type
-/// %11 = memref.load %subview_6[%c0, %c2, %c0]
-/// %12 = vector.broadcast %11 : f32 to !type
-/// %13 = memref.load %subview_6[%c0, %c3, %c0]
-/// %14 = vector.broadcast %13 : f32 to !type
-/// %subview_7 = memref.subview %0[%arg5, %arg10, %arg4] [1, 1, 64] [1, 1, 1]
-/// %15 = vector.load %subview_7[%c0, %c0, %c0]
-/// %16 = vector.fma %8, %15, %arg11 : !type
-/// %17 = vector.fma %10, %15, %arg12 : !type
-/// %18 = vector.fma %12, %15, %arg13 : !type
-/// %19 = vector.fma %14, %15, %arg14 : !type
-/// scf.yield %16, %17, %18, %19 : !type, !type, !type, !type
-/// }
-/// scf.yield %6#0, %6#1, %6#2, %6#3 : !type, !type, !type, !type
-/// }
-/// vector.store %5#0, %subview_2[%c0, %c0]
-/// vector.store %5#1, %subview_3[%c0, %c0]
-/// vector.store %5#2, %subview_4[%c0, %c0]
-/// vector.store %5#3, %subview_5[%c0, %c0])
-///
-struct VectorContractToFMA
- : public OpRewritePattern<vector::ContractionOp> {
- using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(vector::ContractionOp op,
- PatternRewriter &rewriter) const override {
- if (op.getKind() != vector::CombiningKind::ADD)
- return rewriter.notifyMatchFailure(
- op, "Unsupported combining kind, only supports ADD at the moment)");
-
- auto maskableOp = cast<vector::MaskableOpInterface>(op.getOperation());
- if (maskableOp.isMasked())
- return rewriter.notifyMatchFailure(op, "Masked contractOp not supported");
-
- SmallVector<AffineMap, 3> maps = op.getIndexingMapsArray();
- if (llvm::any_of(
- maps, [](AffineMap map) { return !map.isProjectedPermutation(); }))
- return rewriter.notifyMatchFailure(op, "Unexpected map");
-
- // Check for the variant of matrix multiply.
- auto iteratorTypes = op.getIteratorTypesArray();
- MatMulType matmulType;
- unsigned outerDimIndex = 0;
- if (iteratorTypes.size() > 3) {
- outerDimIndex = iteratorTypes.size() - 4;
- matmulType =
- iteratorTypes[outerDimIndex] == vector::IteratorType::parallel
- ? MatMulType::Batch
- : MatMulType::BatchReduce;
- outerDimIndex++;
- } else if (iteratorTypes.size() == 3) {
- matmulType = MatMulType::Standard;
- } else {
- return rewriter.notifyMatchFailure(op, "Not a gemm");
- }
-
- if (matmulType == MatMulType::Batch)
- return rewriter.notifyMatchFailure(op, "Batch matmul not supported");
- if (iteratorTypes[outerDimIndex] != vector::IteratorType::parallel ||
- iteratorTypes[outerDimIndex + 1] != vector::IteratorType::parallel ||
- iteratorTypes[outerDimIndex + 2] != vector::IteratorType::reduction)
- return rewriter.notifyMatchFailure(op, "Not a gemm");
-
- SmallVector<Value, 4> results;
-
- auto lhs = op.getLhs();
- auto rhs = op.getRhs();
- auto acc = op.getAcc();
- auto lhsDefiningOp = lhs.getDefiningOp<vector::TransferReadOp>();
- auto rhsDefiningOp = rhs.getDefiningOp<vector::TransferReadOp>();
- auto accDefiningOp = acc.getDefiningOp<vector::TransferReadOp>();
- if (!lhsDefiningOp || !rhsDefiningOp)
- return failure();
-
- // Accumulator can be a TransferReadOp but must be coming from the chain of
- // iterargs of nested loop.
- if (accDefiningOp)
- return failure();
-
- // Make sure the inputs being read are whole tensor or subview.
- if (!llvm::all_of(lhsDefiningOp.getIndices(), isZeroIndex) ||
- !llvm::all_of(rhsDefiningOp.getIndices(), isZeroIndex)) {
- return failure();
- }
-
- auto lhsType = cast<ShapedType>(lhsDefiningOp.getType());
- auto rhsType = cast<ShapedType>(rhsDefiningOp.getType());
- // auto accType = acc.getType();
- // auto accType = cast<ShapedType>(accDefiningOp.getType());
-
- if (matmulType == MatMulType::BatchReduce &&
- (lhsType.getRank() != 3 || rhsType.getRank() != 3))
- return failure();
-
- if (matmulType == MatMulType::Standard &&
- (lhsType.getRank() != 2 || rhsType.getRank() != 2))
- return failure();
-
- // Check for non-transposed matrices.
- auto mapLHS = maps[0];
- auto mapRHS = maps[1];
- if (matmulType == MatMulType::BatchReduce) {
- mapLHS = mapLHS.dropResult(0);
- mapRHS = mapRHS.dropResult(0);
- }
- if (isTransposed(mapLHS) || isTransposed(mapRHS))
- return rewriter.notifyMatchFailure(
- op, "Transposed matrices are not expected");
-
- // Verify that the accumulator is coming through a chain of iterargs of
- // nested loop and it is define by 'TransferReadOp'.
- //
- struct TransformationContext ctx;
-
- ctx.innerForOp = op->getParentOfType<scf::ForOp>();
- if (!ctx.innerForOp)
- return failure();
- ctx.outerForOp = ctx.innerForOp->getParentOfType<scf::ForOp>();
- if (!ctx.outerForOp)
- return failure();
- ctx.outermostLoop = ctx.outerForOp->getParentOfType<scf::ForOp>();
- if (!ctx.outermostLoop)
- return failure();
-
- // Verify original inner loop has only one iterarg.
- auto origIterArgs = ctx.innerForOp.getRegionIterArgs();
- if (origIterArgs.size() != 1)
- return failure();
-
- // Verify chain, accumulator must be inner loop's iterarg.
- auto bbArg = dyn_cast<BlockArgument>(acc);
- if (!bbArg)
- return failure();
-
- // This block arg must be init arg, not induction variable.
- if (bbArg.getOwner() != ctx.innerForOp.getBody() ||
- bbArg.getArgNumber() == 0) {
- return failure();
- }
-
- // This iterarg must be intialized by outer loop's iterarg.
- auto innerInitValue =
- ctx.innerForOp.getInitArgs()[bbArg.getArgNumber() - 1];
- auto outerBBArg = dyn_cast<BlockArgument>(innerInitValue);
- if (!outerBBArg)
- return failure();
-
- // This block arg must be init arg, not induction variable.
- if (outerBBArg.getOwner() != ctx.outerForOp.getBody() ||
- outerBBArg.getArgNumber() == 0) {
- return failure();
- }
-
- // Outer loop's iterarg initializer must be a TransferReadOp.
- acc = ctx.outerForOp.getInitArgs()[outerBBArg.getArgNumber() - 1];
-
- // This must be defined by vector.transfer_read
- if (!acc.getDefiningOp<vector::TransferReadOp>())
- return failure();
-
- accDefiningOp = acc.getDefiningOp<vector::TransferReadOp>();
- if (!accDefiningOp)
- return failure();
-
- // Only 2-D output expected.
- auto accType = cast<ShapedType>(accDefiningOp.getType());
- if (accType.getRank() != 2)
- return failure();
-
- int64_t M = accType.getDimSize(0);
- int64_t N = accType.getDimSize(1);
- int64_t K = lhsType.getDimSize(lhsType.getRank() - 1);
-
- // K must be 1.
- if (K != 1)
- return failure();
-
- auto accSubview = accDefiningOp.getSource();
- Location loc = op.getLoc();
-
- // Create M different <1xN> subviews.
- auto memrefType = cast<MemRefType>(accSubview.getType());
- auto elementType = memrefType.getElementType();
- SmallVector<OpFoldResult> mixedSizes = {rewriter.getIndexAttr(K),
- rewriter.getIndexAttr(N)};
- SmallVector<OpFoldResult> mixedStrides = {rewriter.getIndexAttr(1),
- rewriter.getIndexAttr(1)};
-
- rewriter.setInsertionPoint(
- ctx.outermostLoop.getBody(),
- std::prev(ctx.outermostLoop.getBody()->end(), 1));
-
- Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0);
- SmallVector<Value, 4> subview_2_splits;
- for (int i = 0; i < M; i++) {
- SmallVector<OpFoldResult> mixedOffsets = {
- rewriter.getIndexAttr(i),
- rewriter.getIndexAttr(0),
- };
- auto split = rewriter.create<memref::SubViewOp>(
- loc, accSubview, mixedOffsets, mixedSizes, mixedStrides);
- subview_2_splits.push_back(split);
- }
-
- // Intialize each accumulator with a vector of size N
- SmallVector<Value, 4> initAccs;
- for (auto subview : subview_2_splits) {
- auto acc = rewriter.create<vector::LoadOp>(
- loc, VectorType::get({N}, elementType), subview, ValueRange{c0, c0});
- initAccs.push_back(acc);
- }
-
- // Create new outer loop with M different accumulators.
- auto newOuterForOp = rewriter.create<scf::ForOp>(
- loc, ctx.outerForOp.getLowerBound(), ctx.outerForOp.getUpperBound(),
- ctx.outerForOp.getStep(), initAccs,
- [&](OpBuilder &nestedBuilder, Location loc, Value iv,
- ValueRange iterArgs) {
- // Create new inner loop with M accumulators.
- auto newInnerForOp = nestedBuilder.create<scf::ForOp>(
- loc, ctx.innerForOp.getLowerBound(),
- ctx.innerForOp.getUpperBound(), ctx.innerForOp.getStep(),
- iterArgs,
- [&](OpBuilder &innerBuilder, Location loc, Value innerIv,
- ValueRange innerIterArgs) {
- IRMapping mapping;
- mapping.map(
- lhsDefiningOp.getSource().getDefiningOp()->getOperand(1),
- iv);
- mapping.map(
- lhsDefiningOp.getSource().getDefiningOp()->getOperand(3),
- innerIv);
- auto lhsClone = innerBuilder.clone(
- *lhsDefiningOp.getSource().getDefiningOp(), mapping);
-
- // Load and broadcast individual elements
- SmallVector<Value, 4> broadcasts;
- for (int i = 0; i < M; i++) {
- auto elem = innerBuilder.create<memref::LoadOp>(
- loc, lhsClone->getResult(0),
- ValueRange{
- c0,
- innerBuilder.create<arith::ConstantIndexOp>(loc, i),
- c0});
- auto bcast = innerBuilder.create<vector::BroadcastOp>(
- loc, VectorType::get({N}, elem.getType()), elem);
- broadcasts.push_back(bcast);
- }
-
- IRMapping rhsMapping;
- rhsMapping.map(
- rhsDefiningOp.getSource().getDefiningOp()->getOperand(1),
- iv);
- rhsMapping.map(
- rhsDefiningOp.getSource().getDefiningOp()->getOperand(2),
- innerIv);
- auto rhsClone = innerBuilder.clone(
- *rhsDefiningOp.getSource().getDefiningOp(), rhsMapping);
- auto rowVec = innerBuilder.create<vector::LoadOp>(
- loc, VectorType::get({N}, elementType),
- rhsClone->getResult(0), ValueRange{c0, c0, c0});
-
- // Create M different FMAs using broadcasts and current
- // accumulator values.
- for (int i = 0; i < M; i++) {
- auto fma = innerBuilder.create<vector::FMAOp>(
- loc, broadcasts[i], rowVec, innerIterArgs[i]);
- results.push_back(fma);
- }
-
- // Yield all M results
- innerBuilder.create<scf::YieldOp>(loc, results);
- });
-
- // Yield results from inner loop to outer loop
- nestedBuilder.create<scf::YieldOp>(loc, newInnerForOp.getResults());
- });
-
- Value matResult = ctx.outerForOp.getResult(0);
- Operation *writeOp;
- for (auto user : matResult.getUsers()) {
- writeOp = dyn_cast<vector::TransferWriteOp>(user);
- if (writeOp)
- break;
- }
-
- // Store final results back to original locations.
- if (writeOp) {
- for (int i = 0; i < M; i++) {
- rewriter.create<vector::StoreOp>(loc, newOuterForOp.getResult(i),
- subview_2_splits[i],
- ValueRange{c0, c0});
- }
- }
-
- // Erase original write.
- if (writeOp)
- rewriter.eraseOp(writeOp);
-
- return success();
- }
-
-};
-
-void linalg::populateVectorContractToFMAPatterns(RewritePatternSet &patterns) {
- patterns.add<VectorContractToFMA>(patterns.getContext());
-}
diff --git a/mlir/test/Dialect/Linalg/hoist-vector-transfer.mlir b/mlir/test/Dialect/Linalg/hoist-vector-transfer.mlir
index 3b57f159108ead..b0b164951d4b32 100644
--- a/mlir/test/Dialect/Linalg/hoist-vector-transfer.mlir
+++ b/mlir/test/Dialect/Linalg/hoist-vector-transfer.mlir
@@ -3,51 +3,48 @@
#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
#map2 = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
- memref.global "private" constant @__constant_24x64x64xf32 : memref<24x64x64xf32> = dense<1.000000e+00> {alignment = 64 : i64}
- func.func @tiled_gemm_hoist_vector_transfer_operations(%arg0: memref<8x24x32x64xf32>) -> memref<8x24x32x64xf32> {
- %cst = arith.constant 0.000000e+00 : f32
- %cst_0 = arith.constant dense<0.000000e+00> : vector<32x64xf32>
- %c1 = arith.constant 1 : index
- %c24 = arith.constant 24 : index
- %c64 = arith.constant 64 : index
- %c4 = arith.constant 4 : index
- %c32 = arith.constant 32 : index
- %c0 = arith.constant 0 : index
- %0 = memref.get_global @__constant_24x64x64xf32 : memref<24x64x64xf32>
- %alloc = memref.alloc() {alignment = 64 : i64} : memref<8x24x32x64xf32>
- scf.forall (%arg1, %arg2) in (8, 24) {
- %subview = memref.subview %alloc[%arg1, %arg2, 0, 0] [1, 1, 32, 64] [1, 1, 1, 1] : memref<8x24x32x64xf32> to memref<32x64xf32, strided<[64, 1], offset: ?>>
- vector.transfer_write %cst_0, %subview[%c0, %c0] {in_bounds = [true, true]} : vector<32x64xf32>, memref<32x64xf32, strided<[64, 1], offset: ?>>
- %subview_1 = memref.subview %arg0[%arg1, 0, 0, 0] [1, 24, 32, 64] [1, 1, 1, 1] : memref<8x24x32x64xf32> to memref<24x32x64xf32, strided<[2048, 64, 1], offset: ?>>
- scf.for %arg3 = %c0 to %c32 step %c4 {
- scf.for %arg4 = %c0 to %c64 step %c64 {
- %subview_2 = memref.subview %subview[%arg3, %arg4] [4, 64] [1, 1] : memref<32x64xf32, strided<[64, 1], offset: ?>> to memref<4x64xf32, strided<[64, 1], offset: ?>>
- scf.for %arg5 = %c0 to %c24 step %c1 {
- scf.for %arg6 = %c0 to %c64 step %c1 {
- %subview_3 = memref.subview %subview_1[%arg5, %arg3, %arg6] [1, 4, 1] [1, 1, 1] : memref<24x32x64xf32, strided<[2048, 64, 1], offset: ?>> to memref<1x4x1xf32, strided<[2048, 64, 1], offset: ?>>
- %subview_4 = memref.subview %0[%arg5, %arg6, %arg4] [1, 1, 64] [1, 1, 1] : memref<24x64x64xf32> to memref<1x1x64xf32, strided<[4096, 64, 1], offset: ?>>
- %1 = vector.transfer_read %subview_3[%c0, %c0, %c0], %cst {in_bounds = [true, true, true]} : memref<1x4x1xf32, strided<[2048, 64, 1], offset: ?>>, vector<1x4x1xf32>
- %2 = vector.transfer_read %subview_4[%c0, %c0, %c0], %cst {in_bounds = [true, true, true]} : memref<1x1x64xf32, strided<[4096, 64, 1], offset: ?>>, vector<1x1x64xf32>
- %3 = vector.transfer_read %subview_2[%c0, %c0], %cst {in_bounds = [true, true]} : memref<4x64xf32, strided<[64, 1], offset: ?>>, vector<4x64xf32>
- %4 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %1, %2, %3 : vector<1x4x1xf32>, vector<1x1x64xf32> into vector<4x64xf32>
- vector.transfer_write %4, %subview_2[%c0, %c0] {in_bounds = [true, true]} : vector<4x64xf32>, memref<4x64xf32, strided<[64, 1], offset: ?>>
- }
+memref.global "private" constant @__constant_24x64x64xf32 : memref<24x64x64xf32> = dense<1.000000e+00> {alignment = 64 : i64}
+func.func @tiled_gemm_hoist_vector_transfer_operations(%arg0: memref<8x24x32x64xf32>) -> memref<8x24x32x64xf32> {
+ %cst = arith.constant 0.000000e+00 : f32
+ %cst_0 = arith.constant dense<0.000000e+00> : vector<32x64xf32>
+ %c1 = arith.constant 1 : index
+ %c24 = arith.constant 24 : index
+ %c64 = arith.constant 64 : index
+ %c4 = arith.constant 4 : index
+ %c32 = arith.constant 32 : index
+ %c0 = arith.constant 0 : index
+ %0 = memref.get_global @__constant_24x64x64xf32 : memref<24x64x64xf32>
+ %alloc = memref.alloc() {alignment = 64 : i64} : memref<8x24x32x64xf32>
+ scf.forall (%arg1, %arg2) in (8, 24) {
+ %subview = memref.subview %alloc[%arg1, %arg2, 0, 0] [1, 1, 32, 64] [1, 1, 1, 1] : memref<8x24x32x64xf32> to memref<32x64xf32, strided<[64, 1], offset: ?>>
+ vector.transfer_write %cst_0, %subview[%c0, %c0] {in_bounds = [true, true]} : vector<32x64xf32>, memref<32x64xf32, strided<[64, 1], offset: ?>>
+ %subview_1 = memref.subview %arg0[%arg1, 0, 0, 0] [1, 24, 32, 64] [1, 1, 1, 1] : memref<8x24x32x64xf32> to memref<24x32x64xf32, strided<[2048, 64, 1], offset: ?>>
+ scf.for %arg3 = %c0 to %c32 step %c4 {
+ scf.for %arg4 = %c0 to %c64 step %c64 {
+ %subview_2 = memref.subview %subview[%arg3, %arg4] [4, 64] [1, 1] : memref<32x64xf32, strided<[64, 1], offset: ?>> to memref<4x64xf32, strided<[64, 1], offset: ?>>
+ scf.for %arg5 = %c0 to %c24 step %c1 {
+ scf.for %arg6 = %c0 to %c64 step %c1 {
+ %subview_3 = memref.subview %subview_1[%arg5, %arg3, %arg6] [1, 4, 1] [1, 1, 1] : memref<24x32x64xf32, strided<[2048, 64, 1], offset: ?>> to memref<1x4x1xf32, strided<[2048, 64, 1], offset: ?>>
+ %subview_4 = memref.subview %0[%arg5, %arg6, %arg4] [1, 1, 64] [1, 1, 1] : memref<24x64x64xf32> to memref<1x1x64xf32, strided<[4096, 64, 1], offset: ?>>
+ %1 = vector.transfer_read %subview_3[%c0, %c0, %c0], %cst {in_bounds = [true, true, true]} : memref<1x4x1xf32, strided<[2048, 64, 1], offset: ?>>, vector<1x4x1xf32>
+ %2 = vector.transfer_read %subview_4[%c0, %c0, %c0], %cst {in_bounds = [true, true, true]} : memref<1x1x64xf32, strided<[4096, 64, 1], offset: ?>>, vector<1x1x64xf32>
+ %3 = vector.transfer_read %subview_2[%c0, %c0], %cst {in_bounds = [true, true]} : memref<4x64xf32, strided<[64, 1], offset: ?>>, vector<4x64xf32>
+ %4 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %1, %2, %3 : vector<1x4x1xf32>, vector<1x1x64xf32> into vector<4x64xf32>
+ vector.transfer_write %4, %subview_2[%c0, %c0] {in_bounds = [true, true]} : vector<4x64xf32>, memref<4x64xf32, strided<[64, 1], offset: ?>>
}
}
}
}
- return %alloc : memref<8x24x32x64xf32>
}
-
+ return %alloc : memref<8x24x32x64xf32>
+}
// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
-
// CHECK-LABEL: memref.global "private" constant @__constant_24x64x64xf32 : memref<24x64x64xf32> = dense<1.000000e+00> {alignment = 64 : i64}
-
// CHECK-LABEL: func.func @tiled_gemm_hoist_vector_transfer_operations(
-// CHECK-SAME: %[[VAL_0:.*]]: memref<8x24x32x64xf32>) -> memref<8x24x32x64xf32> {
+// CHECK-SAME: %[[VAL_0:.*]]: memref<8x24x32x64xf32>) -> memref<8x24x32x64xf32> {
// CHECK: %[[VAL_1:.*]] = arith.constant 0.000000e+00 : f32
// CHECK: %[[VAL_2:.*]] = arith.constant dense<0.000000e+00> : vector<32x64xf32>
// CHECK: %[[VAL_3:.*]] = arith.constant 1 : index
@@ -84,44 +81,39 @@
// CHECK: return %[[VAL_10]] : memref<8x24x32x64xf32>
// CHECK: }
-
-
module attributes {transform.with_named_sequence} {
- transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
- %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
- transform.apply_patterns to %func {
- transform.apply_patterns.vector.hoist_vector_transfer
- } : !transform.any_op
- transform.yield
- }
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.vector.hoist_vector_transfer
+ } : !transform.any_op
+ transform.yield
+ }
}
-
// -----
#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
#map2 = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
-module {
- memref.global "private" constant @__constant_24x64x64xf32 : memref<24x64x64xf32> = dense<1.000000e+00> {alignment = 64 : i64}
- func.func @gemm_without_tiling_so_no_hoisting(%arg0: memref<8x24x32x64xf32>) -> memref<8x24x32x64xf32> {
- %cst = arith.constant 0.000000e+00 : f32
- %cst_0 = arith.constant dense<0.000000e+00> : vector<32x64xf32>
- %c0 = arith.constant 0 : index
- %0 = memref.get_global @__constant_24x64x64xf32 : memref<24x64x64xf32>
- %alloc = memref.alloc() {alignment = 64 : i64} : memref<8x24x32x64xf32>
- scf.forall (%arg1, %arg2) in (8, 24) {
- %subview = memref.subview %alloc[%arg1, %arg2, 0, 0] [1, 1, 32, 64] [1, 1, 1, 1] : memref<8x24x32x64xf32> to memref<32x64xf32, strided<[64, 1], offset: ?>>
- vector.transfer_write %cst_0, %subview[%c0, %c0] {in_bounds = [true, true]} : vector<32x64xf32>, memref<32x64xf32, strided<[64, 1], offset: ?>>
- %subview_1 = memref.subview %arg0[%arg1, 0, 0, 0] [1, 24, 32, 64] [1, 1, 1, 1] : memref<8x24x32x64xf32> to memref<24x32x64xf32, strided<[2048, 64, 1], offset: ?>>
- %1 = vector.transfer_read %subview_1[%c0, %c0, %c0], %cst {in_bounds = [true, true, true]} : memref<24x32x64xf32, strided<[2048, 64, 1], offset: ?>>, vector<24x32x64xf32>
- %2 = vector.transfer_read %0[%c0, %c0, %c0], %cst {in_bounds = [true, true, true]} : memref<24x64x64xf32>, vector<24x64x64xf32>
- %3 = vector.transfer_read %subview[%c0, %c0], %cst {in_bounds = [true, true]} : memref<32x64xf32, strided<[64, 1], offset: ?>>, vector<32x64xf32>
- %4 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %1, %2, %3 : vector<24x32x64xf32>, vector<24x64x64xf32> into vector<32x64xf32>
- vector.transfer_write %4, %subview[%c0, %c0] {in_bounds = [true, true]} : vector<32x64xf32>, memref<32x64xf32, strided<[64, 1], offset: ?>>
- }
- return %alloc : memref<8x24x32x64xf32>
+memref.global "private" constant @__constant_24x64x64xf32 : memref<24x64x64xf32> = dense<1.000000e+00> {alignment = 64 : i64}
+func.func @gemm_without_tiling_so_no_hoisting(%arg0: memref<8x24x32x64xf32>) -> memref<8x24x32x64xf32> {
+ %cst = arith.constant 0.000000e+00 : f32
+ %cst_0 = arith.constant dense<0.000000e+00> : vector<32x64xf32>
+ %c0 = arith.constant 0 : index
+ %0 = memref.get_global @__constant_24x64x64xf32 : memref<24x64x64xf32>
+ %alloc = memref.alloc() {alignment = 64 : i64} : memref<8x24x32x64xf32>
+ scf.forall (%arg1, %arg2) in (8, 24) {
+ %subview = memref.subview %alloc[%arg1, %arg2, 0, 0] [1, 1, 32, 64] [1, 1, 1, 1] : memref<8x24x32x64xf32> to memref<32x64xf32, strided<[64, 1], offset: ?>>
+ vector.transfer_write %cst_0, %subview[%c0, %c0] {in_bounds = [true, true]} : vector<32x64xf32>, memref<32x64xf32, strided<[64, 1], offset: ?>>
+ %subview_1 = memref.subview %arg0[%arg1, 0, 0, 0] [1, 24, 32, 64] [1, 1, 1, 1] : memref<8x24x32x64xf32> to memref<24x32x64xf32, strided<[2048, 64, 1], offset: ?>>
+ %1 = vector.transfer_read %subview_1[%c0, %c0, %c0], %cst {in_bounds = [true, true, true]} : memref<24x32x64xf32, strided<[2048, 64, 1], offset: ?>>, vector<24x32x64xf32>
+ %2 = vector.transfer_read %0[%c0, %c0, %c0], %cst {in_bounds = [true, true, true]} : memref<24x64x64xf32>, vector<24x64x64xf32>
+ %3 = vector.transfer_read %subview[%c0, %c0], %cst {in_bounds = [true, true]} : memref<32x64xf32, strided<[64, 1], offset: ?>>, vector<32x64xf32>
+ %4 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %1, %2, %3 : vector<24x32x64xf32>, vector<24x64x64xf32> into vector<32x64xf32>
+ vector.transfer_write %4, %subview[%c0, %c0] {in_bounds = [true, true]} : vector<32x64xf32>, memref<32x64xf32, strided<[64, 1], offset: ?>>
}
+ return %alloc : memref<8x24x32x64xf32>
}
// CHECK-LABEL: func.func @gemm_without_tiling_so_no_hoisting
@@ -135,36 +127,31 @@ module {
// CHECK-NEXT: vector.transfer_write
module attributes {transform.with_named_sequence} {
- transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
- %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
- transform.apply_patterns to %func {
- transform.apply_patterns.vector.hoist_vector_transfer
- } : !transform.any_op
- transform.yield
- }
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.vector.hoist_vector_transfer
+ } : !transform.any_op
+ transform.yield
+ }
}
-
// -----
#map = affine_map<(d0, d1, d2) -> (d0, d2)>
#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
-
-module {
- func.func @gemm_with_args_so_no_hoisting(%arg0: tensor<4x1xf32>, %arg1: tensor<1x64xf32>, %arg2: tensor<4x64xf32>) -> tensor<4x64xf32> {
- %c0 = arith.constant 0 : index
- %cst = arith.constant 0.000000e+00 : f32
- %0 = vector.transfer_read %arg0[%c0, %c0], %cst {in_bounds = [true, true]} : tensor<4x1xf32>, vector<4x1xf32>
- %1 = vector.transfer_read %arg1[%c0, %c0], %cst {in_bounds = [true, true]} : tensor<1x64xf32>, vector<1x64xf32>
- %2 = vector.transfer_read %arg2[%c0, %c0], %cst {in_bounds = [true, true]} : tensor<4x64xf32>, vector<4x64xf32>
- %3 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %0, %1, %2 : vector<4x1xf32>, vector<1x64xf32> into vector<4x64xf32>
- %4 = vector.transfer_write %3, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<4x64xf32>, tensor<4x64xf32>
- return %4 : tensor<4x64xf32>
- }
+func.func @gemm_with_args_so_no_hoisting(%arg0: tensor<4x1xf32>, %arg1: tensor<1x64xf32>, %arg2: tensor<4x64xf32>) -> tensor<4x64xf32> {
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant 0.000000e+00 : f32
+ %0 = vector.transfer_read %arg0[%c0, %c0], %cst {in_bounds = [true, true]} : tensor<4x1xf32>, vector<4x1xf32>
+ %1 = vector.transfer_read %arg1[%c0, %c0], %cst {in_bounds = [true, true]} : tensor<1x64xf32>, vector<1x64xf32>
+ %2 = vector.transfer_read %arg2[%c0, %c0], %cst {in_bounds = [true, true]} : tensor<4x64xf32>, vector<4x64xf32>
+ %3 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %0, %1, %2 : vector<4x1xf32>, vector<1x64xf32> into vector<4x64xf32>
+ %4 = vector.transfer_write %3, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<4x64xf32>, tensor<4x64xf32>
+ return %4 : tensor<4x64xf32>
}
-
// CHECK-LABEL: func.func @gemm_with_args_so_no_hoisting
// CHECK: vector.transfer_read
// CHECK-NEXT: vector.transfer_read
@@ -174,13 +161,11 @@ module {
// CHECK-NEXT: return
module attributes {transform.with_named_sequence} {
- transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
- %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
- transform.apply_patterns to %func {
- transform.apply_patterns.vector.hoist_vector_transfer
- } : !transform.any_op
- transform.yield
- }
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.vector.hoist_vector_transfer
+ } : !transform.any_op
+ transform.yield
+ }
}
-
-
diff --git a/mlir/test/Dialect/Linalg/vector-contract-to-fma.mlir b/mlir/test/Dialect/Linalg/vector-contract-to-fma.mlir
deleted file mode 100644
index f18c0dcb573d70..00000000000000
--- a/mlir/test/Dialect/Linalg/vector-contract-to-fma.mlir
+++ /dev/null
@@ -1,48 +0,0 @@
-// RUN: mlir-opt %s -transform-interpreter -cse -split-input-file | FileCheck %s
-
-#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
-#mapTransposeB = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
-#map2 = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
-
- func.func @transpose_matrix_no_conversion_to_fma(%arg0: memref<16x32x128xf32>, %arg1: memref<16x128x64xf32>, %arg2: memref<32x64xf32>) {
- %cst = arith.constant 0.000000e+00 : f32
- %c1 = arith.constant 1 : index
- %c16 = arith.constant 16 : index
- %c64 = arith.constant 64 : index
- %c128 = arith.constant 128 : index
- %c4 = arith.constant 4 : index
- %c32 = arith.constant 32 : index
- %c0 = arith.constant 0 : index
-
- scf.for %arg5 = %c0 to %c32 step %c4 {
- scf.for %arg6 = %c0 to %c128 step %c64 {
- %subview_2 = memref.subview %arg2[%arg5, %arg6] [4, 64] [1, 1] : memref<32x64xf32> to memref<4x64xf32, strided<[64, 1], offset: ?>>
- %2 = vector.transfer_read %subview_2[%c0, %c0], %cst {in_bounds = [true, true]} : memref<4x64xf32, strided<[64, 1], offset: ?>>, vector<4x64xf32>
- %con = scf.for %arg7 = %c0 to %c16 step %c1 iter_args(%argcon = %2) -> vector<4x64xf32> {
- %con1 = scf.for %arg8 = %c0 to %c64 step %c1 iter_args(%argcon1 = %argcon) -> vector<4x64xf32> {
- %subview_3 = memref.subview %arg0[%arg7, %arg5, %arg8] [1, 4, 1] [1, 1, 1] : memref<16x32x128xf32> to memref<1x4x1xf32, strided<[4096, 128, 1], offset: ?>>
- %subview_4 = memref.subview %arg1[%arg7, %arg8, %arg6] [1, 1, 64] [1, 1, 1] : memref<16x128x64xf32> to memref<1x1x64xf32, strided<[8192, 64, 1], offset: ?>>
- %0 = vector.transfer_read %subview_3[%c0, %c0, %c0], %cst {in_bounds = [true, true, true]} : memref<1x4x1xf32, strided<[4096, 128, 1], offset: ?>>, vector<1x4x1xf32>
- %1 = vector.transfer_read %subview_4[%c0, %c0, %c0], %cst {permutation_map = affine_map<(d0, d1, d2) -> (d0, d2, d1)>, in_bounds = [true, true, true]} : memref<1x1x64xf32, strided<[8192, 64, 1], offset: ?>>, vector<1x64x1xf32>
- %3 = vector.contract {indexing_maps = [#map, #mapTransposeB, #map2], iterator_types = ["reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %0, %1, %argcon1 : vector<1x4x1xf32>, vector<1x64x1xf32> into vector<4x64xf32>
- scf.yield %3 : vector<4x64xf32>
- }
- scf.yield %con1 : vector<4x64xf32>
- }
- vector.transfer_write %con, %subview_2[%c0, %c0] {in_bounds = [true, true]} : vector<4x64xf32>, memref<4x64xf32, strided<[64, 1], offset: ?>>
- }
- }
- return
- }
-
-// CHECK-NOT: vector.fma
-
- module attributes {transform.with_named_sequence} {
- transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
- %0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
- transform.apply_patterns to %0 {
- transform.apply_patterns.vector.contract_to_fma
- } : !transform.any_op
- transform.yield
- }
- }
More information about the Mlir-commits
mailing list