[Mlir-commits] [mlir] [MLIR][Linalg] Transform pass to optimize and lower vector contract operation to FMA (PR #121748)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Jan 6 21:21:54 PST 2025


https://github.com/arun-thmn updated https://github.com/llvm/llvm-project/pull/121748

>From 84130f65a351a731f6ba153f5b147db59d7ec09c Mon Sep 17 00:00:00 2001
From: Arun Thangamani <arun.thangamani at intel.com>
Date: Thu, 2 Jan 2025 07:27:29 -0800
Subject: [PATCH 1/3] =?UTF-8?q?initial=20changes=20for=20upstreaming=20hoi?=
 =?UTF-8?q?st=20vector=20transfers=C2=A0and=20contract=20to=20fma?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

---
 .../Linalg/TransformOps/LinalgTransformOps.td |  24 ++
 .../Dialect/Linalg/Transforms/Transforms.h    |  10 +
 .../TransformOps/LinalgTransformOps.cpp       |  13 +
 .../Dialect/Linalg/Transforms/CMakeLists.txt  |   2 +
 .../Transforms/HoistVectorTransfers.cpp       | 234 ++++++++++++
 .../Linalg/Transforms/VectorContractToFMA.cpp | 356 ++++++++++++++++++
 .../Dialect/Linalg/hoist-vector-transfer.mlir |  97 +++++
 .../Linalg/vector-contract-to-fma.mlir        | 113 ++++++
 8 files changed, 849 insertions(+)
 create mode 100644 mlir/lib/Dialect/Linalg/Transforms/HoistVectorTransfers.cpp
 create mode 100644 mlir/lib/Dialect/Linalg/Transforms/VectorContractToFMA.cpp
 create mode 100644 mlir/test/Dialect/Linalg/hoist-vector-transfer.mlir
 create mode 100644 mlir/test/Dialect/Linalg/vector-contract-to-fma.mlir

diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 2e713bca24efc5..d614bb4789767b 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -106,6 +106,30 @@ def ApplyFoldAddIntoDestPatternsOp : Op<Transform_Dialect,
   let assemblyFormat = "attr-dict";
 }
 
