[Mlir-commits] [mlir] [MLIR][Linalg] Lower vector.contract to chain of vector.fma for batch reduce matmul (PR #121885)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Jan 6 20:50:24 PST 2025


https://github.com/arun-thmn updated https://github.com/llvm/llvm-project/pull/121885

>From 000bf140c3a4f5bcb43eea39b1cf935ef86d2f55 Mon Sep 17 00:00:00 2001
From: Arun Thangamani <arun.thangamani at intel.com>
Date: Mon, 6 Jan 2025 20:07:31 -0800
Subject: [PATCH 1/2] initial code push for lowering vector contract to fma

---
 .../Linalg/TransformOps/LinalgTransformOps.td |  11 +
 .../Dialect/Linalg/Transforms/Transforms.h    |   6 +
 .../TransformOps/LinalgTransformOps.cpp       |   5 +
 .../Dialect/Linalg/Transforms/CMakeLists.txt  |   1 +
 .../Linalg/Transforms/VectorContractToFMA.cpp | 410 ++++++++++++++++++
 .../Linalg/vector-contract-to-fma.mlir        |  48 ++
 6 files changed, 481 insertions(+)
 create mode 100644 mlir/lib/Dialect/Linalg/Transforms/VectorContractToFMA.cpp
 create mode 100644 mlir/test/Dialect/Linalg/vector-contract-to-fma.mlir

diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 081bf9b6d3b239..14d513a4a7d800 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -106,6 +106,17 @@ def ApplyFoldAddIntoDestPatternsOp : Op<Transform_Dialect,
   let assemblyFormat = "attr-dict";
 }
 
