[Mlir-commits] [mlir] [MLIR][Linalg] Lower vector.contract to chain of vector.fma for batch reduce matmul (PR #121885)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Jan 6 20:50:24 PST 2025
https://github.com/arun-thmn updated https://github.com/llvm/llvm-project/pull/121885
>From 000bf140c3a4f5bcb43eea39b1cf935ef86d2f55 Mon Sep 17 00:00:00 2001
From: Arun Thangamani <arun.thangamani at intel.com>
Date: Mon, 6 Jan 2025 20:07:31 -0800
Subject: [PATCH 1/2] initial code push for lowering vector contract to fma
---
.../Linalg/TransformOps/LinalgTransformOps.td | 11 +
.../Dialect/Linalg/Transforms/Transforms.h | 6 +
.../TransformOps/LinalgTransformOps.cpp | 5 +
.../Dialect/Linalg/Transforms/CMakeLists.txt | 1 +
.../Linalg/Transforms/VectorContractToFMA.cpp | 410 ++++++++++++++++++
.../Linalg/vector-contract-to-fma.mlir | 48 ++
6 files changed, 481 insertions(+)
create mode 100644 mlir/lib/Dialect/Linalg/Transforms/VectorContractToFMA.cpp
create mode 100644 mlir/test/Dialect/Linalg/vector-contract-to-fma.mlir
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 081bf9b6d3b239..14d513a4a7d800 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -106,6 +106,17 @@ def ApplyFoldAddIntoDestPatternsOp : Op<Transform_Dialect,
let assemblyFormat = "attr-dict";
}
+def ApplyVectorContractToFMAPatternsOp : Op<Transform_Dialect,
+ "apply_patterns.vector.contract_to_fma",
+ [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+ let description = [{
+ Collects pattern to lower vector contraction operation (batch reduce matmul) for GEMM of size MxN to
+ sequence of vector FMAs.
+ }];
+
+ let assemblyFormat = "attr-dict";
+}
+
def ApplyPadVectorizationPatternsOp : Op<Transform_Dialect,
"apply_patterns.linalg.pad_vectorization",
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 1dc700f22c2027..307f365bcea5f3 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1824,6 +1824,12 @@ void populateConstantFoldLinalgOperations(RewritePatternSet &patterns,
/// suffices for achieving the sum.
void populateFoldAddIntoDestPatterns(RewritePatternSet &patterns);
+
+/// Pattern to lower vector contraction operation (batch reduce matmul) for GEMM of size MxN to
+/// sequence of vector FMAs.
+void populateVectorContractToFMAPatterns(RewritePatternSet &patterns);
+
+
/// Pattern to fuse a `tensor.pad` operation with the producer of its source,
/// if the producer is a `linalg` operation with all parallel iterator types.
void populateFuseTensorPadWithProducerLinalgOpPatterns(
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index a1d619c8cd19dc..ba407e30aae9bd 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -262,6 +262,11 @@ void transform::ApplyFoldAddIntoDestPatternsOp::populatePatterns(
linalg::populateFoldAddIntoDestPatterns(patterns);
}
+void transform::ApplyVectorContractToFMAPatternsOp::populatePatterns(
+ RewritePatternSet &patterns) {
+ linalg::populateVectorContractToFMAPatterns(patterns);
+}
+
void transform::ApplyPadVectorizationPatternsOp::populatePatterns(
RewritePatternSet &patterns) {
linalg::populatePadOpVectorizationPatterns(patterns);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
index 3594b084138124..900a9af0ca9883 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
@@ -41,6 +41,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
DecomposeGenericByUnfoldingPermutation.cpp
Vectorization.cpp
WinogradConv2D.cpp
+ VectorContractToFMA.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Linalg
diff --git a/mlir/lib/Dialect/Linalg/Transforms/VectorContractToFMA.cpp b/mlir/lib/Dialect/Linalg/Transforms/VectorContractToFMA.cpp
new file mode 100644
index 00000000000000..4d3dac6a2b4d03
--- /dev/null
+++ b/mlir/lib/Dialect/Linalg/Transforms/VectorContractToFMA.cpp
@@ -0,0 +1,410 @@
+//===--------------- VectorContractToFMA.cpp ------------*- C++-*-===//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements lowering of vector contraction to vector fma.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+#define DEBUG_TYPE "vector-contract-to-fma"
+
+using namespace mlir;
+
+/// Returns true if the \p map is transposed.
+static bool isTransposed(AffineMap map) {
+ auto results = map.getResults();
+ // Assert if the map does not have 3 or 4 inputs ([] m, n, k).
+ assert((map.getNumInputs() == 3 || map.getNumInputs() == 4) &&
+ "3 or 4 input dim expected");
+ // Assert if the result is not 2D.
+ assert(map.getNumResults() == 2 && "Only 2 output dim expected");
+
+ // Check the last two dimensions for transposition.
+ auto dimExpr0 = dyn_cast<AffineDimExpr>(results[0]);
+ auto dimExpr1 = dyn_cast<AffineDimExpr>(results[1]);
+ assert((dimExpr0 && dimExpr1) && "Unexpected dim expression");
+
+ // Exclude output map result.
+ bool isOutputResultMap =
+ dimExpr0 ==
+ mlir::getAffineDimExpr(map.getNumInputs() - 3, map.getContext()) &&
+ dimExpr1 ==
+ mlir::getAffineDimExpr(map.getNumInputs() - 2, map.getContext());
+ assert(!isOutputResultMap && "Output result map not expected");
+
+ // It's transposed if result found as (k, m) or (n, k), else not transposed.
+ if ((dimExpr0 ==
+ mlir::getAffineDimExpr(map.getNumInputs() - 1, map.getContext()) &&
+ dimExpr1 ==
+ mlir::getAffineDimExpr(map.getNumInputs() - 3, map.getContext())) ||
+ (dimExpr0 ==
+ mlir::getAffineDimExpr(map.getNumInputs() - 2, map.getContext()) &&
+ dimExpr1 ==
+ mlir::getAffineDimExpr(map.getNumInputs() - 1, map.getContext())))
+ return true;
+ return false;
+}
+
+
+// Structure to hold transformation context
+struct TransformationContext {
+ scf::ForOp innerForOp;
+ scf::ForOp outerForOp;
+ scf::ForOp outermostLoop;
+};
+
+enum class MatMulType { Standard, Batch, BatchReduce };
+
+
+/// Pattern to lower vector contraction operation (batch reduce matmul) for GEMM of size MxN to
+/// sequence of vector FMAs.
+///
+/// As an example, the following pseudo-code will be rewritten
+/// %subview_1 = memref.subview %subview[%arg3, %arg4] [4, 64] [1, 1]
+/// %1 = vector.transfer_read %subview_1[%c0, %c0], %cst {in_bounds = [true, true]}
+/// %2 = scf.for %arg5 = %c0 to %c24 step %c1 iter_args(%arg6 = %1) -> (!type) {
+/// %3 = scf.for %arg7 = %c0 to %c64 step %c1 iter_args(%arg8 = %arg6) -> (!type) {
+/// %subview_3 = memref.subview %subview_0[%arg5, %arg3, %arg7] [1, 4, 1] [1, 1, 1]
+/// %subview_4 = memref.subview %0[%arg5, %arg7, %arg4] [1, 1, 64] [1, 1, 1]
+/// %4 = vector.transfer_read %subview_3[%c0, %c0, %c0], %cst {in_bounds = [true, true, true]}
+/// %5 = vector.transfer_read %subview_4[%c0, %c0, %c0], %cst {in_bounds = [true, true, true]}
+/// %6 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %4, %5, %arg8
+/// scf.yield %6 : !type
+/// }
+/// scf.yield %3 : !type
+/// }
+/// vector.transfer_write %2, %subview_1[%c0, %c0] {in_bounds = [true, true]}
+/// to:
+/// %subview_1 = memref.subview %subview[%arg3, %arg4] [4, 64] [1, 1]
+/// %subview_2 = memref.subview %subview_1[0, 0] [1, 64] [1, 1]
+/// %subview_3 = memref.subview %subview_1[1, 0] [1, 64] [1, 1]
+/// %subview_4 = memref.subview %subview_1[2, 0] [1, 64] [1, 1]
+/// %subview_5 = memref.subview %subview_1[3, 0] [1, 64] [1, 1]
+/// %1 = vector.load %subview_2[%c0, %c0]
+/// %2 = vector.load %subview_3[%c0, %c0]
+/// %3 = vector.load %subview_4[%c0, %c0]
+/// %4 = vector.load %subview_5[%c0, %c0]
+/// %5:4 = scf.for %arg5 = %c0 to %c24 step %c1 iter_args(%arg6 = %1, %arg7 = %2, %arg8 = %3, %arg9 = %4) -> (!type, !type, !type, !type) {
+/// %6:4 = scf.for %arg10 = %c0 to %c64 step %c1 iter_args(%arg11 = %arg6, %arg12 = %arg7, %arg13 = %arg8, %arg14 = %arg9) -> (!type, !type, !type, !type) {
+/// %subview_6 = memref.subview %subview_0[%arg5, %arg3, %arg10] [1, 4, 1] [1, 1, 1]
+/// %7 = memref.load %subview_6[%c0, %c0, %c0]
+/// %8 = vector.broadcast %7 : f32 to !type
+/// %9 = memref.load %subview_6[%c0, %c1, %c0]
+/// %10 = vector.broadcast %9 : f32 to !type
+/// %11 = memref.load %subview_6[%c0, %c2, %c0]
+/// %12 = vector.broadcast %11 : f32 to !type
+/// %13 = memref.load %subview_6[%c0, %c3, %c0]
+/// %14 = vector.broadcast %13 : f32 to !type
+/// %subview_7 = memref.subview %0[%arg5, %arg10, %arg4] [1, 1, 64] [1, 1, 1]
+/// %15 = vector.load %subview_7[%c0, %c0, %c0]
+/// %16 = vector.fma %8, %15, %arg11 : !type
+/// %17 = vector.fma %10, %15, %arg12 : !type
+/// %18 = vector.fma %12, %15, %arg13 : !type
+/// %19 = vector.fma %14, %15, %arg14 : !type
+/// scf.yield %16, %17, %18, %19 : !type, !type, !type, !type
+/// }
+/// scf.yield %6#0, %6#1, %6#2, %6#3 : !type, !type, !type, !type
+/// }
+/// vector.store %5#0, %subview_2[%c0, %c0]
+/// vector.store %5#1, %subview_3[%c0, %c0]
+/// vector.store %5#2, %subview_4[%c0, %c0]
+/// vector.store %5#3, %subview_5[%c0, %c0])
+///
+struct VectorContractToFMA
+ : public OpRewritePattern<vector::ContractionOp> {
+ using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::ContractionOp op,
+ PatternRewriter &rewriter) const override {
+ if (op.getKind() != vector::CombiningKind::ADD)
+ return rewriter.notifyMatchFailure(
+ op, "Unsupported combining kind, only supports ADD at the moment)");
+
+ auto maskableOp = cast<vector::MaskableOpInterface>(op.getOperation());
+ if (maskableOp.isMasked())
+ return rewriter.notifyMatchFailure(op, "Masked contractOp not supported");
+
+ SmallVector<AffineMap, 3> maps = op.getIndexingMapsArray();
+ if (llvm::any_of(
+ maps, [](AffineMap map) { return !map.isProjectedPermutation(); }))
+ return rewriter.notifyMatchFailure(op, "Unexpected map");
+
+ // Check for the variant of matrix multiply.
+ auto iteratorTypes = op.getIteratorTypesArray();
+ MatMulType matmulType;
+ unsigned outerDimIndex = 0;
+ if (iteratorTypes.size() > 3) {
+ outerDimIndex = iteratorTypes.size() - 4;
+ matmulType =
+ iteratorTypes[outerDimIndex] == vector::IteratorType::parallel
+ ? MatMulType::Batch
+ : MatMulType::BatchReduce;
+ outerDimIndex++;
+ } else if (iteratorTypes.size() == 3) {
+ matmulType = MatMulType::Standard;
+ } else {
+ return rewriter.notifyMatchFailure(op, "Not a gemm");
+ }
+
+ if (matmulType == MatMulType::Batch)
+ return rewriter.notifyMatchFailure(op, "Batch matmul not supported");
+ if (iteratorTypes[outerDimIndex] != vector::IteratorType::parallel ||
+ iteratorTypes[outerDimIndex + 1] != vector::IteratorType::parallel ||
+ iteratorTypes[outerDimIndex + 2] != vector::IteratorType::reduction)
+ return rewriter.notifyMatchFailure(op, "Not a gemm");
+
+ SmallVector<Value, 4> results;
+
+ auto lhs = op.getLhs();
+ auto rhs = op.getRhs();
+ auto acc = op.getAcc();
+ auto lhsDefiningOp = lhs.getDefiningOp<vector::TransferReadOp>();
+ auto rhsDefiningOp = rhs.getDefiningOp<vector::TransferReadOp>();
+ auto accDefiningOp = acc.getDefiningOp<vector::TransferReadOp>();
+ if (!lhsDefiningOp || !rhsDefiningOp)
+ return failure();
+
+ // Accumulator can be a TransferReadOp but must be coming from the chain of
+ // iterargs of nested loop.
+ if (accDefiningOp)
+ return failure();
+
+ // Make sure the inputs being read are whole tensor or subview.
+ if (!llvm::all_of(lhsDefiningOp.getIndices(), isZeroIndex) ||
+ !llvm::all_of(rhsDefiningOp.getIndices(), isZeroIndex)) {
+ return failure();
+ }
+
+ auto lhsType = cast<ShapedType>(lhsDefiningOp.getType());
+ auto rhsType = cast<ShapedType>(rhsDefiningOp.getType());
+ // auto accType = acc.getType();
+ // auto accType = cast<ShapedType>(accDefiningOp.getType());
+
+ if (matmulType == MatMulType::BatchReduce &&
+ (lhsType.getRank() != 3 || rhsType.getRank() != 3))
+ return failure();
+
+ if (matmulType == MatMulType::Standard &&
+ (lhsType.getRank() != 2 || rhsType.getRank() != 2))
+ return failure();
+
+ // Check for non-transposed matrices.
+ auto mapLHS = maps[0];
+ auto mapRHS = maps[1];
+ if (matmulType == MatMulType::BatchReduce) {
+ mapLHS = mapLHS.dropResult(0);
+ mapRHS = mapRHS.dropResult(0);
+ }
+ if (isTransposed(mapLHS) || isTransposed(mapRHS))
+ return rewriter.notifyMatchFailure(
+ op, "Transposed matrices are not expected");
+
+ // Verify that the accumulator is coming through a chain of iterargs of
+ // nested loop and it is define by 'TransferReadOp'.
+ //
+ struct TransformationContext ctx;
+
+ ctx.innerForOp = op->getParentOfType<scf::ForOp>();
+ if (!ctx.innerForOp)
+ return failure();
+ ctx.outerForOp = ctx.innerForOp->getParentOfType<scf::ForOp>();
+ if (!ctx.outerForOp)
+ return failure();
+ ctx.outermostLoop = ctx.outerForOp->getParentOfType<scf::ForOp>();
+ if (!ctx.outermostLoop)
+ return failure();
+
+ // Verify original inner loop has only one iterarg.
+ auto origIterArgs = ctx.innerForOp.getRegionIterArgs();
+ if (origIterArgs.size() != 1)
+ return failure();
+
+ // Verify chain, accumulator must be inner loop's iterarg.
+ auto bbArg = dyn_cast<BlockArgument>(acc);
+ if (!bbArg)
+ return failure();
+
+ // This block arg must be init arg, not induction variable.
+ if (bbArg.getOwner() != ctx.innerForOp.getBody() ||
+ bbArg.getArgNumber() == 0) {
+ return failure();
+ }
+
+ // This iterarg must be intialized by outer loop's iterarg.
+ auto innerInitValue =
+ ctx.innerForOp.getInitArgs()[bbArg.getArgNumber() - 1];
+ auto outerBBArg = dyn_cast<BlockArgument>(innerInitValue);
+ if (!outerBBArg)
+ return failure();
+
+ // This block arg must be init arg, not induction variable.
+ if (outerBBArg.getOwner() != ctx.outerForOp.getBody() ||
+ outerBBArg.getArgNumber() == 0) {
+ return failure();
+ }
+
+ // Outer loop's iterarg initializer must be a TransferReadOp.
+ acc = ctx.outerForOp.getInitArgs()[outerBBArg.getArgNumber() - 1];
+
+ // This must be defined by vector.transfer_read
+ if (!acc.getDefiningOp<vector::TransferReadOp>())
+ return failure();
+
+ accDefiningOp = acc.getDefiningOp<vector::TransferReadOp>();
+ if (!accDefiningOp)
+ return failure();
+
+ // Only 2-D output expected.
+ auto accType = cast<ShapedType>(accDefiningOp.getType());
+ if (accType.getRank() != 2)
+ return failure();
+
+ int64_t M = accType.getDimSize(0);
+ int64_t N = accType.getDimSize(1);
+ int64_t K = lhsType.getDimSize(lhsType.getRank() - 1);
+
+ // K must be 1.
+ if (K != 1)
+ return failure();
+
+ auto accSubview = accDefiningOp.getSource();
+ Location loc = op.getLoc();
+
+ // Create M different <1xN> subviews.
+ auto memrefType = cast<MemRefType>(accSubview.getType());
+ auto elementType = memrefType.getElementType();
+ SmallVector<OpFoldResult> mixedSizes = {rewriter.getIndexAttr(K),
+ rewriter.getIndexAttr(N)};
+ SmallVector<OpFoldResult> mixedStrides = {rewriter.getIndexAttr(1),
+ rewriter.getIndexAttr(1)};
+
+ rewriter.setInsertionPoint(
+ ctx.outermostLoop.getBody(),
+ std::prev(ctx.outermostLoop.getBody()->end(), 1));
+
+ Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+ SmallVector<Value, 4> subview_2_splits;
+ for (int i = 0; i < M; i++) {
+ SmallVector<OpFoldResult> mixedOffsets = {
+ rewriter.getIndexAttr(i),
+ rewriter.getIndexAttr(0),
+ };
+ auto split = rewriter.create<memref::SubViewOp>(
+ loc, accSubview, mixedOffsets, mixedSizes, mixedStrides);
+ subview_2_splits.push_back(split);
+ }
+
+ // Intialize each accumulator with a vector of size N
+ SmallVector<Value, 4> initAccs;
+ for (auto subview : subview_2_splits) {
+ auto acc = rewriter.create<vector::LoadOp>(
+ loc, VectorType::get({N}, elementType), subview, ValueRange{c0, c0});
+ initAccs.push_back(acc);
+ }
+
+ // Create new outer loop with M different accumulators.
+ auto newOuterForOp = rewriter.create<scf::ForOp>(
+ loc, ctx.outerForOp.getLowerBound(), ctx.outerForOp.getUpperBound(),
+ ctx.outerForOp.getStep(), initAccs,
+ [&](OpBuilder &nestedBuilder, Location loc, Value iv,
+ ValueRange iterArgs) {
+ // Create new inner loop with M accumulators.
+ auto newInnerForOp = nestedBuilder.create<scf::ForOp>(
+ loc, ctx.innerForOp.getLowerBound(),
+ ctx.innerForOp.getUpperBound(), ctx.innerForOp.getStep(),
+ iterArgs,
+ [&](OpBuilder &innerBuilder, Location loc, Value innerIv,
+ ValueRange innerIterArgs) {
+ IRMapping mapping;
+ mapping.map(
+ lhsDefiningOp.getSource().getDefiningOp()->getOperand(1),
+ iv);
+ mapping.map(
+ lhsDefiningOp.getSource().getDefiningOp()->getOperand(3),
+ innerIv);
+ auto lhsClone = innerBuilder.clone(
+ *lhsDefiningOp.getSource().getDefiningOp(), mapping);
+
+ // Load and broadcast individual elements
+ SmallVector<Value, 4> broadcasts;
+ for (int i = 0; i < M; i++) {
+ auto elem = innerBuilder.create<memref::LoadOp>(
+ loc, lhsClone->getResult(0),
+ ValueRange{
+ c0,
+ innerBuilder.create<arith::ConstantIndexOp>(loc, i),
+ c0});
+ auto bcast = innerBuilder.create<vector::BroadcastOp>(
+ loc, VectorType::get({N}, elem.getType()), elem);
+ broadcasts.push_back(bcast);
+ }
+
+ IRMapping rhsMapping;
+ rhsMapping.map(
+ rhsDefiningOp.getSource().getDefiningOp()->getOperand(1),
+ iv);
+ rhsMapping.map(
+ rhsDefiningOp.getSource().getDefiningOp()->getOperand(2),
+ innerIv);
+ auto rhsClone = innerBuilder.clone(
+ *rhsDefiningOp.getSource().getDefiningOp(), rhsMapping);
+ auto rowVec = innerBuilder.create<vector::LoadOp>(
+ loc, VectorType::get({N}, elementType),
+ rhsClone->getResult(0), ValueRange{c0, c0, c0});
+
+ // Create M different FMAs using broadcasts and current
+ // accumulator values.
+ for (int i = 0; i < M; i++) {
+ auto fma = innerBuilder.create<vector::FMAOp>(
+ loc, broadcasts[i], rowVec, innerIterArgs[i]);
+ results.push_back(fma);
+ }
+
+ // Yield all M results
+ innerBuilder.create<scf::YieldOp>(loc, results);
+ });
+
+ // Yield results from inner loop to outer loop
+ nestedBuilder.create<scf::YieldOp>(loc, newInnerForOp.getResults());
+ });
+
+ Value matResult = ctx.outerForOp.getResult(0);
+ Operation *writeOp;
+ for (auto user : matResult.getUsers()) {
+ writeOp = dyn_cast<vector::TransferWriteOp>(user);
+ if (writeOp)
+ break;
+ }
+
+ // Store final results back to original locations.
+ if (writeOp) {
+ for (int i = 0; i < M; i++) {
+ rewriter.create<vector::StoreOp>(loc, newOuterForOp.getResult(i),
+ subview_2_splits[i],
+ ValueRange{c0, c0});
+ }
+ }
+
+ // Erase original write.
+ if (writeOp)
+ rewriter.eraseOp(writeOp);
+
+ return success();
+ }
+
+};
+
+void linalg::populateVectorContractToFMAPatterns(RewritePatternSet &patterns) {
+ patterns.add<VectorContractToFMA>(patterns.getContext());
+}
diff --git a/mlir/test/Dialect/Linalg/vector-contract-to-fma.mlir b/mlir/test/Dialect/Linalg/vector-contract-to-fma.mlir
new file mode 100644
index 00000000000000..f18c0dcb573d70
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/vector-contract-to-fma.mlir
@@ -0,0 +1,48 @@
+// RUN: mlir-opt %s -transform-interpreter -cse -split-input-file | FileCheck %s
+
+#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+#mapTransposeB = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
+#map2 = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+
+ func.func @transpose_matrix_no_conversion_to_fma(%arg0: memref<16x32x128xf32>, %arg1: memref<16x128x64xf32>, %arg2: memref<32x64xf32>) {
+ %cst = arith.constant 0.000000e+00 : f32
+ %c1 = arith.constant 1 : index
+ %c16 = arith.constant 16 : index
+ %c64 = arith.constant 64 : index
+ %c128 = arith.constant 128 : index
+ %c4 = arith.constant 4 : index
+ %c32 = arith.constant 32 : index
+ %c0 = arith.constant 0 : index
+
+ scf.for %arg5 = %c0 to %c32 step %c4 {
+ scf.for %arg6 = %c0 to %c128 step %c64 {
+ %subview_2 = memref.subview %arg2[%arg5, %arg6] [4, 64] [1, 1] : memref<32x64xf32> to memref<4x64xf32, strided<[64, 1], offset: ?>>
+ %2 = vector.transfer_read %subview_2[%c0, %c0], %cst {in_bounds = [true, true]} : memref<4x64xf32, strided<[64, 1], offset: ?>>, vector<4x64xf32>
+ %con = scf.for %arg7 = %c0 to %c16 step %c1 iter_args(%argcon = %2) -> vector<4x64xf32> {
+ %con1 = scf.for %arg8 = %c0 to %c64 step %c1 iter_args(%argcon1 = %argcon) -> vector<4x64xf32> {
+ %subview_3 = memref.subview %arg0[%arg7, %arg5, %arg8] [1, 4, 1] [1, 1, 1] : memref<16x32x128xf32> to memref<1x4x1xf32, strided<[4096, 128, 1], offset: ?>>
+ %subview_4 = memref.subview %arg1[%arg7, %arg8, %arg6] [1, 1, 64] [1, 1, 1] : memref<16x128x64xf32> to memref<1x1x64xf32, strided<[8192, 64, 1], offset: ?>>
+ %0 = vector.transfer_read %subview_3[%c0, %c0, %c0], %cst {in_bounds = [true, true, true]} : memref<1x4x1xf32, strided<[4096, 128, 1], offset: ?>>, vector<1x4x1xf32>
+ %1 = vector.transfer_read %subview_4[%c0, %c0, %c0], %cst {permutation_map = affine_map<(d0, d1, d2) -> (d0, d2, d1)>, in_bounds = [true, true, true]} : memref<1x1x64xf32, strided<[8192, 64, 1], offset: ?>>, vector<1x64x1xf32>
+ %3 = vector.contract {indexing_maps = [#map, #mapTransposeB, #map2], iterator_types = ["reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %0, %1, %argcon1 : vector<1x4x1xf32>, vector<1x64x1xf32> into vector<4x64xf32>
+ scf.yield %3 : vector<4x64xf32>
+ }
+ scf.yield %con1 : vector<4x64xf32>
+ }
+ vector.transfer_write %con, %subview_2[%c0, %c0] {in_bounds = [true, true]} : vector<4x64xf32>, memref<4x64xf32, strided<[64, 1], offset: ?>>
+ }
+ }
+ return
+ }
+
+// CHECK-NOT: vector.fma
+
+ module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %0 {
+ transform.apply_patterns.vector.contract_to_fma
+ } : !transform.any_op
+ transform.yield
+ }
+ }
>From 927c5149687660599eb1044c8e7789c63ee92b19 Mon Sep 17 00:00:00 2001
From: Arun Thangamani <arun.thangamani at intel.com>
Date: Mon, 6 Jan 2025 20:50:10 -0800
Subject: [PATCH 2/2] added 2 more test-cases
---
.../Linalg/vector-contract-to-fma.mlir | 209 +++++++++++++++---
1 file changed, 175 insertions(+), 34 deletions(-)
diff --git a/mlir/test/Dialect/Linalg/vector-contract-to-fma.mlir b/mlir/test/Dialect/Linalg/vector-contract-to-fma.mlir
index f18c0dcb573d70..6c61ab0368c3fd 100644
--- a/mlir/test/Dialect/Linalg/vector-contract-to-fma.mlir
+++ b/mlir/test/Dialect/Linalg/vector-contract-to-fma.mlir
@@ -1,48 +1,189 @@
// RUN: mlir-opt %s -transform-interpreter -cse -split-input-file | FileCheck %s
#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
-#mapTransposeB = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
+#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
#map2 = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
-
- func.func @transpose_matrix_no_conversion_to_fma(%arg0: memref<16x32x128xf32>, %arg1: memref<16x128x64xf32>, %arg2: memref<32x64xf32>) {
- %cst = arith.constant 0.000000e+00 : f32
- %c1 = arith.constant 1 : index
- %c16 = arith.constant 16 : index
- %c64 = arith.constant 64 : index
- %c128 = arith.constant 128 : index
- %c4 = arith.constant 4 : index
- %c32 = arith.constant 32 : index
- %c0 = arith.constant 0 : index
-
- scf.for %arg5 = %c0 to %c32 step %c4 {
- scf.for %arg6 = %c0 to %c128 step %c64 {
- %subview_2 = memref.subview %arg2[%arg5, %arg6] [4, 64] [1, 1] : memref<32x64xf32> to memref<4x64xf32, strided<[64, 1], offset: ?>>
- %2 = vector.transfer_read %subview_2[%c0, %c0], %cst {in_bounds = [true, true]} : memref<4x64xf32, strided<[64, 1], offset: ?>>, vector<4x64xf32>
- %con = scf.for %arg7 = %c0 to %c16 step %c1 iter_args(%argcon = %2) -> vector<4x64xf32> {
- %con1 = scf.for %arg8 = %c0 to %c64 step %c1 iter_args(%argcon1 = %argcon) -> vector<4x64xf32> {
- %subview_3 = memref.subview %arg0[%arg7, %arg5, %arg8] [1, 4, 1] [1, 1, 1] : memref<16x32x128xf32> to memref<1x4x1xf32, strided<[4096, 128, 1], offset: ?>>
- %subview_4 = memref.subview %arg1[%arg7, %arg8, %arg6] [1, 1, 64] [1, 1, 1] : memref<16x128x64xf32> to memref<1x1x64xf32, strided<[8192, 64, 1], offset: ?>>
- %0 = vector.transfer_read %subview_3[%c0, %c0, %c0], %cst {in_bounds = [true, true, true]} : memref<1x4x1xf32, strided<[4096, 128, 1], offset: ?>>, vector<1x4x1xf32>
- %1 = vector.transfer_read %subview_4[%c0, %c0, %c0], %cst {permutation_map = affine_map<(d0, d1, d2) -> (d0, d2, d1)>, in_bounds = [true, true, true]} : memref<1x1x64xf32, strided<[8192, 64, 1], offset: ?>>, vector<1x64x1xf32>
- %3 = vector.contract {indexing_maps = [#map, #mapTransposeB, #map2], iterator_types = ["reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %0, %1, %argcon1 : vector<1x4x1xf32>, vector<1x64x1xf32> into vector<4x64xf32>
- scf.yield %3 : vector<4x64xf32>
+memref.global "private" constant @__constant_24x64x64xf32 : memref<24x64x64xf32> = dense<1.000000e+00> {alignment = 64 : i64}
+func.func @lower_contract_to_fma(%arg0: memref<8x24x32x64xf32>) -> memref<8x24x32x64xf32> {
+ %cst = arith.constant 0.000000e+00 : f32
+ %cst_0 = arith.constant dense<0.000000e+00> : vector<32x64xf32>
+ %c1 = arith.constant 1 : index
+ %c24 = arith.constant 24 : index
+ %c64 = arith.constant 64 : index
+ %c4 = arith.constant 4 : index
+ %c32 = arith.constant 32 : index
+ %c0 = arith.constant 0 : index
+ %0 = memref.get_global @__constant_24x64x64xf32 : memref<24x64x64xf32>
+ %alloc = memref.alloc() {alignment = 64 : i64} : memref<8x24x32x64xf32>
+ scf.forall (%arg1, %arg2) in (8, 24) {
+ %subview = memref.subview %alloc[%arg1, %arg2, 0, 0] [1, 1, 32, 64] [1, 1, 1, 1] : memref<8x24x32x64xf32> to memref<32x64xf32, strided<[64, 1], offset: ?>>
+ vector.transfer_write %cst_0, %subview[%c0, %c0] {in_bounds = [true, true]} : vector<32x64xf32>, memref<32x64xf32, strided<[64, 1], offset: ?>>
+ %subview_1 = memref.subview %arg0[%arg1, 0, 0, 0] [1, 24, 32, 64] [1, 1, 1, 1] : memref<8x24x32x64xf32> to memref<24x32x64xf32, strided<[2048, 64, 1], offset: ?>>
+ scf.for %arg3 = %c0 to %c32 step %c4 {
+ scf.for %arg4 = %c0 to %c64 step %c64 {
+ %subview_2 = memref.subview %subview[%arg3, %arg4] [4, 64] [1, 1] : memref<32x64xf32, strided<[64, 1], offset: ?>> to memref<4x64xf32, strided<[64, 1], offset: ?>>
+ %1 = vector.transfer_read %subview_2[%c0, %c0], %cst {in_bounds = [true, true]} : memref<4x64xf32, strided<[64, 1], offset: ?>>, vector<4x64xf32>
+ %2 = scf.for %arg5 = %c0 to %c24 step %c1 iter_args(%arg6 = %1) -> (vector<4x64xf32>) {
+ %3 = scf.for %arg7 = %c0 to %c64 step %c1 iter_args(%arg8 = %arg6) -> (vector<4x64xf32>) {
+ %subview_3 = memref.subview %subview_1[%arg5, %arg3, %arg7] [1, 4, 1] [1, 1, 1] : memref<24x32x64xf32, strided<[2048, 64, 1], offset: ?>> to memref<1x4x1xf32, strided<[2048, 64, 1], offset: ?>>
+ %subview_4 = memref.subview %0[%arg5, %arg7, %arg4] [1, 1, 64] [1, 1, 1] : memref<24x64x64xf32> to memref<1x1x64xf32, strided<[4096, 64, 1], offset: ?>>
+ %4 = vector.transfer_read %subview_3[%c0, %c0, %c0], %cst {in_bounds = [true, true, true]} : memref<1x4x1xf32, strided<[2048, 64, 1], offset: ?>>, vector<1x4x1xf32>
+ %5 = vector.transfer_read %subview_4[%c0, %c0, %c0], %cst {in_bounds = [true, true, true]} : memref<1x1x64xf32, strided<[4096, 64, 1], offset: ?>>, vector<1x1x64xf32>
+ %6 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %4, %5, %arg8 : vector<1x4x1xf32>, vector<1x1x64xf32> into vector<4x64xf32>
+ scf.yield %6 : vector<4x64xf32>
}
- scf.yield %con1 : vector<4x64xf32>
+ scf.yield %3 : vector<4x64xf32>
}
- vector.transfer_write %con, %subview_2[%c0, %c0] {in_bounds = [true, true]} : vector<4x64xf32>, memref<4x64xf32, strided<[64, 1], offset: ?>>
+ vector.transfer_write %2, %subview_2[%c0, %c0] {in_bounds = [true, true]} : vector<4x64xf32>, memref<4x64xf32, strided<[64, 1], offset: ?>>
}
}
- return
}
+ return %alloc : memref<8x24x32x64xf32>
+}
+
+// CHECK-LABEL: memref.global "private" constant @__constant_24x64x64xf32 : memref<24x64x64xf32> = dense<1.000000e+00> {alignment = 64 : i64}
+// CHECK-LABEL: func.func @lower_contract_to_fma(
+// CHECK-SAME: %[[VAL_0:.*]]: memref<8x24x32x64xf32>) -> memref<8x24x32x64xf32> {
+// CHECK: %[[VAL_1:.*]] = arith.constant 3 : index
+// CHECK: %[[VAL_2:.*]] = arith.constant 2 : index
+// CHECK: %[[VAL_3:.*]] = arith.constant dense<0.000000e+00> : vector<32x64xf32>
+// CHECK: %[[VAL_4:.*]] = arith.constant 1 : index
+// CHECK: %[[VAL_5:.*]] = arith.constant 24 : index
+// CHECK: %[[VAL_6:.*]] = arith.constant 64 : index
+// CHECK: %[[VAL_7:.*]] = arith.constant 4 : index
+// CHECK: %[[VAL_8:.*]] = arith.constant 32 : index
+// CHECK: %[[VAL_9:.*]] = arith.constant 0 : index
+// CHECK: %[[VAL_10:.*]] = memref.get_global @__constant_24x64x64xf32 : memref<24x64x64xf32>
+// CHECK: %[[VAL_11:.*]] = memref.alloc() {alignment = 64 : i64} : memref<8x24x32x64xf32>
+// CHECK: scf.forall (%[[VAL_12:.*]], %[[VAL_13:.*]]) in (8, 24) {
+// CHECK: %[[VAL_14:.*]] = memref.subview %[[VAL_11]]{{\[}}%[[VAL_12]], %[[VAL_13]], 0, 0] [1, 1, 32, 64] [1, 1, 1, 1] : memref<8x24x32x64xf32> to memref<32x64xf32, strided<[64, 1], offset: ?>>
+// CHECK: vector.transfer_write %[[VAL_3]], %[[VAL_14]]{{\[}}%[[VAL_9]], %[[VAL_9]]] {in_bounds = [true, true]} : vector<32x64xf32>, memref<32x64xf32, strided<[64, 1], offset: ?>>
+// CHECK: %[[VAL_15:.*]] = memref.subview %[[VAL_0]]{{\[}}%[[VAL_12]], 0, 0, 0] [1, 24, 32, 64] [1, 1, 1, 1] : memref<8x24x32x64xf32> to memref<24x32x64xf32, strided<[2048, 64, 1], offset: ?>>
+// CHECK: scf.for %[[VAL_16:.*]] = %[[VAL_9]] to %[[VAL_8]] step %[[VAL_7]] {
+// CHECK: scf.for %[[VAL_17:.*]] = %[[VAL_9]] to %[[VAL_6]] step %[[VAL_6]] {
+// CHECK: %[[VAL_18:.*]] = memref.subview %[[VAL_14]]{{\[}}%[[VAL_16]], %[[VAL_17]]] [4, 64] [1, 1] : memref<32x64xf32, strided<[64, 1], offset: ?>> to memref<4x64xf32, strided<[64, 1], offset: ?>>
+// CHECK: %[[VAL_19:.*]] = memref.subview %[[VAL_18]][0, 0] [1, 64] [1, 1] : memref<4x64xf32, strided<[64, 1], offset: ?>> to memref<1x64xf32, strided<[64, 1], offset: ?>>
+// CHECK: %[[VAL_20:.*]] = memref.subview %[[VAL_18]][1, 0] [1, 64] [1, 1] : memref<4x64xf32, strided<[64, 1], offset: ?>> to memref<1x64xf32, strided<[64, 1], offset: ?>>
+// CHECK: %[[VAL_21:.*]] = memref.subview %[[VAL_18]][2, 0] [1, 64] [1, 1] : memref<4x64xf32, strided<[64, 1], offset: ?>> to memref<1x64xf32, strided<[64, 1], offset: ?>>
+// CHECK: %[[VAL_22:.*]] = memref.subview %[[VAL_18]][3, 0] [1, 64] [1, 1] : memref<4x64xf32, strided<[64, 1], offset: ?>> to memref<1x64xf32, strided<[64, 1], offset: ?>>
+// CHECK: %[[VAL_23:.*]] = vector.load %[[VAL_19]]{{\[}}%[[VAL_9]], %[[VAL_9]]] : memref<1x64xf32, strided<[64, 1], offset: ?>>, vector<64xf32>
+// CHECK: %[[VAL_24:.*]] = vector.load %[[VAL_20]]{{\[}}%[[VAL_9]], %[[VAL_9]]] : memref<1x64xf32, strided<[64, 1], offset: ?>>, vector<64xf32>
+// CHECK: %[[VAL_25:.*]] = vector.load %[[VAL_21]]{{\[}}%[[VAL_9]], %[[VAL_9]]] : memref<1x64xf32, strided<[64, 1], offset: ?>>, vector<64xf32>
+// CHECK: %[[VAL_26:.*]] = vector.load %[[VAL_22]]{{\[}}%[[VAL_9]], %[[VAL_9]]] : memref<1x64xf32, strided<[64, 1], offset: ?>>, vector<64xf32>
+// CHECK: %[[VAL_27:.*]]:4 = scf.for %[[VAL_28:.*]] = %[[VAL_9]] to %[[VAL_5]] step %[[VAL_4]] iter_args(%[[VAL_29:.*]] = %[[VAL_23]], %[[VAL_30:.*]] = %[[VAL_24]], %[[VAL_31:.*]] = %[[VAL_25]], %[[VAL_32:.*]] = %[[VAL_26]]) -> (vector<64xf32>, vector<64xf32>, vector<64xf32>, vector<64xf32>) {
+// CHECK: %[[VAL_33:.*]]:4 = scf.for %[[VAL_34:.*]] = %[[VAL_9]] to %[[VAL_6]] step %[[VAL_4]] iter_args(%[[VAL_35:.*]] = %[[VAL_29]], %[[VAL_36:.*]] = %[[VAL_30]], %[[VAL_37:.*]] = %[[VAL_31]], %[[VAL_38:.*]] = %[[VAL_32]]) -> (vector<64xf32>, vector<64xf32>, vector<64xf32>, vector<64xf32>) {
+// CHECK: %[[VAL_39:.*]] = memref.subview %[[VAL_15]]{{\[}}%[[VAL_28]], %[[VAL_16]], %[[VAL_34]]] [1, 4, 1] [1, 1, 1] : memref<24x32x64xf32, strided<[2048, 64, 1], offset: ?>> to memref<1x4x1xf32, strided<[2048, 64, 1], offset: ?>>
+// CHECK: %[[VAL_40:.*]] = memref.load %[[VAL_39]]{{\[}}%[[VAL_9]], %[[VAL_9]], %[[VAL_9]]] : memref<1x4x1xf32, strided<[2048, 64, 1], offset: ?>>
+// CHECK: %[[VAL_41:.*]] = vector.broadcast %[[VAL_40]] : f32 to vector<64xf32>
+// CHECK: %[[VAL_42:.*]] = memref.load %[[VAL_39]]{{\[}}%[[VAL_9]], %[[VAL_4]], %[[VAL_9]]] : memref<1x4x1xf32, strided<[2048, 64, 1], offset: ?>>
+// CHECK: %[[VAL_43:.*]] = vector.broadcast %[[VAL_42]] : f32 to vector<64xf32>
+// CHECK: %[[VAL_44:.*]] = memref.load %[[VAL_39]]{{\[}}%[[VAL_9]], %[[VAL_2]], %[[VAL_9]]] : memref<1x4x1xf32, strided<[2048, 64, 1], offset: ?>>
+// CHECK: %[[VAL_45:.*]] = vector.broadcast %[[VAL_44]] : f32 to vector<64xf32>
+// CHECK: %[[VAL_46:.*]] = memref.load %[[VAL_39]]{{\[}}%[[VAL_9]], %[[VAL_1]], %[[VAL_9]]] : memref<1x4x1xf32, strided<[2048, 64, 1], offset: ?>>
+// CHECK: %[[VAL_47:.*]] = vector.broadcast %[[VAL_46]] : f32 to vector<64xf32>
+// CHECK: %[[VAL_48:.*]] = memref.subview %[[VAL_10]]{{\[}}%[[VAL_28]], %[[VAL_34]], %[[VAL_17]]] [1, 1, 64] [1, 1, 1] : memref<24x64x64xf32> to memref<1x1x64xf32, strided<[4096, 64, 1], offset: ?>>
+// CHECK: %[[VAL_49:.*]] = vector.load %[[VAL_48]]{{\[}}%[[VAL_9]], %[[VAL_9]], %[[VAL_9]]] : memref<1x1x64xf32, strided<[4096, 64, 1], offset: ?>>, vector<64xf32>
+// CHECK: %[[VAL_50:.*]] = vector.fma %[[VAL_41]], %[[VAL_49]], %[[VAL_35]] : vector<64xf32>
+// CHECK: %[[VAL_51:.*]] = vector.fma %[[VAL_43]], %[[VAL_49]], %[[VAL_36]] : vector<64xf32>
+// CHECK: %[[VAL_52:.*]] = vector.fma %[[VAL_45]], %[[VAL_49]], %[[VAL_37]] : vector<64xf32>
+// CHECK: %[[VAL_53:.*]] = vector.fma %[[VAL_47]], %[[VAL_49]], %[[VAL_38]] : vector<64xf32>
+// CHECK: scf.yield %[[VAL_50]], %[[VAL_51]], %[[VAL_52]], %[[VAL_53]] : vector<64xf32>, vector<64xf32>, vector<64xf32>, vector<64xf32>
+// CHECK: }
+// CHECK: scf.yield %[[VAL_54:.*]]#0, %[[VAL_54]]#1, %[[VAL_54]]#2, %[[VAL_54]]#3 : vector<64xf32>, vector<64xf32>, vector<64xf32>, vector<64xf32>
+// CHECK: }
+// CHECK: vector.store %[[VAL_55:.*]]#0, %[[VAL_19]]{{\[}}%[[VAL_9]], %[[VAL_9]]] : memref<1x64xf32, strided<[64, 1], offset: ?>>, vector<64xf32>
+// CHECK: vector.store %[[VAL_55]]#1, %[[VAL_20]]{{\[}}%[[VAL_9]], %[[VAL_9]]] : memref<1x64xf32, strided<[64, 1], offset: ?>>, vector<64xf32>
+// CHECK: vector.store %[[VAL_55]]#2, %[[VAL_21]]{{\[}}%[[VAL_9]], %[[VAL_9]]] : memref<1x64xf32, strided<[64, 1], offset: ?>>, vector<64xf32>
+// CHECK: vector.store %[[VAL_55]]#3, %[[VAL_22]]{{\[}}%[[VAL_9]], %[[VAL_9]]] : memref<1x64xf32, strided<[64, 1], offset: ?>>, vector<64xf32>
+// CHECK: }
+// CHECK: }
+// CHECK: }
+// CHECK: return %[[VAL_11]] : memref<8x24x32x64xf32>
+// CHECK: }
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %0 {
+ transform.apply_patterns.vector.contract_to_fma
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
+//-----
+
+#mapA = affine_map<(d0, d1, d2) -> (d0, d2)>
+#mapB = affine_map<(d0, d1, d2) -> (d2, d1)>
+#mapC = affine_map<(d0, d1, d2) -> (d0, d1)>
+func.func @matmul_without_iterarg_accumulator_so_no_lowering_to_fma(%arg0: tensor<4x1xf32>, %arg1: tensor<1x64xf32>, %arg2: tensor<4x64xf32>) -> tensor<4x64xf32> {
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant 0.000000e+00 : f32
+ %0 = vector.transfer_read %arg0[%c0, %c0], %cst {in_bounds = [true, true]} : tensor<4x1xf32>, vector<4x1xf32>
+ %1 = vector.transfer_read %arg1[%c0, %c0], %cst {in_bounds = [true, true]} : tensor<1x64xf32>, vector<1x64xf32>
+ %2 = vector.transfer_read %arg2[%c0, %c0], %cst {in_bounds = [true, true]} : tensor<4x64xf32>, vector<4x64xf32>
+ %3 = vector.contract {indexing_maps = [#mapA, #mapB, #mapC], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %0, %1, %2 : vector<4x1xf32>, vector<1x64xf32> into vector<4x64xf32>
+ %4 = vector.transfer_write %3, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<4x64xf32>, tensor<4x64xf32>
+ return %4 : tensor<4x64xf32>
+}
// CHECK-NOT: vector.fma
- module attributes {transform.with_named_sequence} {
- transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
- %0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
- transform.apply_patterns to %0 {
- transform.apply_patterns.vector.contract_to_fma
- } : !transform.any_op
- transform.yield
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %0 {
+ transform.apply_patterns.vector.contract_to_fma
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
+
+// -----
+
+#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+#mapTransposeB = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
+#map2 = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+func.func @transpose_matrix_no_lowering_to_fma(%arg0: memref<16x32x128xf32>, %arg1: memref<16x128x64xf32>, %arg2: memref<32x64xf32>) {
+ %cst = arith.constant 0.000000e+00 : f32
+ %c1 = arith.constant 1 : index
+ %c16 = arith.constant 16 : index
+ %c64 = arith.constant 64 : index
+ %c128 = arith.constant 128 : index
+ %c4 = arith.constant 4 : index
+ %c32 = arith.constant 32 : index
+ %c0 = arith.constant 0 : index
+
+ scf.for %arg5 = %c0 to %c32 step %c4 {
+ scf.for %arg6 = %c0 to %c128 step %c64 {
+ %subview_2 = memref.subview %arg2[%arg5, %arg6] [4, 64] [1, 1] : memref<32x64xf32> to memref<4x64xf32, strided<[64, 1], offset: ?>>
+ %2 = vector.transfer_read %subview_2[%c0, %c0], %cst {in_bounds = [true, true]} : memref<4x64xf32, strided<[64, 1], offset: ?>>, vector<4x64xf32>
+ %con = scf.for %arg7 = %c0 to %c16 step %c1 iter_args(%argcon = %2) -> vector<4x64xf32> {
+ %con1 = scf.for %arg8 = %c0 to %c64 step %c1 iter_args(%argcon1 = %argcon) -> vector<4x64xf32> {
+ %subview_3 = memref.subview %arg0[%arg7, %arg5, %arg8] [1, 4, 1] [1, 1, 1] : memref<16x32x128xf32> to memref<1x4x1xf32, strided<[4096, 128, 1], offset: ?>>
+ %subview_4 = memref.subview %arg1[%arg7, %arg8, %arg6] [1, 1, 64] [1, 1, 1] : memref<16x128x64xf32> to memref<1x1x64xf32, strided<[8192, 64, 1], offset: ?>>
+ %0 = vector.transfer_read %subview_3[%c0, %c0, %c0], %cst {in_bounds = [true, true, true]} : memref<1x4x1xf32, strided<[4096, 128, 1], offset: ?>>, vector<1x4x1xf32>
+ %1 = vector.transfer_read %subview_4[%c0, %c0, %c0], %cst {permutation_map = affine_map<(d0, d1, d2) -> (d0, d2, d1)>, in_bounds = [true, true, true]} : memref<1x1x64xf32, strided<[8192, 64, 1], offset: ?>>, vector<1x64x1xf32>
+ %3 = vector.contract {indexing_maps = [#map, #mapTransposeB, #map2], iterator_types = ["reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %0, %1, %argcon1 : vector<1x4x1xf32>, vector<1x64x1xf32> into vector<4x64xf32>
+ scf.yield %3 : vector<4x64xf32>
+ }
+ scf.yield %con1 : vector<4x64xf32>
+ }
+ vector.transfer_write %con, %subview_2[%c0, %c0] {in_bounds = [true, true]} : vector<4x64xf32>, memref<4x64xf32, strided<[64, 1], offset: ?>>
}
}
+ return
+}
+
+// CHECK-NOT: vector.fma
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %0 {
+ transform.apply_patterns.vector.contract_to_fma
+ } : !transform.any_op
+ transform.yield
+ }
+}
More information about the Mlir-commits
mailing list