+
+def ApplyHoistVectorTransferPatternsOp : Op<Transform_Dialect,
+    "apply_patterns.vector.hoist_vector_transfer",
+    [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+  let description = [{
+    Hoists the vector transfer reads/writes outside the reduction and k-loop.
+  }];
+
+  let assemblyFormat = "attr-dict";
+}
+
+
+def ApplyVectorContractToFMAPatternsOp : Op<Transform_Dialect,
+    "apply_patterns.vector.contract_to_fma",
+    [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+  let description = [{
+    Implements the lowering of vector contraction op for GEMM of size MxN to
+        sequence of vector FMAs wrapped inside scf.for loop with iterargs to
+        accumulate the result of FMAs.
+  }];
+
+  let assemblyFormat = "attr-dict";
+}
+
 def ApplyPadVectorizationPatternsOp : Op<Transform_Dialect,
     "apply_patterns.linalg.pad_vectorization",
     [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 1dc700f22c2027..d35f99826004ed 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1824,6 +1824,16 @@ void populateConstantFoldLinalgOperations(RewritePatternSet &patterns,
 /// suffices for achieving the sum.
 void populateFoldAddIntoDestPatterns(RewritePatternSet &patterns);
 
+
+/// Pattern to hoists the vector transfer reads/writes outside the reduction and 
+/// k-loop.
+void populateHoistVectorTransferPatterns(RewritePatternSet &patterns);
+
+
+/// Pattern to lower vector contraction op for GEMM of size MxN to
+/// sequence of vector FMAs
+void populateVectorContractToFMAPatterns(RewritePatternSet &patterns);
+
 /// Pattern to fuse a `tensor.pad` operation with the producer of its source,
 /// if the producer is a `linalg` operation with all parallel iterator types.
 void populateFuseTensorPadWithProducerLinalgOpPatterns(
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 221ca27b80fdd0..849632aed8a13b 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -262,6 +262,19 @@ void transform::ApplyFoldAddIntoDestPatternsOp::populatePatterns(
   linalg::populateFoldAddIntoDestPatterns(patterns);
 }
 
+
+void transform::ApplyHoistVectorTransferPatternsOp::populatePatterns(
+    RewritePatternSet &patterns) {
+  linalg::populateHoistVectorTransferPatterns(patterns);
+}
+
+
+void transform::ApplyVectorContractToFMAPatternsOp::populatePatterns(
+    RewritePatternSet &patterns) {
+  linalg::populateVectorContractToFMAPatterns(patterns);
+}
+
+
 void transform::ApplyPadVectorizationPatternsOp::populatePatterns(
     RewritePatternSet &patterns) {
   linalg::populatePadOpVectorizationPatterns(patterns);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
index 3594b084138124..90d926201cd753 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
@@ -41,6 +41,8 @@ add_mlir_dialect_library(MLIRLinalgTransforms
   DecomposeGenericByUnfoldingPermutation.cpp
   Vectorization.cpp
   WinogradConv2D.cpp
+  HoistVectorTransfers.cpp
+  VectorContractToFMA.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Linalg
diff --git a/mlir/lib/Dialect/Linalg/Transforms/HoistVectorTransfers.cpp b/mlir/lib/Dialect/Linalg/Transforms/HoistVectorTransfers.cpp
new file mode 100644
index 00000000000000..5911d36cbafef6
--- /dev/null
+++ b/mlir/lib/Dialect/Linalg/Transforms/HoistVectorTransfers.cpp
@@ -0,0 +1,234 @@
+//===-HoistVectorTransfers.cpp -----------------------------------------*-
+// C++-*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements tile configuration hoisting on parallel loops.
+//
+//===----------------------------------------------------------------------===//
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
+#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/IRMapping.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+using namespace mlir;
+
+static FailureOr<SmallVector<vector::TransferReadOp>>
+getContractOperands(vector::ContractionOp contractOp) {
+  SmallVector<vector::TransferReadOp> list;
+  for (int i = 0; i < 3; i++) {
+    auto vectorReadOp =
+        contractOp.getOperand(i).getDefiningOp<vector::TransferReadOp>();
+    if (!vectorReadOp)
+      return failure();
+    list.push_back(vectorReadOp);
+  }
+  return list;
+}
+
+static FailureOr<SmallVector<memref::SubViewOp>>
+getReadOperands(SmallVector<vector::TransferReadOp> readOps) {
+  SmallVector<memref::SubViewOp> list;
+  for (vector::TransferReadOp readOp : readOps) {
+    auto subViewOp = readOp.getOperand(0).getDefiningOp<memref::SubViewOp>();
+    if (!subViewOp)
+      return failure();
+    list.push_back(subViewOp);
+  }
+  return list;
+}
+
+static FailureOr<SmallVector<scf::ForOp>>
+getNestedLoop(vector::ContractionOp contractOp) {
+  SmallVector<scf::ForOp> list;
+  Operation *current = contractOp;
+  for (int i = 0; i < 4; i++) {
+    Operation *parent = current->getParentOfType<scf::ForOp>();
+    if (!parent)
+      return failure();
+    list.push_back(dyn_cast<scf::ForOp>(parent));
+    current = parent;
+  }
+  return list;
+}
+
+static LogicalResult checkNestedLoop(SmallVector<scf::ForOp> loops,
+                                     SmallVector<memref::SubViewOp> subviews) {
+  auto subviewOpLhsOffsets = subviews[0].getOffsets();
+  auto subviewOpRhsOffsets = subviews[1].getOffsets();
+  auto subviewOpAccOffsets = subviews[2].getOffsets();
+
+  Value ivK = loops[0].getInductionVar();
+  if (ivK != subviewOpLhsOffsets[2] || ivK != subviewOpRhsOffsets[1])
+    return failure();
+
+  Value ivReduction = loops[1].getInductionVar();
+  if (ivReduction != subviewOpLhsOffsets[0] ||
+      ivReduction != subviewOpRhsOffsets[0])
+    return failure();
+
+  Value ivN = loops[2].getInductionVar();
+  if (ivN != subviewOpAccOffsets[1] || ivN != subviewOpRhsOffsets[2])
+    return failure();
+
+  Value ivM = loops[3].getInductionVar();
+  if (ivM != subviewOpLhsOffsets[1] || ivM != subviewOpAccOffsets[0])
+    return failure();
+
+  return success();
+}
+
+struct HoistVectorTransferOp : OpRewritePattern<vector::ContractionOp> {
+  using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
+                                PatternRewriter &rewriter) const override {
+
+    // Check the vector contract operation satisfies the required pattern.
+    // Check the Acc, Lhs, and Rhs of contract operation
+
+    auto operands = getContractOperands(contractOp);
+    if (failed(operands))
+      return rewriter.notifyMatchFailure(contractOp,
+                                         "Invalid operands for contract op");
+
+    auto readOps = *operands;
+    auto vectorReadOpAcc = readOps[2];
+    auto vectorReadOpLhs = readOps[0];
+    auto vectorReadOpRhs = readOps[1];
+
+    // Check whether the operand of vector transfer read is a subview
+    auto subviews = getReadOperands(readOps);
+    if (failed(subviews))
+      return rewriter.notifyMatchFailure(
+          contractOp, "Vector read op operands are not a subview");
+
+    // Check the operation type MatMul, B-MatMul, or BR-MatMul
+    SmallVector<vector::IteratorType> contractIteratorTypes =
+        contractOp.getIteratorTypesArray();
+    int reductionCount =
+        std::count(contractIteratorTypes.begin(), contractIteratorTypes.end(),
+                   vector::IteratorType::reduction);
+
+    auto vectorReadOpLhsType = cast<ShapedType>(vectorReadOpLhs.getType());
+    auto vectorReadOpRhsRank =
+        (cast<ShapedType>(vectorReadOpRhs.getType())).getRank();
+
+    if (reductionCount == 2 &&
+        (vectorReadOpLhsType.getRank() != 3 || vectorReadOpRhsRank != 3))
+      return rewriter.notifyMatchFailure(
+          contractOp, "Invalid rank for batch reduce operation");
+
+    if (reductionCount == 1)
+      return rewriter.notifyMatchFailure(
+          contractOp, "Batch matmul operation not supported yet");
+
+    if (reductionCount > 2)
+      return rewriter.notifyMatchFailure(
+          contractOp, "The vector contract operation is not a gemm");
+
+    // Check the K-dim to be 1
+    int64_t K =
+        vectorReadOpLhsType.getDimSize(vectorReadOpLhsType.getRank() - 1);
+    if (K != 1)
+      return rewriter.notifyMatchFailure(contractOp, "K dim is not 1");
+
+    // Check whether the linalg tiling + vector contract pattern matches for the
+    // 4-nested loop structure
+    auto loops = getNestedLoop(contractOp);
+    if (failed(loops))
+      return rewriter.notifyMatchFailure(
+          contractOp, "Invalid loop nest in contract pattern");
+
+    auto checkLoops = checkNestedLoop(*loops, *subviews);
+    if (failed(checkLoops))
+      return rewriter.notifyMatchFailure(
+          contractOp, "Loops doesn't match the iv in subviews");
+
+    auto nestedLoops = *loops;
+    auto kForOp = nestedLoops[0];
+    auto reductionForOp = nestedLoops[1];
+
+    // Move the vector transfer read before the reduction and k loop
+    rewriter.setInsertionPoint(reductionForOp);
+    auto *cloneVectorReadOp = rewriter.clone(*vectorReadOpAcc);
+
+    // Code to re-create the reduction and k loop with iter args
+    auto vectorReadOpValue = cloneVectorReadOp->getResult(0);
+    auto newReductionForOp = rewriter.create<scf::ForOp>(
+        reductionForOp.getLoc(), reductionForOp.getLowerBound(),
+        reductionForOp.getUpperBound(), reductionForOp.getStep(),
+        ValueRange{vectorReadOpValue},
+        [&](OpBuilder &rewriterNewReductionForOp, Location locNewReductionForOp,
+            Value ivNewReductionForOp, ValueRange iterArgsNewReductionForOp) {
+          auto newKForOp = rewriter.create<scf::ForOp>(
+              kForOp.getLoc(), kForOp.getLowerBound(), kForOp.getUpperBound(),
+              kForOp.getStep(), iterArgsNewReductionForOp,
+              [&](OpBuilder &rewriterNewKForOp, Location locNewKForOp,
+                  Value ivNewKForOp, ValueRange iterArgsNewKForOp) {
+                IRMapping mapper;
+                mapper.map(reductionForOp.getInductionVar(),
+                           ivNewReductionForOp);
+                mapper.map(kForOp.getInductionVar(), ivNewKForOp);
+
+                for (auto &op : kForOp.getBody()->without_terminator()) {
+                  rewriterNewKForOp.clone(op, mapper);
+                }
+                rewriterNewKForOp.create<scf::YieldOp>(locNewKForOp,
+                                                       iterArgsNewKForOp);
+              });
+          rewriterNewReductionForOp.create<scf::YieldOp>(
+              locNewReductionForOp, newKForOp.getResult(0));
+        });
+
+    // Code to hoist vector transfer write after reduction loop and also to
+    // update the yield of k loop
+    auto newKForOp =
+        llvm::dyn_cast<scf::ForOp>(newReductionForOp.getBody()->front());
+    Value newcontractOpValue;
+    vector::TransferWriteOp vectorWriteOperation;
+    Block *bodyBlock = newKForOp.getBody();
+    for (auto &op : bodyBlock->getOperations()) {
+      if (auto vectorContractOp = llvm::dyn_cast<vector::ContractionOp>(op)) {
+        vectorContractOp.setOperand(vectorContractOp.getNumOperands() - 1,
+                                    newKForOp.getRegionIterArgs()[0]);
+        newcontractOpValue = vectorContractOp.getResult();
+      }
+      if (auto yieldOp = llvm::dyn_cast<scf::YieldOp>(op)) {
+        yieldOp.setOperand(0, newcontractOpValue);
+      }
+      if (auto vectorWriteOp = llvm::dyn_cast<vector::TransferWriteOp>(op)) {
+        vectorWriteOperation = vectorWriteOp;
+      }
+    }
+
+    vectorWriteOperation.setOperand(0, newReductionForOp.getResult(0));
+    vectorWriteOperation->moveBefore(reductionForOp);
+
+    // Erase the old vector contract operation
+    for (auto result : contractOp->getResults()) {
+      for (auto *userOp : result.getUsers()) {
+        userOp->erase();
+      }
+    }
+    contractOp.erase();
+
+    return success();
+  }
+};
+
+void linalg::populateHoistVectorTransferPatterns(RewritePatternSet &patterns) {
+  patterns.add<HoistVectorTransferOp>(patterns.getContext());
+}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/VectorContractToFMA.cpp b/mlir/lib/Dialect/Linalg/Transforms/VectorContractToFMA.cpp
new file mode 100644
index 00000000000000..2a8132b93bdcb0
--- /dev/null
+++ b/mlir/lib/Dialect/Linalg/Transforms/VectorContractToFMA.cpp
@@ -0,0 +1,356 @@
+
+//===--------------- VectorContractToFMA.cpp ------------*- C++-*-===//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements lowering of vector contraction to vector fma.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+#define DEBUG_TYPE "vector-contract-to-fma"
+
+using namespace mlir;
+
+/// Returns true if the \p map is transposed.
+static bool isTransposed(AffineMap map) {
+  auto results = map.getResults();
+  // Assert if the map does not have 3 or 4 inputs ([] m, n, k).
+  assert((map.getNumInputs() == 3 || map.getNumInputs() == 4) &&
+         "3 or 4 input dim expected");
+  // Assert if the result is not 2D.
+  assert(map.getNumResults() == 2 && "Only 2 output dim expected");
+
+  // Check the last two dimensions for transposition.
+  auto dimExpr0 = dyn_cast<AffineDimExpr>(results[0]);
+  auto dimExpr1 = dyn_cast<AffineDimExpr>(results[1]);
+  assert((dimExpr0 && dimExpr1) && "Unexpected dim expression");
+
+  // Exclude output map result.
+  bool isOutputResultMap =
+      dimExpr0 ==
+          mlir::getAffineDimExpr(map.getNumInputs() - 3, map.getContext()) &&
+      dimExpr1 ==
+          mlir::getAffineDimExpr(map.getNumInputs() - 2, map.getContext());
+  assert(!isOutputResultMap && "Output result map not expected");
+
+  // It's transposed if result found as (k, m) or (n, k), else not transposed.
+  if ((dimExpr0 ==
+           mlir::getAffineDimExpr(map.getNumInputs() - 1, map.getContext()) &&
+       dimExpr1 ==
+           mlir::getAffineDimExpr(map.getNumInputs() - 3, map.getContext())) ||
+      (dimExpr0 ==
+           mlir::getAffineDimExpr(map.getNumInputs() - 2, map.getContext()) &&
+       dimExpr1 ==
+           mlir::getAffineDimExpr(map.getNumInputs() - 1, map.getContext())))
+    return true;
+  return false;
+}
+
+
+// Structure to hold transformation context
+struct TransformationContext {
+  scf::ForOp innerForOp;
+  scf::ForOp outerForOp;
+  scf::ForOp outermostLoop;
+};
+
+enum class MatMulType { Standard, Batch, BatchReduce };
+
+struct VectorContractToFMA
+    : public OpRewritePattern<vector::ContractionOp> {
+  using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::ContractionOp op,
+                                PatternRewriter &rewriter) const override {
+    if (op.getKind() != vector::CombiningKind::ADD)
+      return rewriter.notifyMatchFailure(
+          op, "Unsupported combining kind, only supports ADD at the moment)");
+
+    auto maskableOp = cast<vector::MaskableOpInterface>(op.getOperation());
+    if (maskableOp.isMasked())
+      return rewriter.notifyMatchFailure(op, "Masked contractOp not supported");
+
+    SmallVector<AffineMap, 3> maps = op.getIndexingMapsArray();
+    if (llvm::any_of(
+            maps, [](AffineMap map) { return !map.isProjectedPermutation(); }))
+      return rewriter.notifyMatchFailure(op, "Unexpected map");
+
+    // Check for the variant of matrix multiply.
+    auto iteratorTypes = op.getIteratorTypesArray();
+    MatMulType matmulType;
+    unsigned outerDimIndex = 0;
+    if (iteratorTypes.size() > 3) {
+      outerDimIndex = iteratorTypes.size() - 4;
+      matmulType =
+          iteratorTypes[outerDimIndex] == vector::IteratorType::parallel
+              ? MatMulType::Batch
+              : MatMulType::BatchReduce;
+      outerDimIndex++;
+    } else if (iteratorTypes.size() == 3) {
+      matmulType = MatMulType::Standard;
+    } else {
+      return rewriter.notifyMatchFailure(op, "Not a gemm");
+    }
+
+    if (matmulType == MatMulType::Batch)
+      return rewriter.notifyMatchFailure(op, "Batch matmul not supported");
+    if (iteratorTypes[outerDimIndex] != vector::IteratorType::parallel ||
+        iteratorTypes[outerDimIndex + 1] != vector::IteratorType::parallel ||
+        iteratorTypes[outerDimIndex + 2] != vector::IteratorType::reduction)
+      return rewriter.notifyMatchFailure(op, "Not a gemm");
+
+    SmallVector<Value, 4> results;
+
+    auto lhs = op.getLhs();
+    auto rhs = op.getRhs();
+    auto acc = op.getAcc();
+    auto lhsDefiningOp = lhs.getDefiningOp<vector::TransferReadOp>();
+    auto rhsDefiningOp = rhs.getDefiningOp<vector::TransferReadOp>();
+    auto accDefiningOp = acc.getDefiningOp<vector::TransferReadOp>();
+    if (!lhsDefiningOp || !rhsDefiningOp)
+      return failure();
+
+    // Accumulator can be a TransferReadOp but must be coming from the chain of
+    // iterargs of nested loop.
+    if (accDefiningOp)
+      return failure();
+
+    // Make sure the inputs being read are whole tensor or subview.
+    if (!llvm::all_of(lhsDefiningOp.getIndices(), isZeroIndex) ||
+        !llvm::all_of(rhsDefiningOp.getIndices(), isZeroIndex)) {
+      return failure();
+    }
+
+    auto lhsType = cast<ShapedType>(lhsDefiningOp.getType());
+    auto rhsType = cast<ShapedType>(rhsDefiningOp.getType());
+    // auto accType = acc.getType();
+    //  auto accType = cast<ShapedType>(accDefiningOp.getType());
+
+    if (matmulType == MatMulType::BatchReduce &&
+        (lhsType.getRank() != 3 || rhsType.getRank() != 3))
+      return failure();
+
+    if (matmulType == MatMulType::Standard &&
+        (lhsType.getRank() != 2 || rhsType.getRank() != 2))
+      return failure();
+
+    // Check for non-transposed matrices.
+    auto mapLHS = maps[0];
+    auto mapRHS = maps[1];
+    if (matmulType == MatMulType::BatchReduce) {
+      mapLHS = mapLHS.dropResult(0);
+      mapRHS = mapRHS.dropResult(0);
+    }
+    if (isTransposed(mapLHS) || isTransposed(mapRHS))
+      return rewriter.notifyMatchFailure(
+          op, "Transposed matrices are not expected");
+
+    // Verify that the accumulator is coming through a chain of iterargs of
+    // nested loop and it is define by 'TransferReadOp'.
+    //
+    struct TransformationContext ctx;
+
+    ctx.innerForOp = op->getParentOfType<scf::ForOp>();
+    if (!ctx.innerForOp)
+      return failure();
+    ctx.outerForOp = ctx.innerForOp->getParentOfType<scf::ForOp>();
+    if (!ctx.outerForOp)
+      return failure();
+    ctx.outermostLoop = ctx.outerForOp->getParentOfType<scf::ForOp>();
+    if (!ctx.outermostLoop)
+      return failure();
+
+    // Verify original inner loop has only one iterarg.
+    auto origIterArgs = ctx.innerForOp.getRegionIterArgs();
+    if (origIterArgs.size() != 1)
+      return failure();
+
+    // Verify chain, accumulator must be inner loop's iterarg.
+    auto bbArg = dyn_cast<BlockArgument>(acc);
+    if (!bbArg)
+      return failure();
+
+    // This block arg must be init arg, not induction variable.
+    if (bbArg.getOwner() != ctx.innerForOp.getBody() ||
+        bbArg.getArgNumber() == 0) {
+      return failure();
+    }
+
+    // This iterarg must be intialized by outer loop's iterarg.
+    auto innerInitValue =
+        ctx.innerForOp.getInitArgs()[bbArg.getArgNumber() - 1];
+    auto outerBBArg = dyn_cast<BlockArgument>(innerInitValue);
+    if (!outerBBArg)
+      return failure();
+
+    // This block arg must be init arg, not induction variable.
+    if (outerBBArg.getOwner() != ctx.outerForOp.getBody() ||
+        outerBBArg.getArgNumber() == 0) {
+      return failure();
+    }
+
+    // Outer loop's iterarg initializer must be a TransferReadOp.
+    acc = ctx.outerForOp.getInitArgs()[outerBBArg.getArgNumber() - 1];
+
+    //  This must be defined by vector.transfer_read
+    if (!acc.getDefiningOp<vector::TransferReadOp>())
+      return failure();
+
+    accDefiningOp = acc.getDefiningOp<vector::TransferReadOp>();
+    if (!accDefiningOp)
+      return failure();
+
+    // Only 2-D output expected.
+    auto accType = cast<ShapedType>(accDefiningOp.getType());
+    if (accType.getRank() != 2)
+      return failure();
+
+    int64_t M = accType.getDimSize(0);
+    int64_t N = accType.getDimSize(1);
+    int64_t K = lhsType.getDimSize(lhsType.getRank() - 1);
+
+    // K must be 1.
+    if (K != 1)
+      return failure();
+
+    auto accSubview = accDefiningOp.getSource();
+    Location loc = op.getLoc();
+
+    // Create M different <1xN> subviews.
+    auto memrefType = cast<MemRefType>(accSubview.getType());
+    auto elementType = memrefType.getElementType();
+    SmallVector<OpFoldResult> mixedSizes = {rewriter.getIndexAttr(K),
+                                            rewriter.getIndexAttr(N)};
+    SmallVector<OpFoldResult> mixedStrides = {rewriter.getIndexAttr(1),
+                                              rewriter.getIndexAttr(1)};
+
+    rewriter.setInsertionPoint(
+        ctx.outermostLoop.getBody(),
+        std::prev(ctx.outermostLoop.getBody()->end(), 1));
+
+    Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+    SmallVector<Value, 4> subview_2_splits;
+    for (int i = 0; i < M; i++) {
+      SmallVector<OpFoldResult> mixedOffsets = {
+          rewriter.getIndexAttr(i),
+          rewriter.getIndexAttr(0),
+      };
+      auto split = rewriter.create<memref::SubViewOp>(
+          loc, accSubview, mixedOffsets, mixedSizes, mixedStrides);
+      subview_2_splits.push_back(split);
+    }
+
+    // Intialize each accumulator with a vector of size N
+    SmallVector<Value, 4> initAccs;
+    for (auto subview : subview_2_splits) {
+      auto acc = rewriter.create<vector::LoadOp>(
+          loc, VectorType::get({N}, elementType), subview, ValueRange{c0, c0});
+      initAccs.push_back(acc);
+    }
+
+    // Create new outer loop with M different accumulators.
+    auto newOuterForOp = rewriter.create<scf::ForOp>(
+        loc, ctx.outerForOp.getLowerBound(), ctx.outerForOp.getUpperBound(),
+        ctx.outerForOp.getStep(), initAccs,
+        [&](OpBuilder &nestedBuilder, Location loc, Value iv,
+            ValueRange iterArgs) {
+          // Create new inner loop with M accumulators.
+          auto newInnerForOp = nestedBuilder.create<scf::ForOp>(
+              loc, ctx.innerForOp.getLowerBound(),
+              ctx.innerForOp.getUpperBound(), ctx.innerForOp.getStep(),
+              iterArgs,
+              [&](OpBuilder &innerBuilder, Location loc, Value innerIv,
+                  ValueRange innerIterArgs) {
+                IRMapping mapping;
+                mapping.map(
+                    lhsDefiningOp.getSource().getDefiningOp()->getOperand(1),
+                    iv);
+                mapping.map(
+                    lhsDefiningOp.getSource().getDefiningOp()->getOperand(3),
+                    innerIv);
+                auto lhsClone = innerBuilder.clone(
+                    *lhsDefiningOp.getSource().getDefiningOp(), mapping);
+
+                // Load and broadcast individual elements
+                SmallVector<Value, 4> broadcasts;
+                for (int i = 0; i < M; i++) {
+                  auto elem = innerBuilder.create<memref::LoadOp>(
+                      loc, lhsClone->getResult(0),
+                      ValueRange{
+                          c0,
+                          innerBuilder.create<arith::ConstantIndexOp>(loc, i),
+                          c0});
+                  auto bcast = innerBuilder.create<vector::BroadcastOp>(
+                      loc, VectorType::get({N}, elem.getType()), elem);
+                  broadcasts.push_back(bcast);
+                }
+
+                IRMapping rhsMapping;
+                rhsMapping.map(
+                    rhsDefiningOp.getSource().getDefiningOp()->getOperand(1),
+                    iv);
+                rhsMapping.map(
+                    rhsDefiningOp.getSource().getDefiningOp()->getOperand(2),
+                    innerIv);
+                auto rhsClone = innerBuilder.clone(
+                    *rhsDefiningOp.getSource().getDefiningOp(), rhsMapping);
+                auto rowVec = innerBuilder.create<vector::LoadOp>(
+                    loc, VectorType::get({N}, elementType),
+                    rhsClone->getResult(0), ValueRange{c0, c0, c0});
+
+                // Create M different FMAs using broadcasts and current
+                // accumulator values.
+                for (int i = 0; i < M; i++) {
+                  auto fma = innerBuilder.create<vector::FMAOp>(
+                      loc, broadcasts[i], rowVec, innerIterArgs[i]);
+                  results.push_back(fma);
+                }
+
+                // Yield all M results
+                innerBuilder.create<scf::YieldOp>(loc, results);
+              });
+
+          // Yield results from inner loop to outer loop
+          nestedBuilder.create<scf::YieldOp>(loc, newInnerForOp.getResults());
+        });
+
+    Value matResult = ctx.outerForOp.getResult(0);
+    Operation *writeOp;
+    for (auto user : matResult.getUsers()) {
+      writeOp = dyn_cast<vector::TransferWriteOp>(user);
+      if (writeOp)
+        break;
+    }
+
+    // Store final results back to original locations.
+    if (writeOp) {
+      for (int i = 0; i < M; i++) {
+        rewriter.create<vector::StoreOp>(loc, newOuterForOp.getResult(i),
+                                         subview_2_splits[i],
+                                         ValueRange{c0, c0});
+      }
+    }
+
+    // Erase original write.
+    if (writeOp)
+      rewriter.eraseOp(writeOp);
+
+    return success();
+  }
+
+};
+
+void linalg::populateVectorContractToFMAPatterns(RewritePatternSet &patterns) {
+  patterns.add<VectorContractToFMA>(patterns.getContext());
+}
diff --git a/mlir/test/Dialect/Linalg/hoist-vector-transfer.mlir b/mlir/test/Dialect/Linalg/hoist-vector-transfer.mlir
new file mode 100644
index 00000000000000..f1f24e4e53cb66
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/hoist-vector-transfer.mlir
@@ -0,0 +1,97 @@
+// RUN: mlir-opt %s -transform-interpreter -cse -split-input-file | FileCheck %s
+
+#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
+#map2 = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+  memref.global "private" constant @__constant_24x64x64xf32 : memref<24x64x64xf32> = dense<1.000000e+00> {alignment = 64 : i64}
+  func.func @simple_gemm(%arg0: memref<8x24x32x64xf32>) -> memref<8x24x32x64xf32> {
+    %cst = arith.constant 0.000000e+00 : f32
+    %cst_0 = arith.constant dense<0.000000e+00> : vector<32x64xf32>
+    %c1 = arith.constant 1 : index
+    %c24 = arith.constant 24 : index
+    %c64 = arith.constant 64 : index
+    %c4 = arith.constant 4 : index
+    %c32 = arith.constant 32 : index
+    %c0 = arith.constant 0 : index
+    %0 = memref.get_global @__constant_24x64x64xf32 : memref<24x64x64xf32>
+    %alloc = memref.alloc() {alignment = 64 : i64} : memref<8x24x32x64xf32>
+    scf.forall (%arg1, %arg2) in (8, 24) {
+      %subview = memref.subview %alloc[%arg1, %arg2, 0, 0] [1, 1, 32, 64] [1, 1, 1, 1] : memref<8x24x32x64xf32> to memref<32x64xf32, strided<[64, 1], offset: ?>>
+      vector.transfer_write %cst_0, %subview[%c0, %c0] {in_bounds = [true, true]} : vector<32x64xf32>, memref<32x64xf32, strided<[64, 1], offset: ?>>
+      %subview_1 = memref.subview %arg0[%arg1, 0, 0, 0] [1, 24, 32, 64] [1, 1, 1, 1] : memref<8x24x32x64xf32> to memref<24x32x64xf32, strided<[2048, 64, 1], offset: ?>>
+      scf.for %arg3 = %c0 to %c32 step %c4 {
+        scf.for %arg4 = %c0 to %c64 step %c64 {
+          %subview_2 = memref.subview %subview[%arg3, %arg4] [4, 64] [1, 1] : memref<32x64xf32, strided<[64, 1], offset: ?>> to memref<4x64xf32, strided<[64, 1], offset: ?>>
+          scf.for %arg5 = %c0 to %c24 step %c1 {
+            scf.for %arg6 = %c0 to %c64 step %c1 {
+              %subview_3 = memref.subview %subview_1[%arg5, %arg3, %arg6] [1, 4, 1] [1, 1, 1] : memref<24x32x64xf32, strided<[2048, 64, 1], offset: ?>> to memref<1x4x1xf32, strided<[2048, 64, 1], offset: ?>>
+              %subview_4 = memref.subview %0[%arg5, %arg6, %arg4] [1, 1, 64] [1, 1, 1] : memref<24x64x64xf32> to memref<1x1x64xf32, strided<[4096, 64, 1], offset: ?>>
+              %1 = vector.transfer_read %subview_3[%c0, %c0, %c0], %cst {in_bounds = [true, true, true]} : memref<1x4x1xf32, strided<[2048, 64, 1], offset: ?>>, vector<1x4x1xf32>
+              %2 = vector.transfer_read %subview_4[%c0, %c0, %c0], %cst {in_bounds = [true, true, true]} : memref<1x1x64xf32, strided<[4096, 64, 1], offset: ?>>, vector<1x1x64xf32>
+              %3 = vector.transfer_read %subview_2[%c0, %c0], %cst {in_bounds = [true, true]} : memref<4x64xf32, strided<[64, 1], offset: ?>>, vector<4x64xf32>
+              %4 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %1, %2, %3 : vector<1x4x1xf32>, vector<1x1x64xf32> into vector<4x64xf32>
+              vector.transfer_write %4, %subview_2[%c0, %c0] {in_bounds = [true, true]} : vector<4x64xf32>, memref<4x64xf32, strided<[64, 1], offset: ?>>
+            }
+          }
+        }
+      }
+    }
+    return %alloc : memref<8x24x32x64xf32>
+  }
+
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
+// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+
+// CHECK-LABEL:   memref.global "private" constant @__constant_24x64x64xf32 : memref<24x64x64xf32> = dense<1.000000e+00> {alignment = 64 : i64}
+
+// CHECK-LABEL:   func.func @simple_gemm(
+// CHECK-SAME:                     %[[VAL_0:.*]]: memref<8x24x32x64xf32>) -> memref<8x24x32x64xf32> {
+// CHECK:           %[[VAL_1:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK:           %[[VAL_2:.*]] = arith.constant dense<0.000000e+00> : vector<32x64xf32>
+// CHECK:           %[[VAL_3:.*]] = arith.constant 1 : index
+// CHECK:           %[[VAL_4:.*]] = arith.constant 24 : index
+// CHECK:           %[[VAL_5:.*]] = arith.constant 64 : index
+// CHECK:           %[[VAL_6:.*]] = arith.constant 4 : index
+// CHECK:           %[[VAL_7:.*]] = arith.constant 32 : index
+// CHECK:           %[[VAL_8:.*]] = arith.constant 0 : index
+// CHECK:           %[[VAL_9:.*]] = memref.get_global @__constant_24x64x64xf32 : memref<24x64x64xf32>
+// CHECK:           %[[VAL_10:.*]] = memref.alloc() {alignment = 64 : i64} : memref<8x24x32x64xf32>
+// CHECK:           scf.forall (%[[VAL_11:.*]], %[[VAL_12:.*]]) in (8, 24) {
+// CHECK:             %[[VAL_13:.*]] = memref.subview %[[VAL_10]]{{\[}}%[[VAL_11]], %[[VAL_12]], 0, 0] [1, 1, 32, 64] [1, 1, 1, 1] : memref<8x24x32x64xf32> to memref<32x64xf32, strided<[64, 1], offset: ?>>
+// CHECK:             vector.transfer_write %[[VAL_2]], %[[VAL_13]]{{\[}}%[[VAL_8]], %[[VAL_8]]] {in_bounds = [true, true]} : vector<32x64xf32>, memref<32x64xf32, strided<[64, 1], offset: ?>>
+// CHECK:             %[[VAL_14:.*]] = memref.subview %[[VAL_0]]{{\[}}%[[VAL_11]], 0, 0, 0] [1, 24, 32, 64] [1, 1, 1, 1] : memref<8x24x32x64xf32> to memref<24x32x64xf32, strided<[2048, 64, 1], offset: ?>>
+// CHECK:             scf.for %[[VAL_15:.*]] = %[[VAL_8]] to %[[VAL_7]] step %[[VAL_6]] {
+// CHECK:               scf.for %[[VAL_16:.*]] = %[[VAL_8]] to %[[VAL_5]] step %[[VAL_5]] {
+// CHECK:                 %[[VAL_17:.*]] = memref.subview %[[VAL_13]]{{\[}}%[[VAL_15]], %[[VAL_16]]] [4, 64] [1, 1] : memref<32x64xf32, strided<[64, 1], offset: ?>> to memref<4x64xf32, strided<[64, 1], offset: ?>>
+// CHECK:                 %[[VAL_18:.*]] = vector.transfer_read %[[VAL_17]]{{\[}}%[[VAL_8]], %[[VAL_8]]], %[[VAL_1]] {in_bounds = [true, true]} : memref<4x64xf32, strided<[64, 1], offset: ?>>, vector<4x64xf32>
+// CHECK:                 %[[VAL_19:.*]] = scf.for %[[VAL_20:.*]] = %[[VAL_8]] to %[[VAL_4]] step %[[VAL_3]] iter_args(%[[VAL_21:.*]] = %[[VAL_18]]) -> (vector<4x64xf32>) {
+// CHECK:                   %[[VAL_22:.*]] = scf.for %[[VAL_23:.*]] = %[[VAL_8]] to %[[VAL_5]] step %[[VAL_3]] iter_args(%[[VAL_24:.*]] = %[[VAL_21]]) -> (vector<4x64xf32>) {
+// CHECK:                     %[[VAL_25:.*]] = memref.subview %[[VAL_14]]{{\[}}%[[VAL_20]], %[[VAL_15]], %[[VAL_23]]] [1, 4, 1] [1, 1, 1] : memref<24x32x64xf32, strided<[2048, 64, 1], offset: ?>> to memref<1x4x1xf32, strided<[2048, 64, 1], offset: ?>>
+// CHECK:                     %[[VAL_26:.*]] = memref.subview %[[VAL_9]]{{\[}}%[[VAL_20]], %[[VAL_23]], %[[VAL_16]]] [1, 1, 64] [1, 1, 1] : memref<24x64x64xf32> to memref<1x1x64xf32, strided<[4096, 64, 1], offset: ?>>
+// CHECK:                     %[[VAL_27:.*]] = vector.transfer_read %[[VAL_25]]{{\[}}%[[VAL_8]], %[[VAL_8]], %[[VAL_8]]], %[[VAL_1]] {in_bounds = [true, true, true]} : memref<1x4x1xf32, strided<[2048, 64, 1], offset: ?>>, vector<1x4x1xf32>
+// CHECK:                     %[[VAL_28:.*]] = vector.transfer_read %[[VAL_26]]{{\[}}%[[VAL_8]], %[[VAL_8]], %[[VAL_8]]], %[[VAL_1]] {in_bounds = [true, true, true]} : memref<1x1x64xf32, strided<[4096, 64, 1], offset: ?>>, vector<1x1x64xf32>
+// CHECK:                     %[[VAL_29:.*]] = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %4, %5, %arg8 : vector<1x4x1xf32>, vector<1x1x64xf32> into vector<4x64xf32>
+// CHECK:                     scf.yield %[[VAL_29]] : vector<4x64xf32>
+// CHECK:                   }
+// CHECK:                   scf.yield %[[VAL_22]] : vector<4x64xf32>
+// CHECK:                 }
+// CHECK:                 vector.transfer_write %[[VAL_19]], %[[VAL_17]]{{\[}}%[[VAL_8]], %[[VAL_8]]] {in_bounds = [true, true]} : vector<4x64xf32>, memref<4x64xf32, strided<[64, 1], offset: ?>>
+// CHECK:               }
+// CHECK:             }
+// CHECK:           }
+// CHECK:           return %[[VAL_10]] : memref<8x24x32x64xf32>
+// CHECK:         }
+
+
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.vector.hoist_vector_transfer
+    } : !transform.any_op
+    transform.yield
+  }
+}
diff --git a/mlir/test/Dialect/Linalg/vector-contract-to-fma.mlir b/mlir/test/Dialect/Linalg/vector-contract-to-fma.mlir
new file mode 100644
index 00000000000000..ba11074e3c9637
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/vector-contract-to-fma.mlir
@@ -0,0 +1,113 @@
+// RUN: mlir-opt %s -transform-interpreter -cse -split-input-file | FileCheck %s
+
+#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
+#map2 = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+  memref.global "private" constant @__constant_24x64x64xf32 : memref<24x64x64xf32> = dense<1.000000e+00> {alignment = 64 : i64}
+  func.func @simple_gemm(%arg0: memref<8x24x32x64xf32>) -> memref<8x24x32x64xf32> {
+    %cst = arith.constant 0.000000e+00 : f32
+    %cst_0 = arith.constant dense<0.000000e+00> : vector<32x64xf32>
+    %c1 = arith.constant 1 : index
+    %c24 = arith.constant 24 : index
+    %c64 = arith.constant 64 : index
+    %c4 = arith.constant 4 : index
+    %c32 = arith.constant 32 : index
+    %c0 = arith.constant 0 : index
+    %0 = memref.get_global @__constant_24x64x64xf32 : memref<24x64x64xf32>
+    %alloc = memref.alloc() {alignment = 64 : i64} : memref<8x24x32x64xf32>
+    scf.forall (%arg1, %arg2) in (8, 24) {
+      %subview = memref.subview %alloc[%arg1, %arg2, 0, 0] [1, 1, 32, 64] [1, 1, 1, 1] : memref<8x24x32x64xf32> to memref<32x64xf32, strided<[64, 1], offset: ?>>
+      vector.transfer_write %cst_0, %subview[%c0, %c0] {in_bounds = [true, true]} : vector<32x64xf32>, memref<32x64xf32, strided<[64, 1], offset: ?>>
+      %subview_1 = memref.subview %arg0[%arg1, 0, 0, 0] [1, 24, 32, 64] [1, 1, 1, 1] : memref<8x24x32x64xf32> to memref<24x32x64xf32, strided<[2048, 64, 1], offset: ?>>
+      scf.for %arg3 = %c0 to %c32 step %c4 {
+        scf.for %arg4 = %c0 to %c64 step %c64 {
+          %subview_2 = memref.subview %subview[%arg3, %arg4] [4, 64] [1, 1] : memref<32x64xf32, strided<[64, 1], offset: ?>> to memref<4x64xf32, strided<[64, 1], offset: ?>>
+          %1 = vector.transfer_read %subview_2[%c0, %c0], %cst {in_bounds = [true, true]} : memref<4x64xf32, strided<[64, 1], offset: ?>>, vector<4x64xf32>
+          %2 = scf.for %arg5 = %c0 to %c24 step %c1 iter_args(%arg6 = %1) -> (vector<4x64xf32>) {
+            %3 = scf.for %arg7 = %c0 to %c64 step %c1 iter_args(%arg8 = %arg6) -> (vector<4x64xf32>) {
+              %subview_3 = memref.subview %subview_1[%arg5, %arg3, %arg7] [1, 4, 1] [1, 1, 1] : memref<24x32x64xf32, strided<[2048, 64, 1], offset: ?>> to memref<1x4x1xf32, strided<[2048, 64, 1], offset: ?>>
+              %subview_4 = memref.subview %0[%arg5, %arg7, %arg4] [1, 1, 64] [1, 1, 1] : memref<24x64x64xf32> to memref<1x1x64xf32, strided<[4096, 64, 1], offset: ?>>
+              %4 = vector.transfer_read %subview_3[%c0, %c0, %c0], %cst {in_bounds = [true, true, true]} : memref<1x4x1xf32, strided<[2048, 64, 1], offset: ?>>, vector<1x4x1xf32>
+              %5 = vector.transfer_read %subview_4[%c0, %c0, %c0], %cst {in_bounds = [true, true, true]} : memref<1x1x64xf32, strided<[4096, 64, 1], offset: ?>>, vector<1x1x64xf32>
+              %6 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %4, %5, %arg8 : vector<1x4x1xf32>, vector<1x1x64xf32> into vector<4x64xf32>
+              scf.yield %6 : vector<4x64xf32>
+            }
+            scf.yield %3 : vector<4x64xf32>
+          }
+          vector.transfer_write %2, %subview_2[%c0, %c0] {in_bounds = [true, true]} : vector<4x64xf32>, memref<4x64xf32, strided<[64, 1], offset: ?>>
+        }
+      }
+    }
+    return %alloc : memref<8x24x32x64xf32>
+  }
+
+// CHECK-LABEL:   memref.global "private" constant @__constant_24x64x64xf32 : memref<24x64x64xf32> = dense<1.000000e+00> {alignment = 64 : i64}
+
+// CHECK-LABEL:   func.func @simple_gemm(
+// CHECK-SAME:                           %[[VAL_0:.*]]: memref<8x24x32x64xf32>) -> memref<8x24x32x64xf32> {
+// CHECK:           %[[VAL_1:.*]] = arith.constant 3 : index
+// CHECK:           %[[VAL_2:.*]] = arith.constant 2 : index
+// CHECK:           %[[VAL_3:.*]] = arith.constant dense<0.000000e+00> : vector<32x64xf32>
+// CHECK:           %[[VAL_4:.*]] = arith.constant 1 : index
+// CHECK:           %[[VAL_5:.*]] = arith.constant 24 : index
+// CHECK:           %[[VAL_6:.*]] = arith.constant 64 : index
+// CHECK:           %[[VAL_7:.*]] = arith.constant 4 : index
+// CHECK:           %[[VAL_8:.*]] = arith.constant 32 : index
+// CHECK:           %[[VAL_9:.*]] = arith.constant 0 : index
+// CHECK:           %[[VAL_10:.*]] = memref.get_global @__constant_24x64x64xf32 : memref<24x64x64xf32>
+// CHECK:           %[[VAL_11:.*]] = memref.alloc() {alignment = 64 : i64} : memref<8x24x32x64xf32>
+// CHECK:           scf.forall (%[[VAL_12:.*]], %[[VAL_13:.*]]) in (8, 24) {
+// CHECK:             %[[VAL_14:.*]] = memref.subview %[[VAL_11]]{{\[}}%[[VAL_12]], %[[VAL_13]], 0, 0] [1, 1, 32, 64] [1, 1, 1, 1] : memref<8x24x32x64xf32> to memref<32x64xf32, strided<[64, 1], offset: ?>>
+// CHECK:             vector.transfer_write %[[VAL_3]], %[[VAL_14]]{{\[}}%[[VAL_9]], %[[VAL_9]]] {in_bounds = [true, true]} : vector<32x64xf32>, memref<32x64xf32, strided<[64, 1], offset: ?>>
+// CHECK:             %[[VAL_15:.*]] = memref.subview %[[VAL_0]]{{\[}}%[[VAL_12]], 0, 0, 0] [1, 24, 32, 64] [1, 1, 1, 1] : memref<8x24x32x64xf32> to memref<24x32x64xf32, strided<[2048, 64, 1], offset: ?>>
+// CHECK:             scf.for %[[VAL_16:.*]] = %[[VAL_9]] to %[[VAL_8]] step %[[VAL_7]] {
+// CHECK:               scf.for %[[VAL_17:.*]] = %[[VAL_9]] to %[[VAL_6]] step %[[VAL_6]] {
+// CHECK:                 %[[VAL_18:.*]] = memref.subview %[[VAL_14]]{{\[}}%[[VAL_16]], %[[VAL_17]]] [4, 64] [1, 1] : memref<32x64xf32, strided<[64, 1], offset: ?>> to memref<4x64xf32, strided<[64, 1], offset: ?>>
+// CHECK:                 %[[VAL_19:.*]] = memref.subview %[[VAL_18]][0, 0] [1, 64] [1, 1] : memref<4x64xf32, strided<[64, 1], offset: ?>> to memref<1x64xf32, strided<[64, 1], offset: ?>>
+// CHECK:                 %[[VAL_20:.*]] = memref.subview %[[VAL_18]][1, 0] [1, 64] [1, 1] : memref<4x64xf32, strided<[64, 1], offset: ?>> to memref<1x64xf32, strided<[64, 1], offset: ?>>
+// CHECK:                 %[[VAL_21:.*]] = memref.subview %[[VAL_18]][2, 0] [1, 64] [1, 1] : memref<4x64xf32, strided<[64, 1], offset: ?>> to memref<1x64xf32, strided<[64, 1], offset: ?>>
+// CHECK:                 %[[VAL_22:.*]] = memref.subview %[[VAL_18]][3, 0] [1, 64] [1, 1] : memref<4x64xf32, strided<[64, 1], offset: ?>> to memref<1x64xf32, strided<[64, 1], offset: ?>>
+// CHECK:                 %[[VAL_23:.*]] = vector.load %[[VAL_19]]{{\[}}%[[VAL_9]], %[[VAL_9]]] : memref<1x64xf32, strided<[64, 1], offset: ?>>, vector<64xf32>
+// CHECK:                 %[[VAL_24:.*]] = vector.load %[[VAL_20]]{{\[}}%[[VAL_9]], %[[VAL_9]]] : memref<1x64xf32, strided<[64, 1], offset: ?>>, vector<64xf32>
+// CHECK:                 %[[VAL_25:.*]] = vector.load %[[VAL_21]]{{\[}}%[[VAL_9]], %[[VAL_9]]] : memref<1x64xf32, strided<[64, 1], offset: ?>>, vector<64xf32>
+// CHECK:                 %[[VAL_26:.*]] = vector.load %[[VAL_22]]{{\[}}%[[VAL_9]], %[[VAL_9]]] : memref<1x64xf32, strided<[64, 1], offset: ?>>, vector<64xf32>
+// CHECK:                 %[[VAL_27:.*]]:4 = scf.for %[[VAL_28:.*]] = %[[VAL_9]] to %[[VAL_5]] step %[[VAL_4]] iter_args(%[[VAL_29:.*]] = %[[VAL_23]], %[[VAL_30:.*]] = %[[VAL_24]], %[[VAL_31:.*]] = %[[VAL_25]], %[[VAL_32:.*]] = %[[VAL_26]]) -> (vector<64xf32>, vector<64xf32>, vector<64xf32>, vector<64xf32>) {
+// CHECK:                   %[[VAL_33:.*]]:4 = scf.for %[[VAL_34:.*]] = %[[VAL_9]] to %[[VAL_6]] step %[[VAL_4]] iter_args(%[[VAL_35:.*]] = %[[VAL_29]], %[[VAL_36:.*]] = %[[VAL_30]], %[[VAL_37:.*]] = %[[VAL_31]], %[[VAL_38:.*]] = %[[VAL_32]]) -> (vector<64xf32>, vector<64xf32>, vector<64xf32>, vector<64xf32>) {
+// CHECK:                     %[[VAL_39:.*]] = memref.subview %[[VAL_15]]{{\[}}%[[VAL_28]], %[[VAL_16]], %[[VAL_34]]] [1, 4, 1] [1, 1, 1] : memref<24x32x64xf32, strided<[2048, 64, 1], offset: ?>> to memref<1x4x1xf32, strided<[2048, 64, 1], offset: ?>>
+// CHECK:                     %[[VAL_40:.*]] = memref.load %[[VAL_39]]{{\[}}%[[VAL_9]], %[[VAL_9]], %[[VAL_9]]] : memref<1x4x1xf32, strided<[2048, 64, 1], offset: ?>>
+// CHECK:                     %[[VAL_41:.*]] = vector.broadcast %[[VAL_40]] : f32 to vector<64xf32>
+// CHECK:                     %[[VAL_42:.*]] = memref.load %[[VAL_39]]{{\[}}%[[VAL_9]], %[[VAL_4]], %[[VAL_9]]] : memref<1x4x1xf32, strided<[2048, 64, 1], offset: ?>>
+// CHECK:                     %[[VAL_43:.*]] = vector.broadcast %[[VAL_42]] : f32 to vector<64xf32>
+// CHECK:                     %[[VAL_44:.*]] = memref.load %[[VAL_39]]{{\[}}%[[VAL_9]], %[[VAL_2]], %[[VAL_9]]] : memref<1x4x1xf32, strided<[2048, 64, 1], offset: ?>>
+// CHECK:                     %[[VAL_45:.*]] = vector.broadcast %[[VAL_44]] : f32 to vector<64xf32>
+// CHECK:                     %[[VAL_46:.*]] = memref.load %[[VAL_39]]{{\[}}%[[VAL_9]], %[[VAL_1]], %[[VAL_9]]] : memref<1x4x1xf32, strided<[2048, 64, 1], offset: ?>>
+// CHECK:                     %[[VAL_47:.*]] = vector.broadcast %[[VAL_46]] : f32 to vector<64xf32>
+// CHECK:                     %[[VAL_48:.*]] = memref.subview %[[VAL_10]]{{\[}}%[[VAL_28]], %[[VAL_34]], %[[VAL_17]]] [1, 1, 64] [1, 1, 1] : memref<24x64x64xf32> to memref<1x1x64xf32, strided<[4096, 64, 1], offset: ?>>
+// CHECK:                     %[[VAL_49:.*]] = vector.load %[[VAL_48]]{{\[}}%[[VAL_9]], %[[VAL_9]], %[[VAL_9]]] : memref<1x1x64xf32, strided<[4096, 64, 1], offset: ?>>, vector<64xf32>
+// CHECK:                     %[[VAL_50:.*]] = vector.fma %[[VAL_41]], %[[VAL_49]], %[[VAL_35]] : vector<64xf32>
+// CHECK:                     %[[VAL_51:.*]] = vector.fma %[[VAL_43]], %[[VAL_49]], %[[VAL_36]] : vector<64xf32>
+// CHECK:                     %[[VAL_52:.*]] = vector.fma %[[VAL_45]], %[[VAL_49]], %[[VAL_37]] : vector<64xf32>
+// CHECK:                     %[[VAL_53:.*]] = vector.fma %[[VAL_47]], %[[VAL_49]], %[[VAL_38]] : vector<64xf32>
+// CHECK:                     scf.yield %[[VAL_50]], %[[VAL_51]], %[[VAL_52]], %[[VAL_53]] : vector<64xf32>, vector<64xf32>, vector<64xf32>, vector<64xf32>
+// CHECK:                   }
+// CHECK:                   scf.yield %[[VAL_54:.*]]#0, %[[VAL_54]]#1, %[[VAL_54]]#2, %[[VAL_54]]#3 : vector<64xf32>, vector<64xf32>, vector<64xf32>, vector<64xf32>
+// CHECK:                 }
+// CHECK:                 vector.store %[[VAL_55:.*]]#0, %[[VAL_19]]{{\[}}%[[VAL_9]], %[[VAL_9]]] : memref<1x64xf32, strided<[64, 1], offset: ?>>, vector<64xf32>
+// CHECK:                 vector.store %[[VAL_55]]#1, %[[VAL_20]]{{\[}}%[[VAL_9]], %[[VAL_9]]] : memref<1x64xf32, strided<[64, 1], offset: ?>>, vector<64xf32>
+// CHECK:                 vector.store %[[VAL_55]]#2, %[[VAL_21]]{{\[}}%[[VAL_9]], %[[VAL_9]]] : memref<1x64xf32, strided<[64, 1], offset: ?>>, vector<64xf32>
+// CHECK:                 vector.store %[[VAL_55]]#3, %[[VAL_22]]{{\[}}%[[VAL_9]], %[[VAL_9]]] : memref<1x64xf32, strided<[64, 1], offset: ?>>, vector<64xf32>
+// CHECK:               }
+// CHECK:             }
+// CHECK:           }
+// CHECK:           return %[[VAL_11]] : memref<8x24x32x64xf32>
+// CHECK:         }
+
+  module attributes {transform.with_named_sequence} {
+    transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+      %0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+      transform.apply_patterns to %0 {
+        transform.apply_patterns.vector.contract_to_fma
+      } : !transform.any_op
+      transform.yield
+    }
+  }