+def ApplyVectorContractToFMAPatternsOp : Op<Transform_Dialect,
+    "apply_patterns.vector.contract_to_fma",
+    [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+  let description = [{
+     Collects pattern to lower vector contraction operation (batch reduce matmul) for GEMM of size MxN to
+        sequence of vector FMAs.
+  }];
+
+  let assemblyFormat = "attr-dict";
+}
+
 def ApplyPadVectorizationPatternsOp : Op<Transform_Dialect,
     "apply_patterns.linalg.pad_vectorization",
     [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 1dc700f22c2027..307f365bcea5f3 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1824,6 +1824,12 @@ void populateConstantFoldLinalgOperations(RewritePatternSet &patterns,
 /// suffices for achieving the sum.
 void populateFoldAddIntoDestPatterns(RewritePatternSet &patterns);
 
+
+/// Pattern to lower vector contraction operation (batch reduce matmul) for GEMM of size MxN to
+/// sequence of vector FMAs.
+void populateVectorContractToFMAPatterns(RewritePatternSet &patterns);
+
+
 /// Pattern to fuse a `tensor.pad` operation with the producer of its source,
 /// if the producer is a `linalg` operation with all parallel iterator types.
 void populateFuseTensorPadWithProducerLinalgOpPatterns(
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index a1d619c8cd19dc..ba407e30aae9bd 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -262,6 +262,11 @@ void transform::ApplyFoldAddIntoDestPatternsOp::populatePatterns(
   linalg::populateFoldAddIntoDestPatterns(patterns);
 }
 
+void transform::ApplyVectorContractToFMAPatternsOp::populatePatterns(
+    RewritePatternSet &patterns) {
+  linalg::populateVectorContractToFMAPatterns(patterns);
+}
+
 void transform::ApplyPadVectorizationPatternsOp::populatePatterns(
     RewritePatternSet &patterns) {
   linalg::populatePadOpVectorizationPatterns(patterns);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
index 3594b084138124..900a9af0ca9883 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
@@ -41,6 +41,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
   DecomposeGenericByUnfoldingPermutation.cpp
   Vectorization.cpp
   WinogradConv2D.cpp
+  VectorContractToFMA.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Linalg
diff --git a/mlir/lib/Dialect/Linalg/Transforms/VectorContractToFMA.cpp b/mlir/lib/Dialect/Linalg/Transforms/VectorContractToFMA.cpp
new file mode 100644
index 00000000000000..4d3dac6a2b4d03
--- /dev/null
+++ b/mlir/lib/Dialect/Linalg/Transforms/VectorContractToFMA.cpp
@@ -0,0 +1,410 @@
+//===--------------- VectorContractToFMA.cpp ------------*- C++-*-===//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements lowering of vector contraction to vector fma.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+#define DEBUG_TYPE "vector-contract-to-fma"
+
+using namespace mlir;
+
+/// Returns true if the \p map is transposed.
+static bool isTransposed(AffineMap map) {
+  auto results = map.getResults();
+  // Assert if the map does not have 3 or 4 inputs ([] m, n, k).
+  assert((map.getNumInputs() == 3 || map.getNumInputs() == 4) &&
+         "3 or 4 input dim expected");
+  // Assert if the result is not 2D.
+  assert(map.getNumResults() == 2 && "Only 2 output dim expected");
+
+  // Check the last two dimensions for transposition.
+  auto dimExpr0 = dyn_cast<AffineDimExpr>(results[0]);
+  auto dimExpr1 = dyn_cast<AffineDimExpr>(results[1]);
+  assert((dimExpr0 && dimExpr1) && "Unexpected dim expression");
+
+  // Exclude output map result.
+  bool isOutputResultMap =
+      dimExpr0 ==
+          mlir::getAffineDimExpr(map.getNumInputs() - 3, map.getContext()) &&
+      dimExpr1 ==
+          mlir::getAffineDimExpr(map.getNumInputs() - 2, map.getContext());
+  assert(!isOutputResultMap && "Output result map not expected");
+
+  // It's transposed if result found as (k, m) or (n, k), else not transposed.
+  if ((dimExpr0 ==
+           mlir::getAffineDimExpr(map.getNumInputs() - 1, map.getContext()) &&
+       dimExpr1 ==
+           mlir::getAffineDimExpr(map.getNumInputs() - 3, map.getContext())) ||
+      (dimExpr0 ==
+           mlir::getAffineDimExpr(map.getNumInputs() - 2, map.getContext()) &&
+       dimExpr1 ==
+           mlir::getAffineDimExpr(map.getNumInputs() - 1, map.getContext())))
+    return true;
+  return false;
+}
+
+
+// Structure to hold transformation context
+struct TransformationContext {
+  scf::ForOp innerForOp;
+  scf::ForOp outerForOp;
+  scf::ForOp outermostLoop;
+};
+
+enum class MatMulType { Standard, Batch, BatchReduce };
+
+
+/// Pattern to lower vector contraction operation (batch reduce matmul) for GEMM of size MxN to
+/// sequence of vector FMAs.
+///
+/// As an example, the following pseudo-code will be rewritten
+/// %subview_1 = memref.subview %subview[%arg3, %arg4] [4, 64] [1, 1] 
+/// %1 = vector.transfer_read %subview_1[%c0, %c0], %cst {in_bounds = [true, true]} 
+/// %2 = scf.for %arg5 = %c0 to %c24 step %c1 iter_args(%arg6 = %1) -> (!type) {
+///  %3 = scf.for %arg7 = %c0 to %c64 step %c1 iter_args(%arg8 = %arg6) -> (!type) {
+///   %subview_3 = memref.subview %subview_0[%arg5, %arg3, %arg7] [1, 4, 1] [1, 1, 1]
+///   %subview_4 = memref.subview %0[%arg5, %arg7, %arg4] [1, 1, 64] [1, 1, 1] 
+///   %4 = vector.transfer_read %subview_3[%c0, %c0, %c0], %cst {in_bounds = [true, true, true]} 
+///   %5 = vector.transfer_read %subview_4[%c0, %c0, %c0], %cst {in_bounds = [true, true, true]} 
+///   %6 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %4, %5, %arg8
+///   scf.yield %6 : !type
+///  }
+///  scf.yield %3 : !type
+/// }
+/// vector.transfer_write %2, %subview_1[%c0, %c0] {in_bounds = [true, true]}
+/// to:
+/// %subview_1 = memref.subview %subview[%arg3, %arg4] [4, 64] [1, 1] 
+/// %subview_2 = memref.subview %subview_1[0, 0] [1, 64] [1, 1] 
+/// %subview_3 = memref.subview %subview_1[1, 0] [1, 64] [1, 1] 
+/// %subview_4 = memref.subview %subview_1[2, 0] [1, 64] [1, 1] 
+/// %subview_5 = memref.subview %subview_1[3, 0] [1, 64] [1, 1] 
+/// %1 = vector.load %subview_2[%c0, %c0] 
+/// %2 = vector.load %subview_3[%c0, %c0] 
+/// %3 = vector.load %subview_4[%c0, %c0] 
+/// %4 = vector.load %subview_5[%c0, %c0] 
+/// %5:4 = scf.for %arg5 = %c0 to %c24 step %c1 iter_args(%arg6 = %1, %arg7 = %2, %arg8 = %3, %arg9 = %4) -> (!type, !type, !type, !type) {
+///   %6:4 = scf.for %arg10 = %c0 to %c64 step %c1 iter_args(%arg11 = %arg6, %arg12 = %arg7, %arg13 = %arg8, %arg14 = %arg9) -> (!type, !type, !type, !type) {
+///     %subview_6 = memref.subview %subview_0[%arg5, %arg3, %arg10] [1, 4, 1] [1, 1, 1] 
+///     %7 = memref.load %subview_6[%c0, %c0, %c0] 
+///     %8 = vector.broadcast %7 : f32 to !type
+///     %9 = memref.load %subview_6[%c0, %c1, %c0] 
+///     %10 = vector.broadcast %9 : f32 to !type
+///     %11 = memref.load %subview_6[%c0, %c2, %c0] 
+///     %12 = vector.broadcast %11 : f32 to !type
+///     %13 = memref.load %subview_6[%c0, %c3, %c0] 
+///     %14 = vector.broadcast %13 : f32 to !type
+///     %subview_7 = memref.subview %0[%arg5, %arg10, %arg4] [1, 1, 64] [1, 1, 1] 
+///     %15 = vector.load %subview_7[%c0, %c0, %c0] 
+///     %16 = vector.fma %8, %15, %arg11 : !type
+///     %17 = vector.fma %10, %15, %arg12 : !type
+///     %18 = vector.fma %12, %15, %arg13 : !type
+///     %19 = vector.fma %14, %15, %arg14 : !type
+///     scf.yield %16, %17, %18, %19 : !type, !type, !type, !type
+///   }
+///   scf.yield %6#0, %6#1, %6#2, %6#3 : !type, !type, !type, !type
+/// }
+/// vector.store %5#0, %subview_2[%c0, %c0] 
+/// vector.store %5#1, %subview_3[%c0, %c0] 
+/// vector.store %5#2, %subview_4[%c0, %c0] 
+/// vector.store %5#3, %subview_5[%c0, %c0])
+///
+struct VectorContractToFMA
+    : public OpRewritePattern<vector::ContractionOp> {
+  using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::ContractionOp op,
+                                PatternRewriter &rewriter) const override {
+    if (op.getKind() != vector::CombiningKind::ADD)
+      return rewriter.notifyMatchFailure(
+          op, "Unsupported combining kind, only supports ADD at the moment)");
+
+    auto maskableOp = cast<vector::MaskableOpInterface>(op.getOperation());
+    if (maskableOp.isMasked())
+      return rewriter.notifyMatchFailure(op, "Masked contractOp not supported");
+
+    SmallVector<AffineMap, 3> maps = op.getIndexingMapsArray();
+    if (llvm::any_of(
+            maps, [](AffineMap map) { return !map.isProjectedPermutation(); }))
+      return rewriter.notifyMatchFailure(op, "Unexpected map");
+
+    // Check for the variant of matrix multiply.
+    auto iteratorTypes = op.getIteratorTypesArray();
+    MatMulType matmulType;
+    unsigned outerDimIndex = 0;
+    if (iteratorTypes.size() > 3) {
+      outerDimIndex = iteratorTypes.size() - 4;
+      matmulType =
+          iteratorTypes[outerDimIndex] == vector::IteratorType::parallel
+              ? MatMulType::Batch
+              : MatMulType::BatchReduce;
+      outerDimIndex++;
+    } else if (iteratorTypes.size() == 3) {
+      matmulType = MatMulType::Standard;
+    } else {
+      return rewriter.notifyMatchFailure(op, "Not a gemm");
+    }
+
+    if (matmulType == MatMulType::Batch)
+      return rewriter.notifyMatchFailure(op, "Batch matmul not supported");
+    if (iteratorTypes[outerDimIndex] != vector::IteratorType::parallel ||
+        iteratorTypes[outerDimIndex + 1] != vector::IteratorType::parallel ||
+        iteratorTypes[outerDimIndex + 2] != vector::IteratorType::reduction)
+      return rewriter.notifyMatchFailure(op, "Not a gemm");
+
+    SmallVector<Value, 4> results;
+
+    auto lhs = op.getLhs();
+    auto rhs = op.getRhs();
+    auto acc = op.getAcc();
+    auto lhsDefiningOp = lhs.getDefiningOp<vector::TransferReadOp>();
+    auto rhsDefiningOp = rhs.getDefiningOp<vector::TransferReadOp>();
+    auto accDefiningOp = acc.getDefiningOp<vector::TransferReadOp>();
+    if (!lhsDefiningOp || !rhsDefiningOp)
+      return failure();
+
+    // Accumulator can be a TransferReadOp but must be coming from the chain of
+    // iterargs of nested loop.
+    if (accDefiningOp)
+      return failure();
+
+    // Make sure the inputs being read are whole tensor or subview.
+    if (!llvm::all_of(lhsDefiningOp.getIndices(), isZeroIndex) ||
+        !llvm::all_of(rhsDefiningOp.getIndices(), isZeroIndex)) {
+      return failure();
+    }
+
+    auto lhsType = cast<ShapedType>(lhsDefiningOp.getType());
+    auto rhsType = cast<ShapedType>(rhsDefiningOp.getType());
+    // auto accType = acc.getType();
+    //  auto accType = cast<ShapedType>(accDefiningOp.getType());
+
+    if (matmulType == MatMulType::BatchReduce &&
+        (lhsType.getRank() != 3 || rhsType.getRank() != 3))
+      return failure();
+
+    if (matmulType == MatMulType::Standard &&
+        (lhsType.getRank() != 2 || rhsType.getRank() != 2))
+      return failure();
+
+    // Check for non-transposed matrices.
+    auto mapLHS = maps[0];
+    auto mapRHS = maps[1];
+    if (matmulType == MatMulType::BatchReduce) {
+      mapLHS = mapLHS.dropResult(0);
+      mapRHS = mapRHS.dropResult(0);
+    }
+    if (isTransposed(mapLHS) || isTransposed(mapRHS))
+      return rewriter.notifyMatchFailure(
+          op, "Transposed matrices are not expected");
+
+    // Verify that the accumulator is coming through a chain of iterargs of
+    // nested loop and it is define by 'TransferReadOp'.
+    //
+    struct TransformationContext ctx;
+
+    ctx.innerForOp = op->getParentOfType<scf::ForOp>();
+    if (!ctx.innerForOp)
+      return failure();
+    ctx.outerForOp = ctx.innerForOp->getParentOfType<scf::ForOp>();
+    if (!ctx.outerForOp)
+      return failure();
+    ctx.outermostLoop = ctx.outerForOp->getParentOfType<scf::ForOp>();
+    if (!ctx.outermostLoop)
+      return failure();
+
+    // Verify original inner loop has only one iterarg.
+    auto origIterArgs = ctx.innerForOp.getRegionIterArgs();
+    if (origIterArgs.size() != 1)
+      return failure();
+
+    // Verify chain, accumulator must be inner loop's iterarg.
+    auto bbArg = dyn_cast<BlockArgument>(acc);
+    if (!bbArg)
+      return failure();
+
+    // This block arg must be init arg, not induction variable.
+    if (bbArg.getOwner() != ctx.innerForOp.getBody() ||
+        bbArg.getArgNumber() == 0) {
+      return failure();
+    }
+
+    // This iterarg must be intialized by outer loop's iterarg.
+    auto innerInitValue =
+        ctx.innerForOp.getInitArgs()[bbArg.getArgNumber() - 1];
+    auto outerBBArg = dyn_cast<BlockArgument>(innerInitValue);
+    if (!outerBBArg)
+      return failure();
+
+    // This block arg must be init arg, not induction variable.
+    if (outerBBArg.getOwner() != ctx.outerForOp.getBody() ||
+        outerBBArg.getArgNumber() == 0) {
+      return failure();
+    }
+
+    // Outer loop's iterarg initializer must be a TransferReadOp.
+    acc = ctx.outerForOp.getInitArgs()[outerBBArg.getArgNumber() - 1];
+
+    //  This must be defined by vector.transfer_read
+    if (!acc.getDefiningOp<vector::TransferReadOp>())
+      return failure();
+
+    accDefiningOp = acc.getDefiningOp<vector::TransferReadOp>();
+    if (!accDefiningOp)
+      return failure();
+
+    // Only 2-D output expected.
+    auto accType = cast<ShapedType>(accDefiningOp.getType());
+    if (accType.getRank() != 2)
+      return failure();
+
+    int64_t M = accType.getDimSize(0);
+    int64_t N = accType.getDimSize(1);
+    int64_t K = lhsType.getDimSize(lhsType.getRank() - 1);
+
+    // K must be 1.
+    if (K != 1)
+      return failure();
+
+    auto accSubview = accDefiningOp.getSource();
+    Location loc = op.getLoc();
+
+    // Create M different <1xN> subviews.
+    auto memrefType = cast<MemRefType>(accSubview.getType());
+    auto elementType = memrefType.getElementType();
+    SmallVector<OpFoldResult> mixedSizes = {rewriter.getIndexAttr(K),
+                                            rewriter.getIndexAttr(N)};
+    SmallVector<OpFoldResult> mixedStrides = {rewriter.getIndexAttr(1),
+                                              rewriter.getIndexAttr(1)};
+
+    rewriter.setInsertionPoint(
+        ctx.outermostLoop.getBody(),
+        std::prev(ctx.outermostLoop.getBody()->end(), 1));
+
+    Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+    SmallVector<Value, 4> subview_2_splits;
+    for (int i = 0; i < M; i++) {
+      SmallVector<OpFoldResult> mixedOffsets = {
+          rewriter.getIndexAttr(i),
+          rewriter.getIndexAttr(0),
+      };
+      auto split = rewriter.create<memref::SubViewOp>(
+          loc, accSubview, mixedOffsets, mixedSizes, mixedStrides);
+      subview_2_splits.push_back(split);
+    }
+
+    // Intialize each accumulator with a vector of size N
+    SmallVector<Value, 4> initAccs;
+    for (auto subview : subview_2_splits) {
+      auto acc = rewriter.create<vector::LoadOp>(
+          loc, VectorType::get({N}, elementType), subview, ValueRange{c0, c0});
+      initAccs.push_back(acc);
+    }
+
+    // Create new outer loop with M different accumulators.
+    auto newOuterForOp = rewriter.create<scf::ForOp>(
+        loc, ctx.outerForOp.getLowerBound(), ctx.outerForOp.getUpperBound(),
+        ctx.outerForOp.getStep(), initAccs,
+        [&](OpBuilder &nestedBuilder, Location loc, Value iv,
+            ValueRange iterArgs) {
+          // Create new inner loop with M accumulators.
+          auto newInnerForOp = nestedBuilder.create<scf::ForOp>(
+              loc, ctx.innerForOp.getLowerBound(),
+              ctx.innerForOp.getUpperBound(), ctx.innerForOp.getStep(),
+              iterArgs,
+              [&](OpBuilder &innerBuilder, Location loc, Value innerIv,
+                  ValueRange innerIterArgs) {
+                IRMapping mapping;
+                mapping.map(
+                    lhsDefiningOp.getSource().getDefiningOp()->getOperand(1),
+                    iv);
+                mapping.map(
+                    lhsDefiningOp.getSource().getDefiningOp()->getOperand(3),
+                    innerIv);
+                auto lhsClone = innerBuilder.clone(
+                    *lhsDefiningOp.getSource().getDefiningOp(), mapping);
+
+                // Load and broadcast individual elements
+                SmallVector<Value, 4> broadcasts;
+                for (int i = 0; i < M; i++) {
+                  auto elem = innerBuilder.create<memref::LoadOp>(
+                      loc, lhsClone->getResult(0),
+                      ValueRange{
+                          c0,
+                          innerBuilder.create<arith::ConstantIndexOp>(loc, i),
+                          c0});
+                  auto bcast = innerBuilder.create<vector::BroadcastOp>(
+                      loc, VectorType::get({N}, elem.getType()), elem);
+                  broadcasts.push_back(bcast);
+                }
+
+                IRMapping rhsMapping;
+                rhsMapping.map(
+                    rhsDefiningOp.getSource().getDefiningOp()->getOperand(1),
+                    iv);
+                rhsMapping.map(
+                    rhsDefiningOp.getSource().getDefiningOp()->getOperand(2),
+                    innerIv);
+                auto rhsClone = innerBuilder.clone(
+                    *rhsDefiningOp.getSource().getDefiningOp(), rhsMapping);
+                auto rowVec = innerBuilder.create<vector::LoadOp>(
+                    loc, VectorType::get({N}, elementType),
+                    rhsClone->getResult(0), ValueRange{c0, c0, c0});
+
+                // Create M different FMAs using broadcasts and current
+                // accumulator values.
+                for (int i = 0; i < M; i++) {
+                  auto fma = innerBuilder.create<vector::FMAOp>(
+                      loc, broadcasts[i], rowVec, innerIterArgs[i]);
+                  results.push_back(fma);
+                }
+
+                // Yield all M results
+                innerBuilder.create<scf::YieldOp>(loc, results);
+              });
+
+          // Yield results from inner loop to outer loop
+          nestedBuilder.create<scf::YieldOp>(loc, newInnerForOp.getResults());
+        });
+
+    Value matResult = ctx.outerForOp.getResult(0);
+    Operation *writeOp;
+    for (auto user : matResult.getUsers()) {
+      writeOp = dyn_cast<vector::TransferWriteOp>(user);
+      if (writeOp)
+        break;
+    }
+
+    // Store final results back to original locations.
+    if (writeOp) {
+      for (int i = 0; i < M; i++) {
+        rewriter.create<vector::StoreOp>(loc, newOuterForOp.getResult(i),
+                                         subview_2_splits[i],
+                                         ValueRange{c0, c0});
+      }
+    }
+
+    // Erase original write.
+    if (writeOp)
+      rewriter.eraseOp(writeOp);
+
+    return success();
+  }
+
+};
+
+void linalg::populateVectorContractToFMAPatterns(RewritePatternSet &patterns) {
+  patterns.add<VectorContractToFMA>(patterns.getContext());
+}
diff --git a/mlir/test/Dialect/Linalg/vector-contract-to-fma.mlir b/mlir/test/Dialect/Linalg/vector-contract-to-fma.mlir
new file mode 100644
index 00000000000000..f18c0dcb573d70
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/vector-contract-to-fma.mlir
@@ -0,0 +1,48 @@
+// RUN: mlir-opt %s -transform-interpreter -cse -split-input-file | FileCheck %s
+
+#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+#mapTransposeB = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
+#map2 = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+
+  func.func @transpose_matrix_no_conversion_to_fma(%arg0: memref<16x32x128xf32>, %arg1: memref<16x128x64xf32>, %arg2: memref<32x64xf32>) {
+    %cst = arith.constant 0.000000e+00 : f32
+    %c1 = arith.constant 1 : index
+    %c16 = arith.constant 16 : index
+    %c64 = arith.constant 64 : index
+    %c128 = arith.constant 128 : index
+    %c4 = arith.constant 4 : index
+    %c32 = arith.constant 32 : index
+    %c0 = arith.constant 0 : index
+
+    scf.for %arg5 = %c0 to %c32 step %c4 {
+      scf.for %arg6 = %c0 to %c128 step %c64 {
+        %subview_2 = memref.subview %arg2[%arg5, %arg6] [4, 64] [1, 1] : memref<32x64xf32> to memref<4x64xf32, strided<[64, 1], offset: ?>>
+        %2 = vector.transfer_read %subview_2[%c0, %c0], %cst {in_bounds = [true, true]} : memref<4x64xf32, strided<[64, 1], offset: ?>>, vector<4x64xf32>
+        %con = scf.for %arg7 = %c0 to %c16 step %c1 iter_args(%argcon = %2) -> vector<4x64xf32> {
+          %con1 = scf.for %arg8 = %c0 to %c64 step %c1 iter_args(%argcon1 = %argcon) -> vector<4x64xf32> {
+            %subview_3 = memref.subview %arg0[%arg7, %arg5, %arg8] [1, 4, 1] [1, 1, 1] : memref<16x32x128xf32> to memref<1x4x1xf32, strided<[4096, 128, 1], offset: ?>>
+            %subview_4 = memref.subview %arg1[%arg7, %arg8, %arg6] [1, 1, 64] [1, 1, 1] : memref<16x128x64xf32> to memref<1x1x64xf32, strided<[8192, 64, 1], offset: ?>>
+            %0 = vector.transfer_read %subview_3[%c0, %c0, %c0], %cst {in_bounds = [true, true, true]} : memref<1x4x1xf32, strided<[4096, 128, 1], offset: ?>>, vector<1x4x1xf32>
+            %1 = vector.transfer_read %subview_4[%c0, %c0, %c0], %cst {permutation_map = affine_map<(d0, d1, d2) -> (d0, d2, d1)>, in_bounds = [true, true, true]} : memref<1x1x64xf32, strided<[8192, 64, 1], offset: ?>>, vector<1x64x1xf32>
+            %3 = vector.contract {indexing_maps = [#map, #mapTransposeB, #map2], iterator_types = ["reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %0, %1, %argcon1 : vector<1x4x1xf32>, vector<1x64x1xf32> into vector<4x64xf32>
+            scf.yield %3 : vector<4x64xf32>
+          }
+          scf.yield %con1 : vector<4x64xf32>
+        }
+        vector.transfer_write %con, %subview_2[%c0, %c0] {in_bounds = [true, true]} : vector<4x64xf32>, memref<4x64xf32, strided<[64, 1], offset: ?>>
+      }
+    }
+    return
+  }
+
+// CHECK-NOT: vector.fma
+
+  module attributes {transform.with_named_sequence} {
+    transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+      %0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+      transform.apply_patterns to %0 {
+        transform.apply_patterns.vector.contract_to_fma
+      } : !transform.any_op
+      transform.yield
+    }
+  }

>From 927c5149687660599eb1044c8e7789c63ee92b19 Mon Sep 17 00:00:00 2001
From: Arun Thangamani <arun.thangamani at intel.com>
Date: Mon, 6 Jan 2025 20:50:10 -0800
Subject: [PATCH 2/2] added 2 more test-cases

---
 .../Linalg/vector-contract-to-fma.mlir        | 209 +++++++++++++++---
 1 file changed, 175 insertions(+), 34 deletions(-)

diff --git a/mlir/test/Dialect/Linalg/vector-contract-to-fma.mlir b/mlir/test/Dialect/Linalg/vector-contract-to-fma.mlir
index f18c0dcb573d70..6c61ab0368c3fd 100644
--- a/mlir/test/Dialect/Linalg/vector-contract-to-fma.mlir
+++ b/mlir/test/Dialect/Linalg/vector-contract-to-fma.mlir
@@ -1,48 +1,189 @@
 // RUN: mlir-opt %s -transform-interpreter -cse -split-input-file | FileCheck %s
 
 #map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
-#mapTransposeB = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
+#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
 #map2 = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
-
-  func.func @transpose_matrix_no_conversion_to_fma(%arg0: memref<16x32x128xf32>, %arg1: memref<16x128x64xf32>, %arg2: memref<32x64xf32>) {
-    %cst = arith.constant 0.000000e+00 : f32
-    %c1 = arith.constant 1 : index
-    %c16 = arith.constant 16 : index
-    %c64 = arith.constant 64 : index
-    %c128 = arith.constant 128 : index
-    %c4 = arith.constant 4 : index
-    %c32 = arith.constant 32 : index
-    %c0 = arith.constant 0 : index
-
-    scf.for %arg5 = %c0 to %c32 step %c4 {
-      scf.for %arg6 = %c0 to %c128 step %c64 {
-        %subview_2 = memref.subview %arg2[%arg5, %arg6] [4, 64] [1, 1] : memref<32x64xf32> to memref<4x64xf32, strided<[64, 1], offset: ?>>
-        %2 = vector.transfer_read %subview_2[%c0, %c0], %cst {in_bounds = [true, true]} : memref<4x64xf32, strided<[64, 1], offset: ?>>, vector<4x64xf32>
-        %con = scf.for %arg7 = %c0 to %c16 step %c1 iter_args(%argcon = %2) -> vector<4x64xf32> {
-          %con1 = scf.for %arg8 = %c0 to %c64 step %c1 iter_args(%argcon1 = %argcon) -> vector<4x64xf32> {
-            %subview_3 = memref.subview %arg0[%arg7, %arg5, %arg8] [1, 4, 1] [1, 1, 1] : memref<16x32x128xf32> to memref<1x4x1xf32, strided<[4096, 128, 1], offset: ?>>
-            %subview_4 = memref.subview %arg1[%arg7, %arg8, %arg6] [1, 1, 64] [1, 1, 1] : memref<16x128x64xf32> to memref<1x1x64xf32, strided<[8192, 64, 1], offset: ?>>
-            %0 = vector.transfer_read %subview_3[%c0, %c0, %c0], %cst {in_bounds = [true, true, true]} : memref<1x4x1xf32, strided<[4096, 128, 1], offset: ?>>, vector<1x4x1xf32>
-            %1 = vector.transfer_read %subview_4[%c0, %c0, %c0], %cst {permutation_map = affine_map<(d0, d1, d2) -> (d0, d2, d1)>, in_bounds = [true, true, true]} : memref<1x1x64xf32, strided<[8192, 64, 1], offset: ?>>, vector<1x64x1xf32>
-            %3 = vector.contract {indexing_maps = [#map, #mapTransposeB, #map2], iterator_types = ["reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %0, %1, %argcon1 : vector<1x4x1xf32>, vector<1x64x1xf32> into vector<4x64xf32>
-            scf.yield %3 : vector<4x64xf32>
+memref.global "private" constant @__constant_24x64x64xf32 : memref<24x64x64xf32> = dense<1.000000e+00> {alignment = 64 : i64}
+func.func @lower_contract_to_fma(%arg0: memref<8x24x32x64xf32>) -> memref<8x24x32x64xf32> {
+  %cst = arith.constant 0.000000e+00 : f32
+  %cst_0 = arith.constant dense<0.000000e+00> : vector<32x64xf32>
+  %c1 = arith.constant 1 : index
+  %c24 = arith.constant 24 : index
+  %c64 = arith.constant 64 : index
+  %c4 = arith.constant 4 : index
+  %c32 = arith.constant 32 : index
+  %c0 = arith.constant 0 : index
+  %0 = memref.get_global @__constant_24x64x64xf32 : memref<24x64x64xf32>
+  %alloc = memref.alloc() {alignment = 64 : i64} : memref<8x24x32x64xf32>
+  scf.forall (%arg1, %arg2) in (8, 24) {
+    %subview = memref.subview %alloc[%arg1, %arg2, 0, 0] [1, 1, 32, 64] [1, 1, 1, 1] : memref<8x24x32x64xf32> to memref<32x64xf32, strided<[64, 1], offset: ?>>
+    vector.transfer_write %cst_0, %subview[%c0, %c0] {in_bounds = [true, true]} : vector<32x64xf32>, memref<32x64xf32, strided<[64, 1], offset: ?>>
+    %subview_1 = memref.subview %arg0[%arg1, 0, 0, 0] [1, 24, 32, 64] [1, 1, 1, 1] : memref<8x24x32x64xf32> to memref<24x32x64xf32, strided<[2048, 64, 1], offset: ?>>
+    scf.for %arg3 = %c0 to %c32 step %c4 {
+      scf.for %arg4 = %c0 to %c64 step %c64 {
+        %subview_2 = memref.subview %subview[%arg3, %arg4] [4, 64] [1, 1] : memref<32x64xf32, strided<[64, 1], offset: ?>> to memref<4x64xf32, strided<[64, 1], offset: ?>>
+        %1 = vector.transfer_read %subview_2[%c0, %c0], %cst {in_bounds = [true, true]} : memref<4x64xf32, strided<[64, 1], offset: ?>>, vector<4x64xf32>
+        %2 = scf.for %arg5 = %c0 to %c24 step %c1 iter_args(%arg6 = %1) -> (vector<4x64xf32>) {
+          %3 = scf.for %arg7 = %c0 to %c64 step %c1 iter_args(%arg8 = %arg6) -> (vector<4x64xf32>) {
+            %subview_3 = memref.subview %subview_1[%arg5, %arg3, %arg7] [1, 4, 1] [1, 1, 1] : memref<24x32x64xf32, strided<[2048, 64, 1], offset: ?>> to memref<1x4x1xf32, strided<[2048, 64, 1], offset: ?>>
+            %subview_4 = memref.subview %0[%arg5, %arg7, %arg4] [1, 1, 64] [1, 1, 1] : memref<24x64x64xf32> to memref<1x1x64xf32, strided<[4096, 64, 1], offset: ?>>
+            %4 = vector.transfer_read %subview_3[%c0, %c0, %c0], %cst {in_bounds = [true, true, true]} : memref<1x4x1xf32, strided<[2048, 64, 1], offset: ?>>, vector<1x4x1xf32>
+            %5 = vector.transfer_read %subview_4[%c0, %c0, %c0], %cst {in_bounds = [true, true, true]} : memref<1x1x64xf32, strided<[4096, 64, 1], offset: ?>>, vector<1x1x64xf32>
+            %6 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %4, %5, %arg8 : vector<1x4x1xf32>, vector<1x1x64xf32> into vector<4x64xf32>
+            scf.yield %6 : vector<4x64xf32>
           }
-          scf.yield %con1 : vector<4x64xf32>
+          scf.yield %3 : vector<4x64xf32>
         }
-        vector.transfer_write %con, %subview_2[%c0, %c0] {in_bounds = [true, true]} : vector<4x64xf32>, memref<4x64xf32, strided<[64, 1], offset: ?>>
+        vector.transfer_write %2, %subview_2[%c0, %c0] {in_bounds = [true, true]} : vector<4x64xf32>, memref<4x64xf32, strided<[64, 1], offset: ?>>
       }
     }
-    return
   }
+  return %alloc : memref<8x24x32x64xf32>
+}
+
+// CHECK-LABEL:   memref.global "private" constant @__constant_24x64x64xf32 : memref<24x64x64xf32> = dense<1.000000e+00> {alignment = 64 : i64}
+// CHECK-LABEL:   func.func @lower_contract_to_fma(
+// CHECK-SAME:      %[[VAL_0:.*]]: memref<8x24x32x64xf32>) -> memref<8x24x32x64xf32> {
+// CHECK:           %[[VAL_1:.*]] = arith.constant 3 : index
+// CHECK:           %[[VAL_2:.*]] = arith.constant 2 : index
+// CHECK:           %[[VAL_3:.*]] = arith.constant dense<0.000000e+00> : vector<32x64xf32>
+// CHECK:           %[[VAL_4:.*]] = arith.constant 1 : index
+// CHECK:           %[[VAL_5:.*]] = arith.constant 24 : index
+// CHECK:           %[[VAL_6:.*]] = arith.constant 64 : index
+// CHECK:           %[[VAL_7:.*]] = arith.constant 4 : index
+// CHECK:           %[[VAL_8:.*]] = arith.constant 32 : index
+// CHECK:           %[[VAL_9:.*]] = arith.constant 0 : index
+// CHECK:           %[[VAL_10:.*]] = memref.get_global @__constant_24x64x64xf32 : memref<24x64x64xf32>
+// CHECK:           %[[VAL_11:.*]] = memref.alloc() {alignment = 64 : i64} : memref<8x24x32x64xf32>
+// CHECK:           scf.forall (%[[VAL_12:.*]], %[[VAL_13:.*]]) in (8, 24) {
+// CHECK:             %[[VAL_14:.*]] = memref.subview %[[VAL_11]]{{\[}}%[[VAL_12]], %[[VAL_13]], 0, 0] [1, 1, 32, 64] [1, 1, 1, 1] : memref<8x24x32x64xf32> to memref<32x64xf32, strided<[64, 1], offset: ?>>
+// CHECK:             vector.transfer_write %[[VAL_3]], %[[VAL_14]]{{\[}}%[[VAL_9]], %[[VAL_9]]] {in_bounds = [true, true]} : vector<32x64xf32>, memref<32x64xf32, strided<[64, 1], offset: ?>>
+// CHECK:             %[[VAL_15:.*]] = memref.subview %[[VAL_0]]{{\[}}%[[VAL_12]], 0, 0, 0] [1, 24, 32, 64] [1, 1, 1, 1] : memref<8x24x32x64xf32> to memref<24x32x64xf32, strided<[2048, 64, 1], offset: ?>>
+// CHECK:             scf.for %[[VAL_16:.*]] = %[[VAL_9]] to %[[VAL_8]] step %[[VAL_7]] {
+// CHECK:               scf.for %[[VAL_17:.*]] = %[[VAL_9]] to %[[VAL_6]] step %[[VAL_6]] {
+// CHECK:                 %[[VAL_18:.*]] = memref.subview %[[VAL_14]]{{\[}}%[[VAL_16]], %[[VAL_17]]] [4, 64] [1, 1] : memref<32x64xf32, strided<[64, 1], offset: ?>> to memref<4x64xf32, strided<[64, 1], offset: ?>>
+// CHECK:                 %[[VAL_19:.*]] = memref.subview %[[VAL_18]][0, 0] [1, 64] [1, 1] : memref<4x64xf32, strided<[64, 1], offset: ?>> to memref<1x64xf32, strided<[64, 1], offset: ?>>
+// CHECK:                 %[[VAL_20:.*]] = memref.subview %[[VAL_18]][1, 0] [1, 64] [1, 1] : memref<4x64xf32, strided<[64, 1], offset: ?>> to memref<1x64xf32, strided<[64, 1], offset: ?>>
+// CHECK:                 %[[VAL_21:.*]] = memref.subview %[[VAL_18]][2, 0] [1, 64] [1, 1] : memref<4x64xf32, strided<[64, 1], offset: ?>> to memref<1x64xf32, strided<[64, 1], offset: ?>>
+// CHECK:                 %[[VAL_22:.*]] = memref.subview %[[VAL_18]][3, 0] [1, 64] [1, 1] : memref<4x64xf32, strided<[64, 1], offset: ?>> to memref<1x64xf32, strided<[64, 1], offset: ?>>
+// CHECK:                 %[[VAL_23:.*]] = vector.load %[[VAL_19]]{{\[}}%[[VAL_9]], %[[VAL_9]]] : memref<1x64xf32, strided<[64, 1], offset: ?>>, vector<64xf32>
+// CHECK:                 %[[VAL_24:.*]] = vector.load %[[VAL_20]]{{\[}}%[[VAL_9]], %[[VAL_9]]] : memref<1x64xf32, strided<[64, 1], offset: ?>>, vector<64xf32>
+// CHECK:                 %[[VAL_25:.*]] = vector.load %[[VAL_21]]{{\[}}%[[VAL_9]], %[[VAL_9]]] : memref<1x64xf32, strided<[64, 1], offset: ?>>, vector<64xf32>
+// CHECK:                 %[[VAL_26:.*]] = vector.load %[[VAL_22]]{{\[}}%[[VAL_9]], %[[VAL_9]]] : memref<1x64xf32, strided<[64, 1], offset: ?>>, vector<64xf32>
+// CHECK:                 %[[VAL_27:.*]]:4 = scf.for %[[VAL_28:.*]] = %[[VAL_9]] to %[[VAL_5]] step %[[VAL_4]] iter_args(%[[VAL_29:.*]] = %[[VAL_23]], %[[VAL_30:.*]] = %[[VAL_24]], %[[VAL_31:.*]] = %[[VAL_25]], %[[VAL_32:.*]] = %[[VAL_26]]) -> (vector<64xf32>, vector<64xf32>, vector<64xf32>, vector<64xf32>) {
+// CHECK:                   %[[VAL_33:.*]]:4 = scf.for %[[VAL_34:.*]] = %[[VAL_9]] to %[[VAL_6]] step %[[VAL_4]] iter_args(%[[VAL_35:.*]] = %[[VAL_29]], %[[VAL_36:.*]] = %[[VAL_30]], %[[VAL_37:.*]] = %[[VAL_31]], %[[VAL_38:.*]] = %[[VAL_32]]) -> (vector<64xf32>, vector<64xf32>, vector<64xf32>, vector<64xf32>) {
+// CHECK:                     %[[VAL_39:.*]] = memref.subview %[[VAL_15]]{{\[}}%[[VAL_28]], %[[VAL_16]], %[[VAL_34]]] [1, 4, 1] [1, 1, 1] : memref<24x32x64xf32, strided<[2048, 64, 1], offset: ?>> to memref<1x4x1xf32, strided<[2048, 64, 1], offset: ?>>
+// CHECK:                     %[[VAL_40:.*]] = memref.load %[[VAL_39]]{{\[}}%[[VAL_9]], %[[VAL_9]], %[[VAL_9]]] : memref<1x4x1xf32, strided<[2048, 64, 1], offset: ?>>
+// CHECK:                     %[[VAL_41:.*]] = vector.broadcast %[[VAL_40]] : f32 to vector<64xf32>
+// CHECK:                     %[[VAL_42:.*]] = memref.load %[[VAL_39]]{{\[}}%[[VAL_9]], %[[VAL_4]], %[[VAL_9]]] : memref<1x4x1xf32, strided<[2048, 64, 1], offset: ?>>
+// CHECK:                     %[[VAL_43:.*]] = vector.broadcast %[[VAL_42]] : f32 to vector<64xf32>
+// CHECK:                     %[[VAL_44:.*]] = memref.load %[[VAL_39]]{{\[}}%[[VAL_9]], %[[VAL_2]], %[[VAL_9]]] : memref<1x4x1xf32, strided<[2048, 64, 1], offset: ?>>
+// CHECK:                     %[[VAL_45:.*]] = vector.broadcast %[[VAL_44]] : f32 to vector<64xf32>
+// CHECK:                     %[[VAL_46:.*]] = memref.load %[[VAL_39]]{{\[}}%[[VAL_9]], %[[VAL_1]], %[[VAL_9]]] : memref<1x4x1xf32, strided<[2048, 64, 1], offset: ?>>
+// CHECK:                     %[[VAL_47:.*]] = vector.broadcast %[[VAL_46]] : f32 to vector<64xf32>
+// CHECK:                     %[[VAL_48:.*]] = memref.subview %[[VAL_10]]{{\[}}%[[VAL_28]], %[[VAL_34]], %[[VAL_17]]] [1, 1, 64] [1, 1, 1] : memref<24x64x64xf32> to memref<1x1x64xf32, strided<[4096, 64, 1], offset: ?>>
+// CHECK:                     %[[VAL_49:.*]] = vector.load %[[VAL_48]]{{\[}}%[[VAL_9]], %[[VAL_9]], %[[VAL_9]]] : memref<1x1x64xf32, strided<[4096, 64, 1], offset: ?>>, vector<64xf32>
+// CHECK:                     %[[VAL_50:.*]] = vector.fma %[[VAL_41]], %[[VAL_49]], %[[VAL_35]] : vector<64xf32>
+// CHECK:                     %[[VAL_51:.*]] = vector.fma %[[VAL_43]], %[[VAL_49]], %[[VAL_36]] : vector<64xf32>
+// CHECK:                     %[[VAL_52:.*]] = vector.fma %[[VAL_45]], %[[VAL_49]], %[[VAL_37]] : vector<64xf32>
+// CHECK:                     %[[VAL_53:.*]] = vector.fma %[[VAL_47]], %[[VAL_49]], %[[VAL_38]] : vector<64xf32>
+// CHECK:                     scf.yield %[[VAL_50]], %[[VAL_51]], %[[VAL_52]], %[[VAL_53]] : vector<64xf32>, vector<64xf32>, vector<64xf32>, vector<64xf32>
+// CHECK:                   }
+// CHECK:                   scf.yield %[[VAL_54:.*]]#0, %[[VAL_54]]#1, %[[VAL_54]]#2, %[[VAL_54]]#3 : vector<64xf32>, vector<64xf32>, vector<64xf32>, vector<64xf32>
+// CHECK:                 }
+// CHECK:                 vector.store %[[VAL_55:.*]]#0, %[[VAL_19]]{{\[}}%[[VAL_9]], %[[VAL_9]]] : memref<1x64xf32, strided<[64, 1], offset: ?>>, vector<64xf32>
+// CHECK:                 vector.store %[[VAL_55]]#1, %[[VAL_20]]{{\[}}%[[VAL_9]], %[[VAL_9]]] : memref<1x64xf32, strided<[64, 1], offset: ?>>, vector<64xf32>
+// CHECK:                 vector.store %[[VAL_55]]#2, %[[VAL_21]]{{\[}}%[[VAL_9]], %[[VAL_9]]] : memref<1x64xf32, strided<[64, 1], offset: ?>>, vector<64xf32>
+// CHECK:                 vector.store %[[VAL_55]]#3, %[[VAL_22]]{{\[}}%[[VAL_9]], %[[VAL_9]]] : memref<1x64xf32, strided<[64, 1], offset: ?>>, vector<64xf32>
+// CHECK:               }
+// CHECK:             }
+// CHECK:           }
+// CHECK:           return %[[VAL_11]] : memref<8x24x32x64xf32>
+// CHECK:         }
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %0 {
+      transform.apply_patterns.vector.contract_to_fma
+    } : !transform.any_op
+    transform.yield
+  }
+}
+
+//-----
+
+#mapA = affine_map<(d0, d1, d2) -> (d0, d2)>
+#mapB = affine_map<(d0, d1, d2) -> (d2, d1)>
+#mapC = affine_map<(d0, d1, d2) -> (d0, d1)>
+func.func @matmul_without_iterarg_accumulator_so_no_lowering_to_fma(%arg0: tensor<4x1xf32>, %arg1: tensor<1x64xf32>, %arg2: tensor<4x64xf32>) -> tensor<4x64xf32> {
+  %c0 = arith.constant 0 : index
+  %cst = arith.constant 0.000000e+00 : f32
+  %0 = vector.transfer_read %arg0[%c0, %c0], %cst {in_bounds = [true, true]} : tensor<4x1xf32>, vector<4x1xf32>
+  %1 = vector.transfer_read %arg1[%c0, %c0], %cst {in_bounds = [true, true]} : tensor<1x64xf32>, vector<1x64xf32>
+  %2 = vector.transfer_read %arg2[%c0, %c0], %cst {in_bounds = [true, true]} : tensor<4x64xf32>, vector<4x64xf32>
+  %3 = vector.contract {indexing_maps = [#mapA, #mapB, #mapC], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %0, %1, %2 : vector<4x1xf32>, vector<1x64xf32> into vector<4x64xf32>
+  %4 = vector.transfer_write %3, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<4x64xf32>, tensor<4x64xf32>
+  return %4 : tensor<4x64xf32>
+}
 
 // CHECK-NOT: vector.fma
 
-  module attributes {transform.with_named_sequence} {
-    transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
-      %0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
-      transform.apply_patterns to %0 {
-        transform.apply_patterns.vector.contract_to_fma
-      } : !transform.any_op
-      transform.yield
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %0 {
+      transform.apply_patterns.vector.contract_to_fma
+    } : !transform.any_op
+    transform.yield
+  }
+}
+
+
+// -----
+
+#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+#mapTransposeB = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
+#map2 = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+func.func @transpose_matrix_no_lowering_to_fma(%arg0: memref<16x32x128xf32>, %arg1: memref<16x128x64xf32>, %arg2: memref<32x64xf32>) {
+  %cst = arith.constant 0.000000e+00 : f32
+  %c1 = arith.constant 1 : index
+  %c16 = arith.constant 16 : index
+  %c64 = arith.constant 64 : index
+  %c128 = arith.constant 128 : index
+  %c4 = arith.constant 4 : index
+  %c32 = arith.constant 32 : index
+  %c0 = arith.constant 0 : index
+
+  scf.for %arg5 = %c0 to %c32 step %c4 {
+    scf.for %arg6 = %c0 to %c128 step %c64 {
+      %subview_2 = memref.subview %arg2[%arg5, %arg6] [4, 64] [1, 1] : memref<32x64xf32> to memref<4x64xf32, strided<[64, 1], offset: ?>>
+      %2 = vector.transfer_read %subview_2[%c0, %c0], %cst {in_bounds = [true, true]} : memref<4x64xf32, strided<[64, 1], offset: ?>>, vector<4x64xf32>
+      %con = scf.for %arg7 = %c0 to %c16 step %c1 iter_args(%argcon = %2) -> vector<4x64xf32> {
+        %con1 = scf.for %arg8 = %c0 to %c64 step %c1 iter_args(%argcon1 = %argcon) -> vector<4x64xf32> {
+          %subview_3 = memref.subview %arg0[%arg7, %arg5, %arg8] [1, 4, 1] [1, 1, 1] : memref<16x32x128xf32> to memref<1x4x1xf32, strided<[4096, 128, 1], offset: ?>>
+          %subview_4 = memref.subview %arg1[%arg7, %arg8, %arg6] [1, 1, 64] [1, 1, 1] : memref<16x128x64xf32> to memref<1x1x64xf32, strided<[8192, 64, 1], offset: ?>>
+          %0 = vector.transfer_read %subview_3[%c0, %c0, %c0], %cst {in_bounds = [true, true, true]} : memref<1x4x1xf32, strided<[4096, 128, 1], offset: ?>>, vector<1x4x1xf32>
+          %1 = vector.transfer_read %subview_4[%c0, %c0, %c0], %cst {permutation_map = affine_map<(d0, d1, d2) -> (d0, d2, d1)>, in_bounds = [true, true, true]} : memref<1x1x64xf32, strided<[8192, 64, 1], offset: ?>>, vector<1x64x1xf32>
+          %3 = vector.contract {indexing_maps = [#map, #mapTransposeB, #map2], iterator_types = ["reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %0, %1, %argcon1 : vector<1x4x1xf32>, vector<1x64x1xf32> into vector<4x64xf32>
+          scf.yield %3 : vector<4x64xf32>
+        }
+        scf.yield %con1 : vector<4x64xf32>
+      }
+      vector.transfer_write %con, %subview_2[%c0, %c0] {in_bounds = [true, true]} : vector<4x64xf32>, memref<4x64xf32, strided<[64, 1], offset: ?>>
     }
   }
+  return
+}
+
+// CHECK-NOT: vector.fma
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %0 {
+      transform.apply_patterns.vector.contract_to_fma
+    } : !transform.any_op
+    transform.yield
+  }
+}



More information about the Mlir-commits mailing list