>From 83ed5c4db59730cc0b3bd38a0e5c2551904992b1 Mon Sep 17 00:00:00 2001
From: Arun Thangamani <arun.thangamani at intel.com>
Date: Mon, 6 Jan 2025 02:21:34 -0800
Subject: [PATCH 2/3] added few more test-cases and code re-factoring

---
 .../Linalg/TransformOps/LinalgTransformOps.td |   8 +-
 .../Dialect/Linalg/Transforms/Transforms.h    |   9 +-
 .../Transforms/HoistVectorTransfers.cpp       |  49 ++++++--
 .../Linalg/Transforms/VectorContractToFMA.cpp |  56 ++++++++-
 .../Dialect/Linalg/hoist-vector-transfer.mlir |  93 ++++++++++++++-
 .../Linalg/vector-contract-to-fma.mlir        | 107 ++++--------------
 6 files changed, 216 insertions(+), 106 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index d614bb4789767b..9acaf1ba231a0d 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -111,7 +111,8 @@ def ApplyHoistVectorTransferPatternsOp : Op<Transform_Dialect,
     "apply_patterns.vector.hoist_vector_transfer",
     [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
   let description = [{
-    Hoists the vector transfer reads/writes outside the reduction and k-loop.
+    Finds pattern to hoist the  possible vector transfer reads/writes outside the reduction and k-loop
+ for a batch reduce matmul operation.
   }];
 
   let assemblyFormat = "attr-dict";
@@ -122,9 +123,8 @@ def ApplyVectorContractToFMAPatternsOp : Op<Transform_Dialect,
     "apply_patterns.vector.contract_to_fma",
     [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
   let description = [{
-    Implements the lowering of vector contraction op for GEMM of size MxN to
-        sequence of vector FMAs wrapped inside scf.for loop with iterargs to
-        accumulate the result of FMAs.
+     Collects pattern to lower vector contraction operation (batch reduce matmul) for GEMM of size MxN to
+        sequence of vector FMAs.
   }];
 
   let assemblyFormat = "attr-dict";
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index d35f99826004ed..6f639b45408d87 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1825,13 +1825,12 @@ void populateConstantFoldLinalgOperations(RewritePatternSet &patterns,
 void populateFoldAddIntoDestPatterns(RewritePatternSet &patterns);
 
 
-/// Pattern to hoists the vector transfer reads/writes outside the reduction and 
-/// k-loop.
+/// Pattern to hoists the vector transfer reads/writes outside the reduction and
+/// k-loop for batch reduce matmul operation if licm fails.
 void populateHoistVectorTransferPatterns(RewritePatternSet &patterns);
 
-
-/// Pattern to lower vector contraction op for GEMM of size MxN to
-/// sequence of vector FMAs
+/// Pattern to lower vector contraction operation (batch reduce matmul) for GEMM of size MxN to
+/// sequence of vector FMAs.
 void populateVectorContractToFMAPatterns(RewritePatternSet &patterns);
 
 /// Pattern to fuse a `tensor.pad` operation with the producer of its source,
diff --git a/mlir/lib/Dialect/Linalg/Transforms/HoistVectorTransfers.cpp b/mlir/lib/Dialect/Linalg/Transforms/HoistVectorTransfers.cpp
index 5911d36cbafef6..1e741010c741ec 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/HoistVectorTransfers.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/HoistVectorTransfers.cpp
@@ -1,15 +1,11 @@
-//===-HoistVectorTransfers.cpp -----------------------------------------*-
-// C++-*-===//
+//===- HoistVectorTransfers.cpp ---------------------------------------*- C++-*-===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 //
 //===----------------------------------------------------------------------===//
-//
-// This file implements tile configuration hoisting on parallel loops.
-//
-//===----------------------------------------------------------------------===//
+
 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
@@ -25,6 +21,7 @@
 
 using namespace mlir;
 
+// Function to retrives vector transfer read operations (Acc, Lhs, and Rhs) from contraction operation. 
 static FailureOr<SmallVector<vector::TransferReadOp>>
 getContractOperands(vector::ContractionOp contractOp) {
   SmallVector<vector::TransferReadOp> list;
@@ -38,6 +35,7 @@ getContractOperands(vector::ContractionOp contractOp) {
   return list;
 }
 
+// Function to retrive subview from vector transfer read operation.
 static FailureOr<SmallVector<memref::SubViewOp>>
 getReadOperands(SmallVector<vector::TransferReadOp> readOps) {
   SmallVector<memref::SubViewOp> list;
@@ -50,6 +48,7 @@ getReadOperands(SmallVector<vector::TransferReadOp> readOps) {
   return list;
 }
 
+// Function to retrive the tiled nested loop structure (m->n->reduction->k) for the contract operation
 static FailureOr<SmallVector<scf::ForOp>>
 getNestedLoop(vector::ContractionOp contractOp) {
   SmallVector<scf::ForOp> list;
@@ -64,6 +63,7 @@ getNestedLoop(vector::ContractionOp contractOp) {
   return list;
 }
 
+// Function to check iv of nested loops matches with the subview
 static LogicalResult checkNestedLoop(SmallVector<scf::ForOp> loops,
                                      SmallVector<memref::SubViewOp> subviews) {
   auto subviewOpLhsOffsets = subviews[0].getOffsets();
@@ -90,6 +90,40 @@ static LogicalResult checkNestedLoop(SmallVector<scf::ForOp> loops,
   return success();
 }
 
+/// Hoist vector transfer read and write operations for the tiled batch reduce matmul operation
+/// outside the reduction and k-loop.
+///
+/// As an example, the following pseudo-code will be rewritten
+/// scf.for %arg3 = %c0 to %c32 step %c4  // m-loop
+///  scf.for %arg4 = %c0 to %c64 step %c64   // n-loop
+///   %subview_2 = memref.subview %subview[%arg3, %arg4] [4, 64] [1, 1] 
+///    scf.for %arg5 = %c0 to %c24 step %c1   // reduction-loop
+///     scf.for %arg6 = %c0 to %c64 step %c1   // k-loop
+///      %subview_3 = memref.subview %subview_1[%arg5, %arg3, %arg6] [1, 4, 1] [1, 1, 1] 
+///      %subview_4 = memref.subview %0[%arg5, %arg6, %arg4] [1, 1, 64] [1, 1, 1] 
+///      %1 = vector.transfer_read %subview_3[%c0, %c0, %c0], %cst {in_bounds = [true, true, true]} 
+///      %2 = vector.transfer_read %subview_4[%c0, %c0, %c0], %cst {in_bounds = [true, true, true]} 
+///      %3 = vector.transfer_read %subview_2[%c0, %c0], %cst {in_bounds = [true, true]} 
+///      %4 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %1, %2, %3 
+///      vector.transfer_write %4, %subview_2[%c0, %c0] {in_bounds = [true, true]}
+/// to:
+/// scf.for %arg3 = %c0 to %c32 step %c4 
+///  scf.for %arg4 = %c0 to %c64 step %c64 
+///   %subview_2 = memref.subview %subview[%arg3, %arg4] [4, 64] [1, 1] 
+///   %1 = vector.transfer_read %subview_2[%c0, %c0], %cst {in_bounds = [true, true]} 
+///   %2 = scf.for %arg5 = %c0 to %c24 step %c1 iter_args(%arg6 = %1) -> (!type) {
+///    %3 = scf.for %arg7 = %c0 to %c64 step %c1 iter_args(%arg8 = %arg6) -> (!type) {
+///     %subview_3 = memref.subview %subview_1[%arg5, %arg3, %arg7] [1, 4, 1] [1, 1, 1] 
+///     %subview_4 = memref.subview %0[%arg5, %arg7, %arg4] [1, 1, 64] [1, 1, 1] 
+///     %4 = vector.transfer_read %subview_3[%c0, %c0, %c0], %cst {in_bounds = [true, true, true]} 
+///     %5 = vector.transfer_read %subview_4[%c0, %c0, %c0], %cst {in_bounds = [true, true, true]} 
+///     %6 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %4, %5, %arg8 
+///     scf.yield %6 : !type
+///    }
+///    scf.yield %3 : !type
+///   }
+///   vector.transfer_write %2, %subview_2[%c0, %c0] {in_bounds = [true, true]} 
+///
 struct HoistVectorTransferOp : OpRewritePattern<vector::ContractionOp> {
   using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
 
@@ -98,7 +132,6 @@ struct HoistVectorTransferOp : OpRewritePattern<vector::ContractionOp> {
 
     // Check the vector contract operation satisfies the required pattern.
     // Check the Acc, Lhs, and Rhs of contract operation
-
     auto operands = getContractOperands(contractOp);
     if (failed(operands))
       return rewriter.notifyMatchFailure(contractOp,
@@ -145,7 +178,7 @@ struct HoistVectorTransferOp : OpRewritePattern<vector::ContractionOp> {
     if (K != 1)
       return rewriter.notifyMatchFailure(contractOp, "K dim is not 1");
 
-    // Check whether the linalg tiling + vector contract pattern matches for the
+    // Check whether the BR-matmul tiling + vector contract pattern matches for the
     // 4-nested loop structure
     auto loops = getNestedLoop(contractOp);
     if (failed(loops))
diff --git a/mlir/lib/Dialect/Linalg/Transforms/VectorContractToFMA.cpp b/mlir/lib/Dialect/Linalg/Transforms/VectorContractToFMA.cpp
index 2a8132b93bdcb0..4d3dac6a2b4d03 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/VectorContractToFMA.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/VectorContractToFMA.cpp
@@ -1,4 +1,3 @@
-
 //===--------------- VectorContractToFMA.cpp ------------*- C++-*-===//
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -67,6 +66,61 @@ struct TransformationContext {
 
 enum class MatMulType { Standard, Batch, BatchReduce };
 
+
+/// Pattern to lower vector contraction operation (batch reduce matmul) for GEMM of size MxN to
+/// sequence of vector FMAs.
+///
+/// As an example, the following pseudo-code will be rewritten
+/// %subview_1 = memref.subview %subview[%arg3, %arg4] [4, 64] [1, 1] 
+/// %1 = vector.transfer_read %subview_1[%c0, %c0], %cst {in_bounds = [true, true]} 
+/// %2 = scf.for %arg5 = %c0 to %c24 step %c1 iter_args(%arg6 = %1) -> (!type) {
+///  %3 = scf.for %arg7 = %c0 to %c64 step %c1 iter_args(%arg8 = %arg6) -> (!type) {
+///   %subview_3 = memref.subview %subview_0[%arg5, %arg3, %arg7] [1, 4, 1] [1, 1, 1]
+///   %subview_4 = memref.subview %0[%arg5, %arg7, %arg4] [1, 1, 64] [1, 1, 1] 
+///   %4 = vector.transfer_read %subview_3[%c0, %c0, %c0], %cst {in_bounds = [true, true, true]} 
+///   %5 = vector.transfer_read %subview_4[%c0, %c0, %c0], %cst {in_bounds = [true, true, true]} 
+///   %6 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %4, %5, %arg8
+///   scf.yield %6 : !type
+///  }
+///  scf.yield %3 : !type
+/// }
+/// vector.transfer_write %2, %subview_1[%c0, %c0] {in_bounds = [true, true]}
+/// to:
+/// %subview_1 = memref.subview %subview[%arg3, %arg4] [4, 64] [1, 1] 
+/// %subview_2 = memref.subview %subview_1[0, 0] [1, 64] [1, 1] 
+/// %subview_3 = memref.subview %subview_1[1, 0] [1, 64] [1, 1] 
+/// %subview_4 = memref.subview %subview_1[2, 0] [1, 64] [1, 1] 
+/// %subview_5 = memref.subview %subview_1[3, 0] [1, 64] [1, 1] 
+/// %1 = vector.load %subview_2[%c0, %c0] 
+/// %2 = vector.load %subview_3[%c0, %c0] 
+/// %3 = vector.load %subview_4[%c0, %c0] 
+/// %4 = vector.load %subview_5[%c0, %c0] 
+/// %5:4 = scf.for %arg5 = %c0 to %c24 step %c1 iter_args(%arg6 = %1, %arg7 = %2, %arg8 = %3, %arg9 = %4) -> (!type, !type, !type, !type) {
+///   %6:4 = scf.for %arg10 = %c0 to %c64 step %c1 iter_args(%arg11 = %arg6, %arg12 = %arg7, %arg13 = %arg8, %arg14 = %arg9) -> (!type, !type, !type, !type) {
+///     %subview_6 = memref.subview %subview_0[%arg5, %arg3, %arg10] [1, 4, 1] [1, 1, 1] 
+///     %7 = memref.load %subview_6[%c0, %c0, %c0] 
+///     %8 = vector.broadcast %7 : f32 to !type
+///     %9 = memref.load %subview_6[%c0, %c1, %c0] 
+///     %10 = vector.broadcast %9 : f32 to !type
+///     %11 = memref.load %subview_6[%c0, %c2, %c0] 
+///     %12 = vector.broadcast %11 : f32 to !type
+///     %13 = memref.load %subview_6[%c0, %c3, %c0] 
+///     %14 = vector.broadcast %13 : f32 to !type
+///     %subview_7 = memref.subview %0[%arg5, %arg10, %arg4] [1, 1, 64] [1, 1, 1] 
+///     %15 = vector.load %subview_7[%c0, %c0, %c0] 
+///     %16 = vector.fma %8, %15, %arg11 : !type
+///     %17 = vector.fma %10, %15, %arg12 : !type
+///     %18 = vector.fma %12, %15, %arg13 : !type
+///     %19 = vector.fma %14, %15, %arg14 : !type
+///     scf.yield %16, %17, %18, %19 : !type, !type, !type, !type
+///   }
+///   scf.yield %6#0, %6#1, %6#2, %6#3 : !type, !type, !type, !type
+/// }
+/// vector.store %5#0, %subview_2[%c0, %c0] 
+/// vector.store %5#1, %subview_3[%c0, %c0] 
+/// vector.store %5#2, %subview_4[%c0, %c0] 
+/// vector.store %5#3, %subview_5[%c0, %c0])
+///
 struct VectorContractToFMA
     : public OpRewritePattern<vector::ContractionOp> {
   using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
diff --git a/mlir/test/Dialect/Linalg/hoist-vector-transfer.mlir b/mlir/test/Dialect/Linalg/hoist-vector-transfer.mlir
index f1f24e4e53cb66..3b57f159108ead 100644
--- a/mlir/test/Dialect/Linalg/hoist-vector-transfer.mlir
+++ b/mlir/test/Dialect/Linalg/hoist-vector-transfer.mlir
@@ -4,7 +4,7 @@
 #map1 = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
 #map2 = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
   memref.global "private" constant @__constant_24x64x64xf32 : memref<24x64x64xf32> = dense<1.000000e+00> {alignment = 64 : i64}
-  func.func @simple_gemm(%arg0: memref<8x24x32x64xf32>) -> memref<8x24x32x64xf32> {
+  func.func @tiled_gemm_hoist_vector_transfer_operations(%arg0: memref<8x24x32x64xf32>) -> memref<8x24x32x64xf32> {
     %cst = arith.constant 0.000000e+00 : f32
     %cst_0 = arith.constant dense<0.000000e+00> : vector<32x64xf32>
     %c1 = arith.constant 1 : index
@@ -46,7 +46,7 @@
 
 // CHECK-LABEL:   memref.global "private" constant @__constant_24x64x64xf32 : memref<24x64x64xf32> = dense<1.000000e+00> {alignment = 64 : i64}
 
-// CHECK-LABEL:   func.func @simple_gemm(
+// CHECK-LABEL:   func.func @tiled_gemm_hoist_vector_transfer_operations(
 // CHECK-SAME:                     %[[VAL_0:.*]]: memref<8x24x32x64xf32>) -> memref<8x24x32x64xf32> {
 // CHECK:           %[[VAL_1:.*]] = arith.constant 0.000000e+00 : f32
 // CHECK:           %[[VAL_2:.*]] = arith.constant dense<0.000000e+00> : vector<32x64xf32>
@@ -95,3 +95,92 @@ module attributes {transform.with_named_sequence} {
     transform.yield
   }
 }
+
+
+// -----
+
+#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
+#map2 = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+module {
+  memref.global "private" constant @__constant_24x64x64xf32 : memref<24x64x64xf32> = dense<1.000000e+00> {alignment = 64 : i64}
+  func.func @gemm_without_tiling_so_no_hoisting(%arg0: memref<8x24x32x64xf32>) -> memref<8x24x32x64xf32> {
+    %cst = arith.constant 0.000000e+00 : f32
+    %cst_0 = arith.constant dense<0.000000e+00> : vector<32x64xf32>
+    %c0 = arith.constant 0 : index
+    %0 = memref.get_global @__constant_24x64x64xf32 : memref<24x64x64xf32>
+    %alloc = memref.alloc() {alignment = 64 : i64} : memref<8x24x32x64xf32>
+    scf.forall (%arg1, %arg2) in (8, 24) {
+      %subview = memref.subview %alloc[%arg1, %arg2, 0, 0] [1, 1, 32, 64] [1, 1, 1, 1] : memref<8x24x32x64xf32> to memref<32x64xf32, strided<[64, 1], offset: ?>>
+      vector.transfer_write %cst_0, %subview[%c0, %c0] {in_bounds = [true, true]} : vector<32x64xf32>, memref<32x64xf32, strided<[64, 1], offset: ?>>
+      %subview_1 = memref.subview %arg0[%arg1, 0, 0, 0] [1, 24, 32, 64] [1, 1, 1, 1] : memref<8x24x32x64xf32> to memref<24x32x64xf32, strided<[2048, 64, 1], offset: ?>>
+      %1 = vector.transfer_read %subview_1[%c0, %c0, %c0], %cst {in_bounds = [true, true, true]} : memref<24x32x64xf32, strided<[2048, 64, 1], offset: ?>>, vector<24x32x64xf32>
+      %2 = vector.transfer_read %0[%c0, %c0, %c0], %cst {in_bounds = [true, true, true]} : memref<24x64x64xf32>, vector<24x64x64xf32>
+      %3 = vector.transfer_read %subview[%c0, %c0], %cst {in_bounds = [true, true]} : memref<32x64xf32, strided<[64, 1], offset: ?>>, vector<32x64xf32>
+      %4 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %1, %2, %3 : vector<24x32x64xf32>, vector<24x64x64xf32> into vector<32x64xf32>
+      vector.transfer_write %4, %subview[%c0, %c0] {in_bounds = [true, true]} : vector<32x64xf32>, memref<32x64xf32, strided<[64, 1], offset: ?>>
+    }
+    return %alloc : memref<8x24x32x64xf32>
+  }
+}
+
+// CHECK-LABEL: func.func @gemm_without_tiling_so_no_hoisting
+// CHECK: memref.subview
+// CHECK-NEXT: vector.transfer_write
+// CHECK-NEXT: memref.subview
+// CHECK-NEXT: vector.transfer_read
+// CHECK-NEXT: vector.transfer_read
+// CHECK-NEXT: vector.transfer_read
+// CHECK-NEXT: vector.contract
+// CHECK-NEXT: vector.transfer_write
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.vector.hoist_vector_transfer
+    } : !transform.any_op
+    transform.yield
+  }
+}
+
+
+// -----
+
+#map = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+module {
+  func.func @gemm_with_args_so_no_hoisting(%arg0: tensor<4x1xf32>, %arg1: tensor<1x64xf32>, %arg2: tensor<4x64xf32>) -> tensor<4x64xf32> {
+    %c0 = arith.constant 0 : index
+    %cst = arith.constant 0.000000e+00 : f32
+    %0 = vector.transfer_read %arg0[%c0, %c0], %cst {in_bounds = [true, true]} : tensor<4x1xf32>, vector<4x1xf32>
+    %1 = vector.transfer_read %arg1[%c0, %c0], %cst {in_bounds = [true, true]} : tensor<1x64xf32>, vector<1x64xf32>
+    %2 = vector.transfer_read %arg2[%c0, %c0], %cst {in_bounds = [true, true]} : tensor<4x64xf32>, vector<4x64xf32>
+    %3 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %0, %1, %2 : vector<4x1xf32>, vector<1x64xf32> into vector<4x64xf32>
+    %4 = vector.transfer_write %3, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<4x64xf32>, tensor<4x64xf32>
+    return %4 : tensor<4x64xf32>
+  }
+}
+
+
+// CHECK-LABEL: func.func @gemm_with_args_so_no_hoisting
+// CHECK: vector.transfer_read
+// CHECK-NEXT: vector.transfer_read
+// CHECK-NEXT: vector.transfer_read
+// CHECK-NEXT: vector.contract
+// CHECK-NEXT: vector.transfer_write
+// CHECK-NEXT: return
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.vector.hoist_vector_transfer
+    } : !transform.any_op
+    transform.yield
+  }
+}
+
+
diff --git a/mlir/test/Dialect/Linalg/vector-contract-to-fma.mlir b/mlir/test/Dialect/Linalg/vector-contract-to-fma.mlir
index ba11074e3c9637..f18c0dcb573d70 100644
--- a/mlir/test/Dialect/Linalg/vector-contract-to-fma.mlir
+++ b/mlir/test/Dialect/Linalg/vector-contract-to-fma.mlir
@@ -1,106 +1,41 @@
 // RUN: mlir-opt %s -transform-interpreter -cse -split-input-file | FileCheck %s
 
 #map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
-#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
+#mapTransposeB = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
 #map2 = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
-  memref.global "private" constant @__constant_24x64x64xf32 : memref<24x64x64xf32> = dense<1.000000e+00> {alignment = 64 : i64}
-  func.func @simple_gemm(%arg0: memref<8x24x32x64xf32>) -> memref<8x24x32x64xf32> {
+
+  func.func @transpose_matrix_no_conversion_to_fma(%arg0: memref<16x32x128xf32>, %arg1: memref<16x128x64xf32>, %arg2: memref<32x64xf32>) {
     %cst = arith.constant 0.000000e+00 : f32
-    %cst_0 = arith.constant dense<0.000000e+00> : vector<32x64xf32>
     %c1 = arith.constant 1 : index
-    %c24 = arith.constant 24 : index
+    %c16 = arith.constant 16 : index
     %c64 = arith.constant 64 : index
+    %c128 = arith.constant 128 : index
     %c4 = arith.constant 4 : index
     %c32 = arith.constant 32 : index
     %c0 = arith.constant 0 : index
-    %0 = memref.get_global @__constant_24x64x64xf32 : memref<24x64x64xf32>
-    %alloc = memref.alloc() {alignment = 64 : i64} : memref<8x24x32x64xf32>
-    scf.forall (%arg1, %arg2) in (8, 24) {
-      %subview = memref.subview %alloc[%arg1, %arg2, 0, 0] [1, 1, 32, 64] [1, 1, 1, 1] : memref<8x24x32x64xf32> to memref<32x64xf32, strided<[64, 1], offset: ?>>
-      vector.transfer_write %cst_0, %subview[%c0, %c0] {in_bounds = [true, true]} : vector<32x64xf32>, memref<32x64xf32, strided<[64, 1], offset: ?>>
-      %subview_1 = memref.subview %arg0[%arg1, 0, 0, 0] [1, 24, 32, 64] [1, 1, 1, 1] : memref<8x24x32x64xf32> to memref<24x32x64xf32, strided<[2048, 64, 1], offset: ?>>
-      scf.for %arg3 = %c0 to %c32 step %c4 {
-        scf.for %arg4 = %c0 to %c64 step %c64 {
-          %subview_2 = memref.subview %subview[%arg3, %arg4] [4, 64] [1, 1] : memref<32x64xf32, strided<[64, 1], offset: ?>> to memref<4x64xf32, strided<[64, 1], offset: ?>>
-          %1 = vector.transfer_read %subview_2[%c0, %c0], %cst {in_bounds = [true, true]} : memref<4x64xf32, strided<[64, 1], offset: ?>>, vector<4x64xf32>
-          %2 = scf.for %arg5 = %c0 to %c24 step %c1 iter_args(%arg6 = %1) -> (vector<4x64xf32>) {
-            %3 = scf.for %arg7 = %c0 to %c64 step %c1 iter_args(%arg8 = %arg6) -> (vector<4x64xf32>) {
-              %subview_3 = memref.subview %subview_1[%arg5, %arg3, %arg7] [1, 4, 1] [1, 1, 1] : memref<24x32x64xf32, strided<[2048, 64, 1], offset: ?>> to memref<1x4x1xf32, strided<[2048, 64, 1], offset: ?>>
-              %subview_4 = memref.subview %0[%arg5, %arg7, %arg4] [1, 1, 64] [1, 1, 1] : memref<24x64x64xf32> to memref<1x1x64xf32, strided<[4096, 64, 1], offset: ?>>
-              %4 = vector.transfer_read %subview_3[%c0, %c0, %c0], %cst {in_bounds = [true, true, true]} : memref<1x4x1xf32, strided<[2048, 64, 1], offset: ?>>, vector<1x4x1xf32>
-              %5 = vector.transfer_read %subview_4[%c0, %c0, %c0], %cst {in_bounds = [true, true, true]} : memref<1x1x64xf32, strided<[4096, 64, 1], offset: ?>>, vector<1x1x64xf32>
-              %6 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %4, %5, %arg8 : vector<1x4x1xf32>, vector<1x1x64xf32> into vector<4x64xf32>
-              scf.yield %6 : vector<4x64xf32>
-            }
+
+    scf.for %arg5 = %c0 to %c32 step %c4 {
+      scf.for %arg6 = %c0 to %c128 step %c64 {
+        %subview_2 = memref.subview %arg2[%arg5, %arg6] [4, 64] [1, 1] : memref<32x64xf32> to memref<4x64xf32, strided<[64, 1], offset: ?>>
+        %2 = vector.transfer_read %subview_2[%c0, %c0], %cst {in_bounds = [true, true]} : memref<4x64xf32, strided<[64, 1], offset: ?>>, vector<4x64xf32>
+        %con = scf.for %arg7 = %c0 to %c16 step %c1 iter_args(%argcon = %2) -> vector<4x64xf32> {
+          %con1 = scf.for %arg8 = %c0 to %c64 step %c1 iter_args(%argcon1 = %argcon) -> vector<4x64xf32> {
+            %subview_3 = memref.subview %arg0[%arg7, %arg5, %arg8] [1, 4, 1] [1, 1, 1] : memref<16x32x128xf32> to memref<1x4x1xf32, strided<[4096, 128, 1], offset: ?>>
+            %subview_4 = memref.subview %arg1[%arg7, %arg8, %arg6] [1, 1, 64] [1, 1, 1] : memref<16x128x64xf32> to memref<1x1x64xf32, strided<[8192, 64, 1], offset: ?>>
+            %0 = vector.transfer_read %subview_3[%c0, %c0, %c0], %cst {in_bounds = [true, true, true]} : memref<1x4x1xf32, strided<[4096, 128, 1], offset: ?>>, vector<1x4x1xf32>
+            %1 = vector.transfer_read %subview_4[%c0, %c0, %c0], %cst {permutation_map = affine_map<(d0, d1, d2) -> (d0, d2, d1)>, in_bounds = [true, true, true]} : memref<1x1x64xf32, strided<[8192, 64, 1], offset: ?>>, vector<1x64x1xf32>
+            %3 = vector.contract {indexing_maps = [#map, #mapTransposeB, #map2], iterator_types = ["reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %0, %1, %argcon1 : vector<1x4x1xf32>, vector<1x64x1xf32> into vector<4x64xf32>
             scf.yield %3 : vector<4x64xf32>
           }
-          vector.transfer_write %2, %subview_2[%c0, %c0] {in_bounds = [true, true]} : vector<4x64xf32>, memref<4x64xf32, strided<[64, 1], offset: ?>>
+          scf.yield %con1 : vector<4x64xf32>
         }
+        vector.transfer_write %con, %subview_2[%c0, %c0] {in_bounds = [true, true]} : vector<4x64xf32>, memref<4x64xf32, strided<[64, 1], offset: ?>>
       }
     }
-    return %alloc : memref<8x24x32x64xf32>
+    return
   }
 
-// CHECK-LABEL:   memref.global "private" constant @__constant_24x64x64xf32 : memref<24x64x64xf32> = dense<1.000000e+00> {alignment = 64 : i64}
-
-// CHECK-LABEL:   func.func @simple_gemm(
-// CHECK-SAME:                           %[[VAL_0:.*]]: memref<8x24x32x64xf32>) -> memref<8x24x32x64xf32> {
-// CHECK:           %[[VAL_1:.*]] = arith.constant 3 : index
-// CHECK:           %[[VAL_2:.*]] = arith.constant 2 : index
-// CHECK:           %[[VAL_3:.*]] = arith.constant dense<0.000000e+00> : vector<32x64xf32>
-// CHECK:           %[[VAL_4:.*]] = arith.constant 1 : index
-// CHECK:           %[[VAL_5:.*]] = arith.constant 24 : index
-// CHECK:           %[[VAL_6:.*]] = arith.constant 64 : index
-// CHECK:           %[[VAL_7:.*]] = arith.constant 4 : index
-// CHECK:           %[[VAL_8:.*]] = arith.constant 32 : index
-// CHECK:           %[[VAL_9:.*]] = arith.constant 0 : index
-// CHECK:           %[[VAL_10:.*]] = memref.get_global @__constant_24x64x64xf32 : memref<24x64x64xf32>
-// CHECK:           %[[VAL_11:.*]] = memref.alloc() {alignment = 64 : i64} : memref<8x24x32x64xf32>
-// CHECK:           scf.forall (%[[VAL_12:.*]], %[[VAL_13:.*]]) in (8, 24) {
-// CHECK:             %[[VAL_14:.*]] = memref.subview %[[VAL_11]]{{\[}}%[[VAL_12]], %[[VAL_13]], 0, 0] [1, 1, 32, 64] [1, 1, 1, 1] : memref<8x24x32x64xf32> to memref<32x64xf32, strided<[64, 1], offset: ?>>
-// CHECK:             vector.transfer_write %[[VAL_3]], %[[VAL_14]]{{\[}}%[[VAL_9]], %[[VAL_9]]] {in_bounds = [true, true]} : vector<32x64xf32>, memref<32x64xf32, strided<[64, 1], offset: ?>>
-// CHECK:             %[[VAL_15:.*]] = memref.subview %[[VAL_0]]{{\[}}%[[VAL_12]], 0, 0, 0] [1, 24, 32, 64] [1, 1, 1, 1] : memref<8x24x32x64xf32> to memref<24x32x64xf32, strided<[2048, 64, 1], offset: ?>>
-// CHECK:             scf.for %[[VAL_16:.*]] = %[[VAL_9]] to %[[VAL_8]] step %[[VAL_7]] {
-// CHECK:               scf.for %[[VAL_17:.*]] = %[[VAL_9]] to %[[VAL_6]] step %[[VAL_6]] {
-// CHECK:                 %[[VAL_18:.*]] = memref.subview %[[VAL_14]]{{\[}}%[[VAL_16]], %[[VAL_17]]] [4, 64] [1, 1] : memref<32x64xf32, strided<[64, 1], offset: ?>> to memref<4x64xf32, strided<[64, 1], offset: ?>>
-// CHECK:                 %[[VAL_19:.*]] = memref.subview %[[VAL_18]][0, 0] [1, 64] [1, 1] : memref<4x64xf32, strided<[64, 1], offset: ?>> to memref<1x64xf32, strided<[64, 1], offset: ?>>
-// CHECK:                 %[[VAL_20:.*]] = memref.subview %[[VAL_18]][1, 0] [1, 64] [1, 1] : memref<4x64xf32, strided<[64, 1], offset: ?>> to memref<1x64xf32, strided<[64, 1], offset: ?>>
-// CHECK:                 %[[VAL_21:.*]] = memref.subview %[[VAL_18]][2, 0] [1, 64] [1, 1] : memref<4x64xf32, strided<[64, 1], offset: ?>> to memref<1x64xf32, strided<[64, 1], offset: ?>>
-// CHECK:                 %[[VAL_22:.*]] = memref.subview %[[VAL_18]][3, 0] [1, 64] [1, 1] : memref<4x64xf32, strided<[64, 1], offset: ?>> to memref<1x64xf32, strided<[64, 1], offset: ?>>
-// CHECK:                 %[[VAL_23:.*]] = vector.load %[[VAL_19]]{{\[}}%[[VAL_9]], %[[VAL_9]]] : memref<1x64xf32, strided<[64, 1], offset: ?>>, vector<64xf32>
-// CHECK:                 %[[VAL_24:.*]] = vector.load %[[VAL_20]]{{\[}}%[[VAL_9]], %[[VAL_9]]] : memref<1x64xf32, strided<[64, 1], offset: ?>>, vector<64xf32>
-// CHECK:                 %[[VAL_25:.*]] = vector.load %[[VAL_21]]{{\[}}%[[VAL_9]], %[[VAL_9]]] : memref<1x64xf32, strided<[64, 1], offset: ?>>, vector<64xf32>
-// CHECK:                 %[[VAL_26:.*]] = vector.load %[[VAL_22]]{{\[}}%[[VAL_9]], %[[VAL_9]]] : memref<1x64xf32, strided<[64, 1], offset: ?>>, vector<64xf32>
-// CHECK:                 %[[VAL_27:.*]]:4 = scf.for %[[VAL_28:.*]] = %[[VAL_9]] to %[[VAL_5]] step %[[VAL_4]] iter_args(%[[VAL_29:.*]] = %[[VAL_23]], %[[VAL_30:.*]] = %[[VAL_24]], %[[VAL_31:.*]] = %[[VAL_25]], %[[VAL_32:.*]] = %[[VAL_26]]) -> (vector<64xf32>, vector<64xf32>, vector<64xf32>, vector<64xf32>) {
-// CHECK:                   %[[VAL_33:.*]]:4 = scf.for %[[VAL_34:.*]] = %[[VAL_9]] to %[[VAL_6]] step %[[VAL_4]] iter_args(%[[VAL_35:.*]] = %[[VAL_29]], %[[VAL_36:.*]] = %[[VAL_30]], %[[VAL_37:.*]] = %[[VAL_31]], %[[VAL_38:.*]] = %[[VAL_32]]) -> (vector<64xf32>, vector<64xf32>, vector<64xf32>, vector<64xf32>) {
-// CHECK:                     %[[VAL_39:.*]] = memref.subview %[[VAL_15]]{{\[}}%[[VAL_28]], %[[VAL_16]], %[[VAL_34]]] [1, 4, 1] [1, 1, 1] : memref<24x32x64xf32, strided<[2048, 64, 1], offset: ?>> to memref<1x4x1xf32, strided<[2048, 64, 1], offset: ?>>
-// CHECK:                     %[[VAL_40:.*]] = memref.load %[[VAL_39]]{{\[}}%[[VAL_9]], %[[VAL_9]], %[[VAL_9]]] : memref<1x4x1xf32, strided<[2048, 64, 1], offset: ?>>
-// CHECK:                     %[[VAL_41:.*]] = vector.broadcast %[[VAL_40]] : f32 to vector<64xf32>
-// CHECK:                     %[[VAL_42:.*]] = memref.load %[[VAL_39]]{{\[}}%[[VAL_9]], %[[VAL_4]], %[[VAL_9]]] : memref<1x4x1xf32, strided<[2048, 64, 1], offset: ?>>
-// CHECK:                     %[[VAL_43:.*]] = vector.broadcast %[[VAL_42]] : f32 to vector<64xf32>
-// CHECK:                     %[[VAL_44:.*]] = memref.load %[[VAL_39]]{{\[}}%[[VAL_9]], %[[VAL_2]], %[[VAL_9]]] : memref<1x4x1xf32, strided<[2048, 64, 1], offset: ?>>
-// CHECK:                     %[[VAL_45:.*]] = vector.broadcast %[[VAL_44]] : f32 to vector<64xf32>
-// CHECK:                     %[[VAL_46:.*]] = memref.load %[[VAL_39]]{{\[}}%[[VAL_9]], %[[VAL_1]], %[[VAL_9]]] : memref<1x4x1xf32, strided<[2048, 64, 1], offset: ?>>
-// CHECK:                     %[[VAL_47:.*]] = vector.broadcast %[[VAL_46]] : f32 to vector<64xf32>
-// CHECK:                     %[[VAL_48:.*]] = memref.subview %[[VAL_10]]{{\[}}%[[VAL_28]], %[[VAL_34]], %[[VAL_17]]] [1, 1, 64] [1, 1, 1] : memref<24x64x64xf32> to memref<1x1x64xf32, strided<[4096, 64, 1], offset: ?>>
-// CHECK:                     %[[VAL_49:.*]] = vector.load %[[VAL_48]]{{\[}}%[[VAL_9]], %[[VAL_9]], %[[VAL_9]]] : memref<1x1x64xf32, strided<[4096, 64, 1], offset: ?>>, vector<64xf32>
-// CHECK:                     %[[VAL_50:.*]] = vector.fma %[[VAL_41]], %[[VAL_49]], %[[VAL_35]] : vector<64xf32>
-// CHECK:                     %[[VAL_51:.*]] = vector.fma %[[VAL_43]], %[[VAL_49]], %[[VAL_36]] : vector<64xf32>
-// CHECK:                     %[[VAL_52:.*]] = vector.fma %[[VAL_45]], %[[VAL_49]], %[[VAL_37]] : vector<64xf32>
-// CHECK:                     %[[VAL_53:.*]] = vector.fma %[[VAL_47]], %[[VAL_49]], %[[VAL_38]] : vector<64xf32>
-// CHECK:                     scf.yield %[[VAL_50]], %[[VAL_51]], %[[VAL_52]], %[[VAL_53]] : vector<64xf32>, vector<64xf32>, vector<64xf32>, vector<64xf32>
-// CHECK:                   }
-// CHECK:                   scf.yield %[[VAL_54:.*]]#0, %[[VAL_54]]#1, %[[VAL_54]]#2, %[[VAL_54]]#3 : vector<64xf32>, vector<64xf32>, vector<64xf32>, vector<64xf32>
-// CHECK:                 }
-// CHECK:                 vector.store %[[VAL_55:.*]]#0, %[[VAL_19]]{{\[}}%[[VAL_9]], %[[VAL_9]]] : memref<1x64xf32, strided<[64, 1], offset: ?>>, vector<64xf32>
-// CHECK:                 vector.store %[[VAL_55]]#1, %[[VAL_20]]{{\[}}%[[VAL_9]], %[[VAL_9]]] : memref<1x64xf32, strided<[64, 1], offset: ?>>, vector<64xf32>
-// CHECK:                 vector.store %[[VAL_55]]#2, %[[VAL_21]]{{\[}}%[[VAL_9]], %[[VAL_9]]] : memref<1x64xf32, strided<[64, 1], offset: ?>>, vector<64xf32>
-// CHECK:                 vector.store %[[VAL_55]]#3, %[[VAL_22]]{{\[}}%[[VAL_9]], %[[VAL_9]]] : memref<1x64xf32, strided<[64, 1], offset: ?>>, vector<64xf32>
-// CHECK:               }
-// CHECK:             }
-// CHECK:           }
-// CHECK:           return %[[VAL_11]] : memref<8x24x32x64xf32>
-// CHECK:         }
+// CHECK-NOT: vector.fma
 
   module attributes {transform.with_named_sequence} {
     transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {

>From 8aa56e403e3a30eaae5d3deba6a73072b6f60df7 Mon Sep 17 00:00:00 2001
From: Arun Thangamani <arun.thangamani at intel.com>
Date: Mon, 6 Jan 2025 21:21:38 -0800
Subject: [PATCH 3/3] created a separate PR for vector.contract to fma

---
 .../Linalg/TransformOps/LinalgTransformOps.td |  13 -
 .../Dialect/Linalg/Transforms/Transforms.h    |   5 -
 .../TransformOps/LinalgTransformOps.cpp       |   8 -
 .../Dialect/Linalg/Transforms/CMakeLists.txt  |   1 -
 .../Linalg/Transforms/VectorContractToFMA.cpp | 410 ------------------
 .../Dialect/Linalg/hoist-vector-transfer.mlir | 171 ++++----
 .../Linalg/vector-contract-to-fma.mlir        |  48 --
 7 files changed, 78 insertions(+), 578 deletions(-)
 delete mode 100644 mlir/lib/Dialect/Linalg/Transforms/VectorContractToFMA.cpp
 delete mode 100644 mlir/test/Dialect/Linalg/vector-contract-to-fma.mlir

diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 12b3ddb9c74904..6b890272bb6b49 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -106,7 +106,6 @@ def ApplyFoldAddIntoDestPatternsOp : Op<Transform_Dialect,
   let assemblyFormat = "attr-dict";
 }
 
-
 def ApplyHoistVectorTransferPatternsOp : Op<Transform_Dialect,
     "apply_patterns.vector.hoist_vector_transfer",
     [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
@@ -118,18 +117,6 @@ def ApplyHoistVectorTransferPatternsOp : Op<Transform_Dialect,
   let assemblyFormat = "attr-dict";
 }
 
-
-def ApplyVectorContractToFMAPatternsOp : Op<Transform_Dialect,
-    "apply_patterns.vector.contract_to_fma",
-    [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
-  let description = [{
-     Collects pattern to lower vector contraction operation (batch reduce matmul) for GEMM of size MxN to
-        sequence of vector FMAs.
-  }];
-
-  let assemblyFormat = "attr-dict";
-}
-
 def ApplyPadVectorizationPatternsOp : Op<Transform_Dialect,
     "apply_patterns.linalg.pad_vectorization",
     [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 6f639b45408d87..8a06df4fed3632 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1824,15 +1824,10 @@ void populateConstantFoldLinalgOperations(RewritePatternSet &patterns,
 /// suffices for achieving the sum.
 void populateFoldAddIntoDestPatterns(RewritePatternSet &patterns);
 
-
 /// Pattern to hoists the vector transfer reads/writes outside the reduction and
 /// k-loop for batch reduce matmul operation if licm fails.
 void populateHoistVectorTransferPatterns(RewritePatternSet &patterns);
 
-/// Pattern to lower vector contraction operation (batch reduce matmul) for GEMM of size MxN to
-/// sequence of vector FMAs.
-void populateVectorContractToFMAPatterns(RewritePatternSet &patterns);
-
 /// Pattern to fuse a `tensor.pad` operation with the producer of its source,
 /// if the producer is a `linalg` operation with all parallel iterator types.
 void populateFuseTensorPadWithProducerLinalgOpPatterns(
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index ed838f8476486f..61a3db7302d8db 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -262,19 +262,11 @@ void transform::ApplyFoldAddIntoDestPatternsOp::populatePatterns(
   linalg::populateFoldAddIntoDestPatterns(patterns);
 }
 
-
 void transform::ApplyHoistVectorTransferPatternsOp::populatePatterns(
     RewritePatternSet &patterns) {
   linalg::populateHoistVectorTransferPatterns(patterns);
 }
 
-
-void transform::ApplyVectorContractToFMAPatternsOp::populatePatterns(
-    RewritePatternSet &patterns) {
-  linalg::populateVectorContractToFMAPatterns(patterns);
-}
-
-
 void transform::ApplyPadVectorizationPatternsOp::populatePatterns(
     RewritePatternSet &patterns) {
   linalg::populatePadOpVectorizationPatterns(patterns);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
index 90d926201cd753..63758a654f803f 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
@@ -42,7 +42,6 @@ add_mlir_dialect_library(MLIRLinalgTransforms
   Vectorization.cpp
   WinogradConv2D.cpp
   HoistVectorTransfers.cpp
-  VectorContractToFMA.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Linalg
diff --git a/mlir/lib/Dialect/Linalg/Transforms/VectorContractToFMA.cpp b/mlir/lib/Dialect/Linalg/Transforms/VectorContractToFMA.cpp
deleted file mode 100644
index 4d3dac6a2b4d03..00000000000000
--- a/mlir/lib/Dialect/Linalg/Transforms/VectorContractToFMA.cpp
+++ /dev/null
@@ -1,410 +0,0 @@
-//===--------------- VectorContractToFMA.cpp ------------*- C++-*-===//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// This file implements lowering of vector contraction to vector fma.
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
-#include "mlir/Dialect/MemRef/IR/MemRef.h"
-#include "mlir/Dialect/SCF/IR/SCF.h"
-#include "mlir/Dialect/Tensor/IR/Tensor.h"
-#include "mlir/Dialect/Vector/IR/VectorOps.h"
-#include "mlir/Pass/Pass.h"
-#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
-
-#define DEBUG_TYPE "vector-contract-to-fma"
-
-using namespace mlir;
-
-/// Returns true if the \p map is transposed.
-static bool isTransposed(AffineMap map) {
-  auto results = map.getResults();
-  // Assert if the map does not have 3 or 4 inputs ([] m, n, k).
-  assert((map.getNumInputs() == 3 || map.getNumInputs() == 4) &&
-         "3 or 4 input dim expected");
-  // Assert if the result is not 2D.
-  assert(map.getNumResults() == 2 && "Only 2 output dim expected");
-
-  // Check the last two dimensions for transposition.
-  auto dimExpr0 = dyn_cast<AffineDimExpr>(results[0]);
-  auto dimExpr1 = dyn_cast<AffineDimExpr>(results[1]);
-  assert((dimExpr0 && dimExpr1) && "Unexpected dim expression");
-
-  // Exclude output map result.
-  bool isOutputResultMap =
-      dimExpr0 ==
-          mlir::getAffineDimExpr(map.getNumInputs() - 3, map.getContext()) &&
-      dimExpr1 ==
-          mlir::getAffineDimExpr(map.getNumInputs() - 2, map.getContext());
-  assert(!isOutputResultMap && "Output result map not expected");
-
-  // It's transposed if result found as (k, m) or (n, k), else not transposed.
-  if ((dimExpr0 ==
-           mlir::getAffineDimExpr(map.getNumInputs() - 1, map.getContext()) &&
-       dimExpr1 ==
-           mlir::getAffineDimExpr(map.getNumInputs() - 3, map.getContext())) ||
-      (dimExpr0 ==
-           mlir::getAffineDimExpr(map.getNumInputs() - 2, map.getContext()) &&
-       dimExpr1 ==
-           mlir::getAffineDimExpr(map.getNumInputs() - 1, map.getContext())))
-    return true;
-  return false;
-}
-
-
-// Structure to hold transformation context
-struct TransformationContext {
-  scf::ForOp innerForOp;
-  scf::ForOp outerForOp;
-  scf::ForOp outermostLoop;
-};
-
-enum class MatMulType { Standard, Batch, BatchReduce };
-
-
-/// Pattern to lower vector contraction operation (batch reduce matmul) for GEMM of size MxN to
-/// sequence of vector FMAs.
-///
-/// As an example, the following pseudo-code will be rewritten
-/// %subview_1 = memref.subview %subview[%arg3, %arg4] [4, 64] [1, 1] 
-/// %1 = vector.transfer_read %subview_1[%c0, %c0], %cst {in_bounds = [true, true]} 
-/// %2 = scf.for %arg5 = %c0 to %c24 step %c1 iter_args(%arg6 = %1) -> (!type) {
-///  %3 = scf.for %arg7 = %c0 to %c64 step %c1 iter_args(%arg8 = %arg6) -> (!type) {
-///   %subview_3 = memref.subview %subview_0[%arg5, %arg3, %arg7] [1, 4, 1] [1, 1, 1]
-///   %subview_4 = memref.subview %0[%arg5, %arg7, %arg4] [1, 1, 64] [1, 1, 1] 
-///   %4 = vector.transfer_read %subview_3[%c0, %c0, %c0], %cst {in_bounds = [true, true, true]} 
-///   %5 = vector.transfer_read %subview_4[%c0, %c0, %c0], %cst {in_bounds = [true, true, true]} 
-///   %6 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %4, %5, %arg8
-///   scf.yield %6 : !type
-///  }
-///  scf.yield %3 : !type
-/// }
-/// vector.transfer_write %2, %subview_1[%c0, %c0] {in_bounds = [true, true]}
-/// to:
-/// %subview_1 = memref.subview %subview[%arg3, %arg4] [4, 64] [1, 1] 
-/// %subview_2 = memref.subview %subview_1[0, 0] [1, 64] [1, 1] 
-/// %subview_3 = memref.subview %subview_1[1, 0] [1, 64] [1, 1] 
-/// %subview_4 = memref.subview %subview_1[2, 0] [1, 64] [1, 1] 
-/// %subview_5 = memref.subview %subview_1[3, 0] [1, 64] [1, 1] 
-/// %1 = vector.load %subview_2[%c0, %c0] 
-/// %2 = vector.load %subview_3[%c0, %c0] 
-/// %3 = vector.load %subview_4[%c0, %c0] 
-/// %4 = vector.load %subview_5[%c0, %c0] 
-/// %5:4 = scf.for %arg5 = %c0 to %c24 step %c1 iter_args(%arg6 = %1, %arg7 = %2, %arg8 = %3, %arg9 = %4) -> (!type, !type, !type, !type) {
-///   %6:4 = scf.for %arg10 = %c0 to %c64 step %c1 iter_args(%arg11 = %arg6, %arg12 = %arg7, %arg13 = %arg8, %arg14 = %arg9) -> (!type, !type, !type, !type) {
-///     %subview_6 = memref.subview %subview_0[%arg5, %arg3, %arg10] [1, 4, 1] [1, 1, 1] 
-///     %7 = memref.load %subview_6[%c0, %c0, %c0] 
-///     %8 = vector.broadcast %7 : f32 to !type
-///     %9 = memref.load %subview_6[%c0, %c1, %c0] 
-///     %10 = vector.broadcast %9 : f32 to !type
-///     %11 = memref.load %subview_6[%c0, %c2, %c0] 
-///     %12 = vector.broadcast %11 : f32 to !type
-///     %13 = memref.load %subview_6[%c0, %c3, %c0] 
-///     %14 = vector.broadcast %13 : f32 to !type
-///     %subview_7 = memref.subview %0[%arg5, %arg10, %arg4] [1, 1, 64] [1, 1, 1] 
-///     %15 = vector.load %subview_7[%c0, %c0, %c0] 
-///     %16 = vector.fma %8, %15, %arg11 : !type
-///     %17 = vector.fma %10, %15, %arg12 : !type
-///     %18 = vector.fma %12, %15, %arg13 : !type
-///     %19 = vector.fma %14, %15, %arg14 : !type
-///     scf.yield %16, %17, %18, %19 : !type, !type, !type, !type
-///   }
-///   scf.yield %6#0, %6#1, %6#2, %6#3 : !type, !type, !type, !type
-/// }
-/// vector.store %5#0, %subview_2[%c0, %c0] 
-/// vector.store %5#1, %subview_3[%c0, %c0] 
-/// vector.store %5#2, %subview_4[%c0, %c0] 
-/// vector.store %5#3, %subview_5[%c0, %c0])
-///
-struct VectorContractToFMA
-    : public OpRewritePattern<vector::ContractionOp> {
-  using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(vector::ContractionOp op,
-                                PatternRewriter &rewriter) const override {
-    if (op.getKind() != vector::CombiningKind::ADD)
-      return rewriter.notifyMatchFailure(
-          op, "Unsupported combining kind, only supports ADD at the moment)");
-
-    auto maskableOp = cast<vector::MaskableOpInterface>(op.getOperation());
-    if (maskableOp.isMasked())
-      return rewriter.notifyMatchFailure(op, "Masked contractOp not supported");
-
-    SmallVector<AffineMap, 3> maps = op.getIndexingMapsArray();
-    if (llvm::any_of(
-            maps, [](AffineMap map) { return !map.isProjectedPermutation(); }))
-      return rewriter.notifyMatchFailure(op, "Unexpected map");
-
-    // Check for the variant of matrix multiply.
-    auto iteratorTypes = op.getIteratorTypesArray();
-    MatMulType matmulType;
-    unsigned outerDimIndex = 0;
-    if (iteratorTypes.size() > 3) {
-      outerDimIndex = iteratorTypes.size() - 4;
-      matmulType =
-          iteratorTypes[outerDimIndex] == vector::IteratorType::parallel
-              ? MatMulType::Batch
-              : MatMulType::BatchReduce;
-      outerDimIndex++;
-    } else if (iteratorTypes.size() == 3) {
-      matmulType = MatMulType::Standard;
-    } else {
-      return rewriter.notifyMatchFailure(op, "Not a gemm");
-    }
-
-    if (matmulType == MatMulType::Batch)
-      return rewriter.notifyMatchFailure(op, "Batch matmul not supported");
-    if (iteratorTypes[outerDimIndex] != vector::IteratorType::parallel ||
-        iteratorTypes[outerDimIndex + 1] != vector::IteratorType::parallel ||
-        iteratorTypes[outerDimIndex + 2] != vector::IteratorType::reduction)
-      return rewriter.notifyMatchFailure(op, "Not a gemm");
-
-    SmallVector<Value, 4> results;
-
-    auto lhs = op.getLhs();
-    auto rhs = op.getRhs();
-    auto acc = op.getAcc();
-    auto lhsDefiningOp = lhs.getDefiningOp<vector::TransferReadOp>();
-    auto rhsDefiningOp = rhs.getDefiningOp<vector::TransferReadOp>();
-    auto accDefiningOp = acc.getDefiningOp<vector::TransferReadOp>();
-    if (!lhsDefiningOp || !rhsDefiningOp)
-      return failure();
-
-    // Accumulator can be a TransferReadOp but must be coming from the chain of
-    // iterargs of nested loop.
-    if (accDefiningOp)
-      return failure();
-
-    // Make sure the inputs being read are whole tensor or subview.
-    if (!llvm::all_of(lhsDefiningOp.getIndices(), isZeroIndex) ||
-        !llvm::all_of(rhsDefiningOp.getIndices(), isZeroIndex)) {
-      return failure();
-    }
-
-    auto lhsType = cast<ShapedType>(lhsDefiningOp.getType());
-    auto rhsType = cast<ShapedType>(rhsDefiningOp.getType());
-    // auto accType = acc.getType();
-    //  auto accType = cast<ShapedType>(accDefiningOp.getType());
-
-    if (matmulType == MatMulType::BatchReduce &&
-        (lhsType.getRank() != 3 || rhsType.getRank() != 3))
-      return failure();
-
-    if (matmulType == MatMulType::Standard &&
-        (lhsType.getRank() != 2 || rhsType.getRank() != 2))
-      return failure();
-
-    // Check for non-transposed matrices.
-    auto mapLHS = maps[0];
-    auto mapRHS = maps[1];
-    if (matmulType == MatMulType::BatchReduce) {
-      mapLHS = mapLHS.dropResult(0);
-      mapRHS = mapRHS.dropResult(0);
-    }
-    if (isTransposed(mapLHS) || isTransposed(mapRHS))
-      return rewriter.notifyMatchFailure(
-          op, "Transposed matrices are not expected");
-
-    // Verify that the accumulator is coming through a chain of iterargs of
-    // nested loop and it is define by 'TransferReadOp'.
-    //
-    struct TransformationContext ctx;
-
-    ctx.innerForOp = op->getParentOfType<scf::ForOp>();
-    if (!ctx.innerForOp)
-      return failure();
-    ctx.outerForOp = ctx.innerForOp->getParentOfType<scf::ForOp>();
-    if (!ctx.outerForOp)
-      return failure();
-    ctx.outermostLoop = ctx.outerForOp->getParentOfType<scf::ForOp>();
-    if (!ctx.outermostLoop)
-      return failure();
-
-    // Verify original inner loop has only one iterarg.
-    auto origIterArgs = ctx.innerForOp.getRegionIterArgs();
-    if (origIterArgs.size() != 1)
-      return failure();
-
-    // Verify chain, accumulator must be inner loop's iterarg.
-    auto bbArg = dyn_cast<BlockArgument>(acc);
-    if (!bbArg)
-      return failure();
-
-    // This block arg must be init arg, not induction variable.
-    if (bbArg.getOwner() != ctx.innerForOp.getBody() ||
-        bbArg.getArgNumber() == 0) {
-      return failure();
-    }
-
-    // This iterarg must be intialized by outer loop's iterarg.
-    auto innerInitValue =
-        ctx.innerForOp.getInitArgs()[bbArg.getArgNumber() - 1];
-    auto outerBBArg = dyn_cast<BlockArgument>(innerInitValue);
-    if (!outerBBArg)
-      return failure();
-
-    // This block arg must be init arg, not induction variable.
-    if (outerBBArg.getOwner() != ctx.outerForOp.getBody() ||
-        outerBBArg.getArgNumber() == 0) {
-      return failure();
-    }
-
-    // Outer loop's iterarg initializer must be a TransferReadOp.
-    acc = ctx.outerForOp.getInitArgs()[outerBBArg.getArgNumber() - 1];
-
-    //  This must be defined by vector.transfer_read
-    if (!acc.getDefiningOp<vector::TransferReadOp>())
-      return failure();
-
-    accDefiningOp = acc.getDefiningOp<vector::TransferReadOp>();
-    if (!accDefiningOp)
-      return failure();
-
-    // Only 2-D output expected.
-    auto accType = cast<ShapedType>(accDefiningOp.getType());
-    if (accType.getRank() != 2)
-      return failure();
-
-    int64_t M = accType.getDimSize(0);
-    int64_t N = accType.getDimSize(1);
-    int64_t K = lhsType.getDimSize(lhsType.getRank() - 1);
-
-    // K must be 1.
-    if (K != 1)
-      return failure();
-
-    auto accSubview = accDefiningOp.getSource();
-    Location loc = op.getLoc();
-
-    // Create M different <1xN> subviews.
-    auto memrefType = cast<MemRefType>(accSubview.getType());
-    auto elementType = memrefType.getElementType();
-    SmallVector<OpFoldResult> mixedSizes = {rewriter.getIndexAttr(K),
-                                            rewriter.getIndexAttr(N)};
-    SmallVector<OpFoldResult> mixedStrides = {rewriter.getIndexAttr(1),
-                                              rewriter.getIndexAttr(1)};
-
-    rewriter.setInsertionPoint(
-        ctx.outermostLoop.getBody(),
-        std::prev(ctx.outermostLoop.getBody()->end(), 1));
-
-    Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0);
-    SmallVector<Value, 4> subview_2_splits;
-    for (int i = 0; i < M; i++) {
-      SmallVector<OpFoldResult> mixedOffsets = {
-          rewriter.getIndexAttr(i),
-          rewriter.getIndexAttr(0),
-      };
-      auto split = rewriter.create<memref::SubViewOp>(
-          loc, accSubview, mixedOffsets, mixedSizes, mixedStrides);
-      subview_2_splits.push_back(split);
-    }
-
-    // Intialize each accumulator with a vector of size N
-    SmallVector<Value, 4> initAccs;
-    for (auto subview : subview_2_splits) {
-      auto acc = rewriter.create<vector::LoadOp>(
-          loc, VectorType::get({N}, elementType), subview, ValueRange{c0, c0});
-      initAccs.push_back(acc);
-    }
-
-    // Create new outer loop with M different accumulators.
-    auto newOuterForOp = rewriter.create<scf::ForOp>(
-        loc, ctx.outerForOp.getLowerBound(), ctx.outerForOp.getUpperBound(),
-        ctx.outerForOp.getStep(), initAccs,
-        [&](OpBuilder &nestedBuilder, Location loc, Value iv,
-            ValueRange iterArgs) {
-          // Create new inner loop with M accumulators.
-          auto newInnerForOp = nestedBuilder.create<scf::ForOp>(
-              loc, ctx.innerForOp.getLowerBound(),
-              ctx.innerForOp.getUpperBound(), ctx.innerForOp.getStep(),
-              iterArgs,
-              [&](OpBuilder &innerBuilder, Location loc, Value innerIv,
-                  ValueRange innerIterArgs) {
-                IRMapping mapping;
-                mapping.map(
-                    lhsDefiningOp.getSource().getDefiningOp()->getOperand(1),
-                    iv);
-                mapping.map(
-                    lhsDefiningOp.getSource().getDefiningOp()->getOperand(3),
-                    innerIv);
-                auto lhsClone = innerBuilder.clone(
-                    *lhsDefiningOp.getSource().getDefiningOp(), mapping);
-
-                // Load and broadcast individual elements
-                SmallVector<Value, 4> broadcasts;
-                for (int i = 0; i < M; i++) {
-                  auto elem = innerBuilder.create<memref::LoadOp>(
-                      loc, lhsClone->getResult(0),
-                      ValueRange{
-                          c0,
-                          innerBuilder.create<arith::ConstantIndexOp>(loc, i),
-                          c0});
-                  auto bcast = innerBuilder.create<vector::BroadcastOp>(
-                      loc, VectorType::get({N}, elem.getType()), elem);
-                  broadcasts.push_back(bcast);
-                }
-
-                IRMapping rhsMapping;
-                rhsMapping.map(
-                    rhsDefiningOp.getSource().getDefiningOp()->getOperand(1),
-                    iv);
-                rhsMapping.map(
-                    rhsDefiningOp.getSource().getDefiningOp()->getOperand(2),
-                    innerIv);
-                auto rhsClone = innerBuilder.clone(
-                    *rhsDefiningOp.getSource().getDefiningOp(), rhsMapping);
-                auto rowVec = innerBuilder.create<vector::LoadOp>(
-                    loc, VectorType::get({N}, elementType),
-                    rhsClone->getResult(0), ValueRange{c0, c0, c0});
-
-                // Create M different FMAs using broadcasts and current
-                // accumulator values.
-                for (int i = 0; i < M; i++) {
-                  auto fma = innerBuilder.create<vector::FMAOp>(
-                      loc, broadcasts[i], rowVec, innerIterArgs[i]);
-                  results.push_back(fma);
-                }
-
-                // Yield all M results
-                innerBuilder.create<scf::YieldOp>(loc, results);
-              });
-
-          // Yield results from inner loop to outer loop
-          nestedBuilder.create<scf::YieldOp>(loc, newInnerForOp.getResults());
-        });
-
-    Value matResult = ctx.outerForOp.getResult(0);
-    Operation *writeOp;
-    for (auto user : matResult.getUsers()) {
-      writeOp = dyn_cast<vector::TransferWriteOp>(user);
-      if (writeOp)
-        break;
-    }
-
-    // Store final results back to original locations.
-    if (writeOp) {
-      for (int i = 0; i < M; i++) {
-        rewriter.create<vector::StoreOp>(loc, newOuterForOp.getResult(i),
-                                         subview_2_splits[i],
-                                         ValueRange{c0, c0});
-      }
-    }
-
-    // Erase original write.
-    if (writeOp)
-      rewriter.eraseOp(writeOp);
-
-    return success();
-  }
-
-};
-
-void linalg::populateVectorContractToFMAPatterns(RewritePatternSet &patterns) {
-  patterns.add<VectorContractToFMA>(patterns.getContext());
-}
diff --git a/mlir/test/Dialect/Linalg/hoist-vector-transfer.mlir b/mlir/test/Dialect/Linalg/hoist-vector-transfer.mlir
index 3b57f159108ead..b0b164951d4b32 100644
--- a/mlir/test/Dialect/Linalg/hoist-vector-transfer.mlir
+++ b/mlir/test/Dialect/Linalg/hoist-vector-transfer.mlir
@@ -3,51 +3,48 @@
 #map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
 #map1 = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
 #map2 = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
-  memref.global "private" constant @__constant_24x64x64xf32 : memref<24x64x64xf32> = dense<1.000000e+00> {alignment = 64 : i64}
-  func.func @tiled_gemm_hoist_vector_transfer_operations(%arg0: memref<8x24x32x64xf32>) -> memref<8x24x32x64xf32> {
-    %cst = arith.constant 0.000000e+00 : f32
-    %cst_0 = arith.constant dense<0.000000e+00> : vector<32x64xf32>
-    %c1 = arith.constant 1 : index
-    %c24 = arith.constant 24 : index
-    %c64 = arith.constant 64 : index
-    %c4 = arith.constant 4 : index
-    %c32 = arith.constant 32 : index
-    %c0 = arith.constant 0 : index
-    %0 = memref.get_global @__constant_24x64x64xf32 : memref<24x64x64xf32>
-    %alloc = memref.alloc() {alignment = 64 : i64} : memref<8x24x32x64xf32>
-    scf.forall (%arg1, %arg2) in (8, 24) {
-      %subview = memref.subview %alloc[%arg1, %arg2, 0, 0] [1, 1, 32, 64] [1, 1, 1, 1] : memref<8x24x32x64xf32> to memref<32x64xf32, strided<[64, 1], offset: ?>>
-      vector.transfer_write %cst_0, %subview[%c0, %c0] {in_bounds = [true, true]} : vector<32x64xf32>, memref<32x64xf32, strided<[64, 1], offset: ?>>
-      %subview_1 = memref.subview %arg0[%arg1, 0, 0, 0] [1, 24, 32, 64] [1, 1, 1, 1] : memref<8x24x32x64xf32> to memref<24x32x64xf32, strided<[2048, 64, 1], offset: ?>>
-      scf.for %arg3 = %c0 to %c32 step %c4 {
-        scf.for %arg4 = %c0 to %c64 step %c64 {
-          %subview_2 = memref.subview %subview[%arg3, %arg4] [4, 64] [1, 1] : memref<32x64xf32, strided<[64, 1], offset: ?>> to memref<4x64xf32, strided<[64, 1], offset: ?>>
-          scf.for %arg5 = %c0 to %c24 step %c1 {
-            scf.for %arg6 = %c0 to %c64 step %c1 {
-              %subview_3 = memref.subview %subview_1[%arg5, %arg3, %arg6] [1, 4, 1] [1, 1, 1] : memref<24x32x64xf32, strided<[2048, 64, 1], offset: ?>> to memref<1x4x1xf32, strided<[2048, 64, 1], offset: ?>>
-              %subview_4 = memref.subview %0[%arg5, %arg6, %arg4] [1, 1, 64] [1, 1, 1] : memref<24x64x64xf32> to memref<1x1x64xf32, strided<[4096, 64, 1], offset: ?>>
-              %1 = vector.transfer_read %subview_3[%c0, %c0, %c0], %cst {in_bounds = [true, true, true]} : memref<1x4x1xf32, strided<[2048, 64, 1], offset: ?>>, vector<1x4x1xf32>
-              %2 = vector.transfer_read %subview_4[%c0, %c0, %c0], %cst {in_bounds = [true, true, true]} : memref<1x1x64xf32, strided<[4096, 64, 1], offset: ?>>, vector<1x1x64xf32>
-              %3 = vector.transfer_read %subview_2[%c0, %c0], %cst {in_bounds = [true, true]} : memref<4x64xf32, strided<[64, 1], offset: ?>>, vector<4x64xf32>
-              %4 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %1, %2, %3 : vector<1x4x1xf32>, vector<1x1x64xf32> into vector<4x64xf32>
-              vector.transfer_write %4, %subview_2[%c0, %c0] {in_bounds = [true, true]} : vector<4x64xf32>, memref<4x64xf32, strided<[64, 1], offset: ?>>
-            }
+memref.global "private" constant @__constant_24x64x64xf32 : memref<24x64x64xf32> = dense<1.000000e+00> {alignment = 64 : i64}
+func.func @tiled_gemm_hoist_vector_transfer_operations(%arg0: memref<8x24x32x64xf32>) -> memref<8x24x32x64xf32> {
+  %cst = arith.constant 0.000000e+00 : f32
+  %cst_0 = arith.constant dense<0.000000e+00> : vector<32x64xf32>
+  %c1 = arith.constant 1 : index
+  %c24 = arith.constant 24 : index
+  %c64 = arith.constant 64 : index
+  %c4 = arith.constant 4 : index
+  %c32 = arith.constant 32 : index
+  %c0 = arith.constant 0 : index
+  %0 = memref.get_global @__constant_24x64x64xf32 : memref<24x64x64xf32>
+  %alloc = memref.alloc() {alignment = 64 : i64} : memref<8x24x32x64xf32>
+  scf.forall (%arg1, %arg2) in (8, 24) {
+    %subview = memref.subview %alloc[%arg1, %arg2, 0, 0] [1, 1, 32, 64] [1, 1, 1, 1] : memref<8x24x32x64xf32> to memref<32x64xf32, strided<[64, 1], offset: ?>>
+    vector.transfer_write %cst_0, %subview[%c0, %c0] {in_bounds = [true, true]} : vector<32x64xf32>, memref<32x64xf32, strided<[64, 1], offset: ?>>
+    %subview_1 = memref.subview %arg0[%arg1, 0, 0, 0] [1, 24, 32, 64] [1, 1, 1, 1] : memref<8x24x32x64xf32> to memref<24x32x64xf32, strided<[2048, 64, 1], offset: ?>>
+    scf.for %arg3 = %c0 to %c32 step %c4 {
+      scf.for %arg4 = %c0 to %c64 step %c64 {
+        %subview_2 = memref.subview %subview[%arg3, %arg4] [4, 64] [1, 1] : memref<32x64xf32, strided<[64, 1], offset: ?>> to memref<4x64xf32, strided<[64, 1], offset: ?>>
+        scf.for %arg5 = %c0 to %c24 step %c1 {
+          scf.for %arg6 = %c0 to %c64 step %c1 {
+            %subview_3 = memref.subview %subview_1[%arg5, %arg3, %arg6] [1, 4, 1] [1, 1, 1] : memref<24x32x64xf32, strided<[2048, 64, 1], offset: ?>> to memref<1x4x1xf32, strided<[2048, 64, 1], offset: ?>>
+            %subview_4 = memref.subview %0[%arg5, %arg6, %arg4] [1, 1, 64] [1, 1, 1] : memref<24x64x64xf32> to memref<1x1x64xf32, strided<[4096, 64, 1], offset: ?>>
+            %1 = vector.transfer_read %subview_3[%c0, %c0, %c0], %cst {in_bounds = [true, true, true]} : memref<1x4x1xf32, strided<[2048, 64, 1], offset: ?>>, vector<1x4x1xf32>
+            %2 = vector.transfer_read %subview_4[%c0, %c0, %c0], %cst {in_bounds = [true, true, true]} : memref<1x1x64xf32, strided<[4096, 64, 1], offset: ?>>, vector<1x1x64xf32>
+            %3 = vector.transfer_read %subview_2[%c0, %c0], %cst {in_bounds = [true, true]} : memref<4x64xf32, strided<[64, 1], offset: ?>>, vector<4x64xf32>
+            %4 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %1, %2, %3 : vector<1x4x1xf32>, vector<1x1x64xf32> into vector<4x64xf32>
+            vector.transfer_write %4, %subview_2[%c0, %c0] {in_bounds = [true, true]} : vector<4x64xf32>, memref<4x64xf32, strided<[64, 1], offset: ?>>
           }
         }
       }
     }
-    return %alloc : memref<8x24x32x64xf32>
   }
-
+  return %alloc : memref<8x24x32x64xf32>
+}
 
 // CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
 // CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
 // CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
-
 // CHECK-LABEL:   memref.global "private" constant @__constant_24x64x64xf32 : memref<24x64x64xf32> = dense<1.000000e+00> {alignment = 64 : i64}
-
 // CHECK-LABEL:   func.func @tiled_gemm_hoist_vector_transfer_operations(
-// CHECK-SAME:                     %[[VAL_0:.*]]: memref<8x24x32x64xf32>) -> memref<8x24x32x64xf32> {
+// CHECK-SAME:      %[[VAL_0:.*]]: memref<8x24x32x64xf32>) -> memref<8x24x32x64xf32> {
 // CHECK:           %[[VAL_1:.*]] = arith.constant 0.000000e+00 : f32
 // CHECK:           %[[VAL_2:.*]] = arith.constant dense<0.000000e+00> : vector<32x64xf32>
 // CHECK:           %[[VAL_3:.*]] = arith.constant 1 : index
@@ -84,44 +81,39 @@
 // CHECK:           return %[[VAL_10]] : memref<8x24x32x64xf32>
 // CHECK:         }
 
-
-
 module attributes {transform.with_named_sequence} {
-  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
-    %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
-    transform.apply_patterns to %func {
-      transform.apply_patterns.vector.hoist_vector_transfer
-    } : !transform.any_op
-    transform.yield
-  }
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+   %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+   transform.apply_patterns to %func {
+     transform.apply_patterns.vector.hoist_vector_transfer
+   } : !transform.any_op
+   transform.yield
+ }
 }
 
-
 // -----
 
 #map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
 #map1 = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
 #map2 = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
-module {
-  memref.global "private" constant @__constant_24x64x64xf32 : memref<24x64x64xf32> = dense<1.000000e+00> {alignment = 64 : i64}
-  func.func @gemm_without_tiling_so_no_hoisting(%arg0: memref<8x24x32x64xf32>) -> memref<8x24x32x64xf32> {
-    %cst = arith.constant 0.000000e+00 : f32
-    %cst_0 = arith.constant dense<0.000000e+00> : vector<32x64xf32>
-    %c0 = arith.constant 0 : index
-    %0 = memref.get_global @__constant_24x64x64xf32 : memref<24x64x64xf32>
-    %alloc = memref.alloc() {alignment = 64 : i64} : memref<8x24x32x64xf32>
-    scf.forall (%arg1, %arg2) in (8, 24) {
-      %subview = memref.subview %alloc[%arg1, %arg2, 0, 0] [1, 1, 32, 64] [1, 1, 1, 1] : memref<8x24x32x64xf32> to memref<32x64xf32, strided<[64, 1], offset: ?>>
-      vector.transfer_write %cst_0, %subview[%c0, %c0] {in_bounds = [true, true]} : vector<32x64xf32>, memref<32x64xf32, strided<[64, 1], offset: ?>>
-      %subview_1 = memref.subview %arg0[%arg1, 0, 0, 0] [1, 24, 32, 64] [1, 1, 1, 1] : memref<8x24x32x64xf32> to memref<24x32x64xf32, strided<[2048, 64, 1], offset: ?>>
-      %1 = vector.transfer_read %subview_1[%c0, %c0, %c0], %cst {in_bounds = [true, true, true]} : memref<24x32x64xf32, strided<[2048, 64, 1], offset: ?>>, vector<24x32x64xf32>
-      %2 = vector.transfer_read %0[%c0, %c0, %c0], %cst {in_bounds = [true, true, true]} : memref<24x64x64xf32>, vector<24x64x64xf32>
-      %3 = vector.transfer_read %subview[%c0, %c0], %cst {in_bounds = [true, true]} : memref<32x64xf32, strided<[64, 1], offset: ?>>, vector<32x64xf32>
-      %4 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %1, %2, %3 : vector<24x32x64xf32>, vector<24x64x64xf32> into vector<32x64xf32>
-      vector.transfer_write %4, %subview[%c0, %c0] {in_bounds = [true, true]} : vector<32x64xf32>, memref<32x64xf32, strided<[64, 1], offset: ?>>
-    }
-    return %alloc : memref<8x24x32x64xf32>
+memref.global "private" constant @__constant_24x64x64xf32 : memref<24x64x64xf32> = dense<1.000000e+00> {alignment = 64 : i64}
+func.func @gemm_without_tiling_so_no_hoisting(%arg0: memref<8x24x32x64xf32>) -> memref<8x24x32x64xf32> {
+  %cst = arith.constant 0.000000e+00 : f32
+  %cst_0 = arith.constant dense<0.000000e+00> : vector<32x64xf32>
+  %c0 = arith.constant 0 : index
+  %0 = memref.get_global @__constant_24x64x64xf32 : memref<24x64x64xf32>
+  %alloc = memref.alloc() {alignment = 64 : i64} : memref<8x24x32x64xf32>
+  scf.forall (%arg1, %arg2) in (8, 24) {
+    %subview = memref.subview %alloc[%arg1, %arg2, 0, 0] [1, 1, 32, 64] [1, 1, 1, 1] : memref<8x24x32x64xf32> to memref<32x64xf32, strided<[64, 1], offset: ?>>
+    vector.transfer_write %cst_0, %subview[%c0, %c0] {in_bounds = [true, true]} : vector<32x64xf32>, memref<32x64xf32, strided<[64, 1], offset: ?>>
+    %subview_1 = memref.subview %arg0[%arg1, 0, 0, 0] [1, 24, 32, 64] [1, 1, 1, 1] : memref<8x24x32x64xf32> to memref<24x32x64xf32, strided<[2048, 64, 1], offset: ?>>
+    %1 = vector.transfer_read %subview_1[%c0, %c0, %c0], %cst {in_bounds = [true, true, true]} : memref<24x32x64xf32, strided<[2048, 64, 1], offset: ?>>, vector<24x32x64xf32>
+    %2 = vector.transfer_read %0[%c0, %c0, %c0], %cst {in_bounds = [true, true, true]} : memref<24x64x64xf32>, vector<24x64x64xf32>
+    %3 = vector.transfer_read %subview[%c0, %c0], %cst {in_bounds = [true, true]} : memref<32x64xf32, strided<[64, 1], offset: ?>>, vector<32x64xf32>
+    %4 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %1, %2, %3 : vector<24x32x64xf32>, vector<24x64x64xf32> into vector<32x64xf32>
+    vector.transfer_write %4, %subview[%c0, %c0] {in_bounds = [true, true]} : vector<32x64xf32>, memref<32x64xf32, strided<[64, 1], offset: ?>>
   }
+  return %alloc : memref<8x24x32x64xf32>
 }
 
 // CHECK-LABEL: func.func @gemm_without_tiling_so_no_hoisting
@@ -135,36 +127,31 @@ module {
 // CHECK-NEXT: vector.transfer_write
 
 module attributes {transform.with_named_sequence} {
-  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
-    %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
-    transform.apply_patterns to %func {
-      transform.apply_patterns.vector.hoist_vector_transfer
-    } : !transform.any_op
-    transform.yield
-  }
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+   %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+   transform.apply_patterns to %func {
+     transform.apply_patterns.vector.hoist_vector_transfer
+   } : !transform.any_op
+   transform.yield
+ }
 }
 
-
 // -----
 
 #map = affine_map<(d0, d1, d2) -> (d0, d2)>
 #map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
 #map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
-
-module {
-  func.func @gemm_with_args_so_no_hoisting(%arg0: tensor<4x1xf32>, %arg1: tensor<1x64xf32>, %arg2: tensor<4x64xf32>) -> tensor<4x64xf32> {
-    %c0 = arith.constant 0 : index
-    %cst = arith.constant 0.000000e+00 : f32
-    %0 = vector.transfer_read %arg0[%c0, %c0], %cst {in_bounds = [true, true]} : tensor<4x1xf32>, vector<4x1xf32>
-    %1 = vector.transfer_read %arg1[%c0, %c0], %cst {in_bounds = [true, true]} : tensor<1x64xf32>, vector<1x64xf32>
-    %2 = vector.transfer_read %arg2[%c0, %c0], %cst {in_bounds = [true, true]} : tensor<4x64xf32>, vector<4x64xf32>
-    %3 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %0, %1, %2 : vector<4x1xf32>, vector<1x64xf32> into vector<4x64xf32>
-    %4 = vector.transfer_write %3, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<4x64xf32>, tensor<4x64xf32>
-    return %4 : tensor<4x64xf32>
-  }
+func.func @gemm_with_args_so_no_hoisting(%arg0: tensor<4x1xf32>, %arg1: tensor<1x64xf32>, %arg2: tensor<4x64xf32>) -> tensor<4x64xf32> {
+  %c0 = arith.constant 0 : index
+  %cst = arith.constant 0.000000e+00 : f32
+  %0 = vector.transfer_read %arg0[%c0, %c0], %cst {in_bounds = [true, true]} : tensor<4x1xf32>, vector<4x1xf32>
+  %1 = vector.transfer_read %arg1[%c0, %c0], %cst {in_bounds = [true, true]} : tensor<1x64xf32>, vector<1x64xf32>
+  %2 = vector.transfer_read %arg2[%c0, %c0], %cst {in_bounds = [true, true]} : tensor<4x64xf32>, vector<4x64xf32>
+  %3 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %0, %1, %2 : vector<4x1xf32>, vector<1x64xf32> into vector<4x64xf32>
+  %4 = vector.transfer_write %3, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<4x64xf32>, tensor<4x64xf32>
+  return %4 : tensor<4x64xf32>
 }
 
-
 // CHECK-LABEL: func.func @gemm_with_args_so_no_hoisting
 // CHECK: vector.transfer_read
 // CHECK-NEXT: vector.transfer_read
@@ -174,13 +161,11 @@ module {
 // CHECK-NEXT: return
 
 module attributes {transform.with_named_sequence} {
-  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
-    %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
-    transform.apply_patterns to %func {
-      transform.apply_patterns.vector.hoist_vector_transfer
-    } : !transform.any_op
-    transform.yield
-  }
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+   %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+   transform.apply_patterns to %func {
+     transform.apply_patterns.vector.hoist_vector_transfer
+   } : !transform.any_op
+   transform.yield
+ }
 }
-
-
diff --git a/mlir/test/Dialect/Linalg/vector-contract-to-fma.mlir b/mlir/test/Dialect/Linalg/vector-contract-to-fma.mlir
deleted file mode 100644
index f18c0dcb573d70..00000000000000
--- a/mlir/test/Dialect/Linalg/vector-contract-to-fma.mlir
+++ /dev/null
@@ -1,48 +0,0 @@
-// RUN: mlir-opt %s -transform-interpreter -cse -split-input-file | FileCheck %s
-
-#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
-#mapTransposeB = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
-#map2 = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
-
-  func.func @transpose_matrix_no_conversion_to_fma(%arg0: memref<16x32x128xf32>, %arg1: memref<16x128x64xf32>, %arg2: memref<32x64xf32>) {
-    %cst = arith.constant 0.000000e+00 : f32
-    %c1 = arith.constant 1 : index
-    %c16 = arith.constant 16 : index
-    %c64 = arith.constant 64 : index
-    %c128 = arith.constant 128 : index
-    %c4 = arith.constant 4 : index
-    %c32 = arith.constant 32 : index
-    %c0 = arith.constant 0 : index
-
-    scf.for %arg5 = %c0 to %c32 step %c4 {
-      scf.for %arg6 = %c0 to %c128 step %c64 {
-        %subview_2 = memref.subview %arg2[%arg5, %arg6] [4, 64] [1, 1] : memref<32x64xf32> to memref<4x64xf32, strided<[64, 1], offset: ?>>
-        %2 = vector.transfer_read %subview_2[%c0, %c0], %cst {in_bounds = [true, true]} : memref<4x64xf32, strided<[64, 1], offset: ?>>, vector<4x64xf32>
-        %con = scf.for %arg7 = %c0 to %c16 step %c1 iter_args(%argcon = %2) -> vector<4x64xf32> {
-          %con1 = scf.for %arg8 = %c0 to %c64 step %c1 iter_args(%argcon1 = %argcon) -> vector<4x64xf32> {
-            %subview_3 = memref.subview %arg0[%arg7, %arg5, %arg8] [1, 4, 1] [1, 1, 1] : memref<16x32x128xf32> to memref<1x4x1xf32, strided<[4096, 128, 1], offset: ?>>
-            %subview_4 = memref.subview %arg1[%arg7, %arg8, %arg6] [1, 1, 64] [1, 1, 1] : memref<16x128x64xf32> to memref<1x1x64xf32, strided<[8192, 64, 1], offset: ?>>
-            %0 = vector.transfer_read %subview_3[%c0, %c0, %c0], %cst {in_bounds = [true, true, true]} : memref<1x4x1xf32, strided<[4096, 128, 1], offset: ?>>, vector<1x4x1xf32>
-            %1 = vector.transfer_read %subview_4[%c0, %c0, %c0], %cst {permutation_map = affine_map<(d0, d1, d2) -> (d0, d2, d1)>, in_bounds = [true, true, true]} : memref<1x1x64xf32, strided<[8192, 64, 1], offset: ?>>, vector<1x64x1xf32>
-            %3 = vector.contract {indexing_maps = [#map, #mapTransposeB, #map2], iterator_types = ["reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %0, %1, %argcon1 : vector<1x4x1xf32>, vector<1x64x1xf32> into vector<4x64xf32>
-            scf.yield %3 : vector<4x64xf32>
-          }
-          scf.yield %con1 : vector<4x64xf32>
-        }
-        vector.transfer_write %con, %subview_2[%c0, %c0] {in_bounds = [true, true]} : vector<4x64xf32>, memref<4x64xf32, strided<[64, 1], offset: ?>>
-      }
-    }
-    return
-  }
-
-// CHECK-NOT: vector.fma
-
-  module attributes {transform.with_named_sequence} {
-    transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
-      %0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
-      transform.apply_patterns to %0 {
-        transform.apply_patterns.vector.contract_to_fma
-      } : !transform.any_op
-      transform.yield
-    }
-  }



More information about the Mlir-commits mailing list