[Mlir-commits] [mlir] ee0ac74 - [mlir][x86] Lower packed type vector.contract to AMX dot-product (#182810)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Mar 17 21:55:01 PDT 2026


Author: Arun Thangamani
Date: 2026-03-18T10:24:55+05:30
New Revision: ee0ac7443e4dc48f0ab2371dd5cbdcca32732e48

URL: https://github.com/llvm/llvm-project/commit/ee0ac7443e4dc48f0ab2371dd5cbdcca32732e48
DIFF: https://github.com/llvm/llvm-project/commit/ee0ac7443e4dc48f0ab2371dd5cbdcca32732e48.diff

LOG: [mlir][x86] Lower packed type vector.contract to AMX dot-product (#182810)

A transform pass to lower `vector.contract` operation to (a)
`amx.tile_mulf` for BF16, or (b) `amx.tile_muli` for Int8 packed types.

Added: 
    mlir/lib/Dialect/X86/Transforms/VectorContractToAMXDotProduct.cpp
    mlir/test/Dialect/X86/AMX/vector-contract-to-tiled-dp.mlir

Modified: 
    mlir/include/mlir/Dialect/X86/TransformOps/X86TransformOps.td
    mlir/include/mlir/Dialect/X86/Transforms.h
    mlir/lib/Dialect/X86/TransformOps/X86TransformOps.cpp
    mlir/lib/Dialect/X86/Transforms/CMakeLists.txt

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/X86/TransformOps/X86TransformOps.td b/mlir/include/mlir/Dialect/X86/TransformOps/X86TransformOps.td
index e7508f9ba4abb..c474cfb47d003 100644
--- a/mlir/include/mlir/Dialect/X86/TransformOps/X86TransformOps.td
+++ b/mlir/include/mlir/Dialect/X86/TransformOps/X86TransformOps.td
@@ -71,5 +71,16 @@ def ApplyShuffleVectorFMAOpsPatternsOp : Op<Transform_Dialect,
   let assemblyFormat = "attr-dict";
 }
 
+def ApplyVectorContractToAMXDotProductPatternsOp : Op<Transform_Dialect,
+    "apply_patterns.x86.vector_contract_to_amx_dot_product",
+    [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+  let description = [{
+    Collect patterns to lower a BF16/Int8 type vector.contract operation 
+	to a BF16/Int8 AMX tiled dot-product.
+  }];
+
+  let assemblyFormat = "attr-dict";
+}
+
 
 #endif // X86_TRANSFORM_OPS

diff  --git a/mlir/include/mlir/Dialect/X86/Transforms.h b/mlir/include/mlir/Dialect/X86/Transforms.h
index 2862e83f06f79..6ebba5e94ec7c 100644
--- a/mlir/include/mlir/Dialect/X86/Transforms.h
+++ b/mlir/include/mlir/Dialect/X86/Transforms.h
@@ -104,6 +104,12 @@ void populateSinkVectorProducerOpsPatterns(RewritePatternSet &patterns);
 // grouped with respect to odd/even packed index.
 void populateShuffleVectorFMAOpsPatterns(RewritePatternSet &patterns);
 
+// A set of patterns for lowering 32-bit packed vector contraction operations
+// to their corresponding packed-type tiled dot-product operations, using
+// AMX ultimately targeting the relevant x86 LLVM intrinsics (e.g., BF16 and
+// Int8).
+void populateVectorContractToAMXDotProductPatterns(RewritePatternSet &patterns);
+
 //===----------------------------------------------------------------------===//
 /// Helpers extracted from:
 ///   - clang/lib/Headers/avxintrin.h

diff  --git a/mlir/lib/Dialect/X86/TransformOps/X86TransformOps.cpp b/mlir/lib/Dialect/X86/TransformOps/X86TransformOps.cpp
index 2511f6f1b8b4c..390b21e12b0ed 100644
--- a/mlir/lib/Dialect/X86/TransformOps/X86TransformOps.cpp
+++ b/mlir/lib/Dialect/X86/TransformOps/X86TransformOps.cpp
@@ -47,6 +47,11 @@ void mlir::transform::ApplyShuffleVectorFMAOpsPatternsOp::populatePatterns(
   x86::populateShuffleVectorFMAOpsPatterns(patterns);
 }
 
+void mlir::transform::ApplyVectorContractToAMXDotProductPatternsOp::
+    populatePatterns(RewritePatternSet &patterns) {
+  x86::populateVectorContractToAMXDotProductPatterns(patterns);
+}
+
 //===----------------------------------------------------------------------===//
 // Transform op registration
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/X86/Transforms/CMakeLists.txt b/mlir/lib/Dialect/X86/Transforms/CMakeLists.txt
index efb6fde4fade1..9c3695536cda9 100644
--- a/mlir/lib/Dialect/X86/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/X86/Transforms/CMakeLists.txt
@@ -6,6 +6,7 @@ add_mlir_dialect_library(MLIRX86Transforms
   VectorContractBF16ToFMA.cpp
   SinkVectorProducerOps.cpp
   ShuffleVectorFMAOps.cpp
+  VectorContractToAMXDotProduct.cpp
 
   LINK_LIBS PUBLIC
   MLIRArithDialect

diff  --git a/mlir/lib/Dialect/X86/Transforms/VectorContractToAMXDotProduct.cpp b/mlir/lib/Dialect/X86/Transforms/VectorContractToAMXDotProduct.cpp
new file mode 100644
index 0000000000000..85966a85af40e
--- /dev/null
+++ b/mlir/lib/Dialect/X86/Transforms/VectorContractToAMXDotProduct.cpp
@@ -0,0 +1,675 @@
+//===- VectorContractToAMXDotProduct.cpp ----------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
+#include "mlir/Dialect/X86/Transforms.h"
+#include "mlir/Dialect/X86/Utils/X86Utils.h"
+#include "mlir/Dialect/X86/X86Dialect.h"
+
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/Dominance.h"
+#include "mlir/IR/PatternMatch.h"
+#include "llvm/Support/Casting.h"
+
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+using namespace mlir;
+using namespace mlir::vector;
+using namespace mlir::x86;
+
+namespace {
+
+// Function to collapse the last two dimension (vnni and k) to help the
+// amx.tile_load to correctly load the packed element type.
+static Value collapseInnerDims(OpBuilder &builder, mlir::Location loc,
+                               Value input) {
+  ShapedType inputType = cast<ShapedType>(input.getType());
+  int64_t firstDimToCollapse = inputType.getRank() - 2;
+
+  if (inputType.getRank() == 1)
+    return input;
+
+  SmallVector<ReassociationIndices> reassociation;
+  for (int64_t i = 0; i < firstDimToCollapse; ++i)
+    reassociation.push_back(ReassociationIndices{i});
+
+  ReassociationIndices collapsedIndices;
+  for (int64_t i = firstDimToCollapse; i < inputType.getRank(); ++i)
+    collapsedIndices.push_back(i);
+
+  reassociation.push_back(collapsedIndices);
+  return memref::CollapseShapeOp::create(builder, loc, input, reassociation);
+}
+
+// Get the MemRef source and offset index for the operands of
+// vector.contract.
+static FailureOr<std::pair<Value, SmallVector<Value>>>
+getSrcIndxValue(OpBuilder &rewriter, Location loc, Value operand,
+                bool isNotAcc) {
+  Operation *defOp = operand.getDefiningOp();
+  if (!defOp)
+    return failure();
+
+  Value srcBuff;
+  SmallVector<OpFoldResult> indexVals;
+  llvm::TypeSwitch<Operation *>(operand.getDefiningOp())
+      .Case<TransferReadOp, LoadOp>([&](auto readOp) {
+        indexVals = SmallVector<OpFoldResult>(readOp.getIndices().begin(),
+                                              readOp.getIndices().end());
+        srcBuff = readOp.getOperand(0);
+      });
+
+  if (!srcBuff)
+    return failure();
+
+  if (isNotAcc)
+    indexVals.pop_back();
+
+  SmallVector<Value> indices;
+  indices.reserve(indexVals.size());
+
+  for (OpFoldResult ofr : indexVals) {
+    indices.push_back(
+        mlir::getValueOrCreateConstantIndexOp(rewriter, loc, ofr));
+  }
+
+  if (isNotAcc) {
+    srcBuff = collapseInnerDims(rewriter, loc, srcBuff);
+  }
+
+  return std::make_pair(srcBuff, indices);
+}
+
+// Function to validate the vector.contract operation.
+static LogicalResult validateContractOps(OpBuilder &rewriter,
+                                         vector::ContractionOp contractOp,
+                                         unsigned int blockingFactor,
+                                         Value srcBuffLhs, Value srcBuffRhs,
+                                         bool srcValidate) {
+
+  if (srcValidate) {
+    // Get the MemRef buffer of LHS operand.
+    auto srcIndxLhs = getSrcIndxValue(rewriter, contractOp.getLoc(),
+                                      contractOp.getLhs(), false);
+    if (failed(srcIndxLhs))
+      return failure();
+    auto [buffLhs, indicesLhs] = *srcIndxLhs;
+
+    // Get the MemRef buffer of RHS operand.
+    auto srcIndxRhs = getSrcIndxValue(rewriter, contractOp.getLoc(),
+                                      contractOp.getRhs(), false);
+    if (failed(srcIndxRhs))
+      return failure();
+    auto [buffRhs, indicesRhs] = *srcIndxRhs;
+
+    // Return failure if the Memref buff didn't match.
+    if (buffLhs != srcBuffLhs)
+      return failure();
+
+    if (buffRhs != srcBuffRhs)
+      return failure();
+  }
+
+  VectorType accTy = dyn_cast<VectorType>(contractOp.getAccType());
+  if (!accTy)
+    return failure();
+
+  // The Accumulator dims should be 16 or 1. Like <1x16x16> or <16x16>.
+  ArrayRef<int64_t> accShape = accTy.getShape();
+  llvm::SmallVector<int64_t> nonUnitDimAcc;
+  llvm::copy_if(accShape, std::back_inserter(nonUnitDimAcc),
+                [](int64_t dim) { return (dim != 16 && dim != 1); });
+
+  if (nonUnitDimAcc.size() != 0)
+    return failure();
+
+  // The LHS dims should be 16 or vnni or 1. Like <1x16x16x2> or
+  // <16x16x4>. The vnni dims should be 2 or 4.
+  VectorType lhsTy = contractOp.getLhsType();
+  ArrayRef<int64_t> lhsShape = lhsTy.getShape();
+  llvm::SmallVector<int64_t> nonUnitDimLhs;
+  llvm::copy_if(lhsShape, std::back_inserter(nonUnitDimLhs),
+                [](int64_t dim) { return (dim != 16 && dim != 1); });
+
+  if (nonUnitDimLhs.size() != 1)
+    return failure();
+
+  if (nonUnitDimLhs[0] != blockingFactor)
+    return failure();
+
+  // The RHS dims should be 16 or vnni or 1. Like <1x16x16x2> or
+  // <16x16x4>. The vnni dims should be 2 or 4.
+  VectorType rhsTy = contractOp.getRhsType();
+  ArrayRef<int64_t> rhsShape = rhsTy.getShape();
+  llvm::SmallVector<int64_t> nonUnitDimRhs;
+  llvm::copy_if(rhsShape, std::back_inserter(nonUnitDimRhs),
+                [](int64_t dim) { return (dim != 16 && dim != 1); });
+
+  if (nonUnitDimRhs.size() != 1)
+    return failure();
+
+  if (nonUnitDimRhs[0] != blockingFactor)
+    return failure();
+
+  return success();
+}
+
+// Returns the loop index position to get mapped during the
+// MemRef type clone.
+static unsigned getIndexPosition(Value operand, scf::ForOp loop) {
+  Value iv = loop.getInductionVar();
+
+  Value srcBuff;
+  llvm::TypeSwitch<Operation *>(operand.getDefiningOp())
+      .Case<TransferReadOp, LoadOp>(
+          [&](auto readOp) { srcBuff = readOp.getOperand(0); });
+
+  auto subview = srcBuff.getDefiningOp<memref::SubViewOp>();
+  if (!subview)
+    return 0;
+
+  auto offsets = subview.getOffsets();
+
+  for (auto it : llvm::enumerate(offsets)) {
+    if (it.value() == iv)
+      return it.index();
+  }
+
+  return 0;
+}
+
+// Creates amx.tile_loads.
+static amx::TileLoadOp createTileLoads(OpBuilder &rewriter, Location loc,
+                                       Value operand, Value mat, Type ipType,
+                                       bool rhs, unsigned int offset) {
+
+  auto srcIndx = getSrcIndxValue(rewriter, loc, operand, false);
+  auto [srcBuff, indices] = *srcIndx;
+  indices.pop_back();
+
+  if (rhs) {
+    auto cOffset = arith::ConstantIndexOp::create(rewriter, loc, offset);
+    indices[indices.size() - 1] = arith::MulIOp::create(
+        rewriter, loc, indices[indices.size() - 1], cOffset);
+  }
+
+  amx::TileType tileType = amx::TileType::get({16, (16 * offset)}, ipType);
+  return amx::TileLoadOp::create(rewriter, loc, tileType, mat, indices);
+}
+
+// Creates tiled amx dot-products.
+static SmallVector<Value> createTiledDp(OpBuilder &rewriter, Location loc,
+                                        SmallVector<vector::ContractionOp> ops,
+                                        Value matA, Value matB, Type ipType,
+                                        Type opType, ValueRange accIterArgs,
+                                        unsigned int offset) {
+
+  auto subviewCollapseLhs = collapseInnerDims(rewriter, loc, matA);
+  auto subviewCollapseRhs = collapseInnerDims(rewriter, loc, matB);
+
+  SmallVector<Value> accumulators;
+  // Stores the amx.tile_load operation vs it's equivalent vector tranfer_read
+  // or load operations.
+  llvm::DenseMap<Operation *, amx::TileLoadOp> readsToTileLoads;
+
+  // Iterate over the contraction operations and compute the tiled dot-product.
+  for (size_t i = 0; i < ops.size(); i++) {
+
+    Operation *readOpLhs = ops[i].getLhs().getDefiningOp();
+    amx::TileLoadOp tilesLhs;
+    auto itLhs = readsToTileLoads.find(readOpLhs);
+    if (itLhs != readsToTileLoads.end()) {
+      tilesLhs = itLhs->second;
+    } else {
+      tilesLhs = createTileLoads(rewriter, loc, ops[i].getLhs(),
+                                 subviewCollapseLhs, ipType, false, offset);
+      readsToTileLoads.try_emplace(readOpLhs, tilesLhs);
+    }
+
+    Operation *readOpRhs = ops[i].getRhs().getDefiningOp();
+    amx::TileLoadOp tilesRhs;
+    auto itRhs = readsToTileLoads.find(readOpRhs);
+    if (itRhs != readsToTileLoads.end()) {
+      tilesRhs = itRhs->second;
+    } else {
+      tilesRhs = createTileLoads(rewriter, loc, ops[i].getRhs(),
+                                 subviewCollapseRhs, ipType, true, offset);
+      readsToTileLoads.try_emplace(readOpRhs, tilesRhs);
+    }
+
+    auto accTileType = amx::TileType::get({16, 16}, opType);
+
+    Value dp;
+    if (ipType.isBF16())
+      dp = amx::TileMulFOp::create(rewriter, loc, accTileType, tilesLhs,
+                                   tilesRhs, accIterArgs[i]);
+
+    if (ipType.isSignlessInteger(8))
+      dp = amx::TileMulIOp::create(rewriter, loc, accTileType, tilesLhs,
+                                   tilesRhs, accIterArgs[i]);
+
+    accumulators.push_back(dp);
+  }
+  return accumulators;
+}
+
+static SmallVector<Value> createTileZeros(OpBuilder &rewriter, Location loc,
+                                          Type opType, scf::ForOp outerLoop,
+                                          int64_t size) {
+  rewriter.setInsertionPoint(outerLoop);
+
+  SmallVector<Value> loopItrArgs;
+  auto zeroTileType = amx::TileType::get({16, 16}, opType);
+
+  for (int i = 0; i < size; i++) {
+    auto zeroTile = amx::TileZeroOp::create(rewriter, loc, zeroTileType);
+    loopItrArgs.push_back(zeroTile);
+  }
+  return loopItrArgs;
+}
+
+// Implements tiled dot-product operation for a vector.contract operation or a
+// sequence of vector.contracts inside the reduction loops.
+//
+// For example - for F32 type:
+// ```
+//   vector.transfer_read %arg0 {{.}*} : memref<16x32x4xi8>, vector<16x16x4xi8>
+//   vector.transfer_read %arg1 {{.}*} : memref<16x32x4xi8>, vector<16x16x4xi8>
+//   vector.contract <16x16x4xi8>, <16x16x4xi8> into <16x16xi32>
+//   vector.transfer_write arg2 {{.}*} : vector<16x16xi32>, memref<32x32xi32>
+// ```
+// to
+// ```
+//   amx.tile_load %arg0 {{.}*} : memref<16x32x4xi8> into !amx.tile<16x64xi8>
+//   amx.tile_load %arg1 {{.}*} : memref<16x32x4xi8> into !amx.tile<16x64xi8>
+//   amx.tile_muli !amx.tile<16x64xi8> -> !amx.tile<16x16xi32>
+//   amx.tile_store %arg2{{.}*} : memref<32x32xi32>, !amx.tile<16x16xi32>
+// ```
+struct VectorContractToAMXDotProduct
+    : public OpRewritePattern<vector::ContractionOp> {
+  using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
+                                PatternRewriter &rewriter) const override {
+
+    if (contractOp.getKind() != vector::CombiningKind::ADD)
+      return rewriter.notifyMatchFailure(contractOp,
+                                         "Expects add combining kind.");
+
+    unsigned int blockingFactor =
+        contractOp.getLhsType().getElementType().isBF16() ? 2 : 4;
+    bool isVnni =
+        isInVnniLayout(contractOp.getOperation(),
+                       contractOp.getIndexingMapsArray(), blockingFactor);
+
+    VectorType lhsTy = contractOp.getLhsType();
+    if (!lhsTy.getElementType().isBF16() &&
+        !lhsTy.getElementType().isSignlessInteger(8))
+      return rewriter.notifyMatchFailure(
+          contractOp, "Only BF16/Int8 lowering is supported.");
+
+    VectorType accTy = dyn_cast<VectorType>(contractOp.getAccType());
+    if (!accTy)
+      return rewriter.notifyMatchFailure(contractOp, "Wrong accmulator type.");
+
+    if ((lhsTy.getElementType().isBF16() && !accTy.getElementType().isF32()) ||
+        (lhsTy.getElementType().isSignlessInteger(8) &&
+         !accTy.getElementType().isSignlessInteger(32)))
+      return rewriter.notifyMatchFailure(contractOp,
+                                         "Only F32 for BF16 or Int32 for Int8 "
+                                         "accumulation type is supported.");
+    if (!isVnni)
+      return rewriter.notifyMatchFailure(
+          contractOp, "Only VNNI-packed inputs are supported.");
+
+    Operation *accReadOp =
+        traceToVectorReadLikeParentOperation(contractOp.getAcc());
+
+    Operation *resultWriteOp =
+        traceToVectorWriteLikeUserOperation(contractOp.getResult());
+
+    if (!accReadOp || !resultWriteOp)
+      return rewriter.notifyMatchFailure(
+          contractOp, "The ACC operand of the vector.contract should be a "
+                      "transfer_read or a load. And, the result should be "
+                      "stored using transfer_write or store.");
+
+    Type ipType = rewriter.getBF16Type();
+    Type opType = rewriter.getF32Type();
+
+    if (lhsTy.getElementType().isSignlessInteger(8)) {
+      ipType = rewriter.getIntegerType(8);
+      opType = rewriter.getIntegerType(32);
+    }
+
+    if (accReadOp->getBlock() == contractOp->getBlock() &&
+        resultWriteOp->getBlock() != contractOp->getBlock())
+      return rewriter.notifyMatchFailure(
+          contractOp, "The accumulator store is in 
diff erent block.");
+
+    if (accReadOp->getBlock() != contractOp->getBlock() &&
+        resultWriteOp->getBlock() == contractOp->getBlock())
+      return rewriter.notifyMatchFailure(
+          contractOp, "The accumulator read is in 
diff erent block.");
+
+    // Case 1: For just one VC rewrite. Where all accumulator read/write
+    // within the same block.
+    if (accReadOp->getBlock() == contractOp->getBlock() &&
+        resultWriteOp->getBlock() == contractOp->getBlock()) {
+
+      LogicalResult validate = validateContractOps(
+          rewriter, contractOp, blockingFactor, Value(), Value(), false);
+
+      if (failed(validate))
+        return rewriter.notifyMatchFailure(
+            contractOp, "The contract operation doesn't satisfy the operands "
+                        "dimensions. M, N, and vnni dims are 16, 16, and 2/4. "
+                        "The rest dims should be 1.");
+
+      Location loc = contractOp.getLoc();
+
+      auto srcIndxLhs = getSrcIndxValue(rewriter, contractOp.getLoc(),
+                                        contractOp.getLhs(), true);
+      if (failed(srcIndxLhs))
+        return rewriter.notifyMatchFailure(contractOp,
+                                           "The LHS src is not a MemRef type.");
+      auto [srcBuffLhs, indicesLhs] = *srcIndxLhs;
+
+      auto srcIndxRhs = getSrcIndxValue(rewriter, contractOp.getLoc(),
+                                        contractOp.getRhs(), true);
+      if (failed(srcIndxRhs))
+        return rewriter.notifyMatchFailure(contractOp,
+                                           "The RHS src is not a MemRef type.");
+      auto [srcBuffRhs, indicesRhs] = *srcIndxRhs;
+
+      auto srcIndxAcc = getSrcIndxValue(rewriter, contractOp.getLoc(),
+                                        contractOp.getAcc(), false);
+      if (failed(srcIndxAcc))
+        return rewriter.notifyMatchFailure(contractOp,
+                                           "The ACC src is not a MemRef type.");
+      auto [srcBuffAcc, indicesAcc] = *srcIndxAcc;
+
+      // amx.tile_loads
+      auto tileType = amx::TileType::get({16, (16 * blockingFactor)}, ipType);
+      auto loadLhs = amx::TileLoadOp::create(rewriter, loc, tileType,
+                                             srcBuffLhs, indicesLhs);
+      auto loadRhs = amx::TileLoadOp::create(rewriter, loc, tileType,
+                                             srcBuffRhs, indicesRhs);
+
+      auto tileTypeAcc = amx::TileType::get({16, 16}, opType);
+      auto loadAcc = amx::TileLoadOp::create(rewriter, loc, tileTypeAcc,
+                                             srcBuffAcc, indicesAcc);
+
+      // Tiled dot-product.
+      Value dp;
+      if (ipType.isBF16())
+        dp = amx::TileMulFOp::create(rewriter, loc, tileTypeAcc, loadLhs,
+                                     loadRhs, loadAcc);
+
+      if (ipType.isSignlessInteger(8))
+        dp = amx::TileMulIOp::create(rewriter, loc, tileTypeAcc, loadLhs,
+                                     loadRhs, loadAcc);
+
+      amx::TileStoreOp::create(rewriter, loc, srcBuffAcc, indicesAcc, dp);
+
+      rewriter.eraseOp(resultWriteOp);
+      return success();
+    }
+
+    // Case 2: The acc are passed as iter args through the reduction loop.
+    // We support, reduction loop depth until 2. TODO: Support for n-depth
+    // reduction loop.
+    SmallVector<scf::ForOp> loopLists;
+    Operation *current = contractOp;
+
+    while (true) {
+      Operation *parent = current->getParentOfType<scf::ForOp>();
+      loopLists.push_back(dyn_cast<scf::ForOp>(parent));
+
+      if (accReadOp->getBlock() == parent->getBlock()) {
+        break;
+      }
+
+      current = parent;
+    }
+
+    if (loopLists.size() > 2 || loopLists.size() == 0)
+      return rewriter.notifyMatchFailure(
+          contractOp, "Rewrite is supported until reduction loop depth of 2.");
+
+    auto srcIndxLhs = getSrcIndxValue(rewriter, contractOp.getLoc(),
+                                      contractOp.getLhs(), false);
+    if (failed(srcIndxLhs))
+      return rewriter.notifyMatchFailure(contractOp,
+                                         "The LHS src is not a MemRef type.");
+    auto [srcBuffLhs, indicesLhs] = *srcIndxLhs;
+
+    auto srcIndxRhs = getSrcIndxValue(rewriter, contractOp.getLoc(),
+                                      contractOp.getRhs(), false);
+    if (failed(srcIndxRhs))
+      return rewriter.notifyMatchFailure(contractOp,
+                                         "The RHS src is not a MemRef type.");
+    auto [srcBuffRhs, indicesRhs] = *srcIndxRhs;
+
+    Operation *vectorOpLhs;
+    llvm::TypeSwitch<Operation *>(contractOp.getLhs().getDefiningOp())
+        .Case<TransferReadOp, LoadOp>([&](auto readOp) {
+          vectorOpLhs = readOp.getBase().getDefiningOp();
+        });
+
+    Operation *vectorOpRhs;
+    llvm::TypeSwitch<Operation *>(contractOp.getRhs().getDefiningOp())
+        .Case<TransferReadOp, LoadOp>([&](auto readOp) {
+          vectorOpRhs = readOp.getBase().getDefiningOp();
+        });
+
+    // Retrive all the contaction operation within the loop.
+    SmallVector<vector::ContractionOp> ops;
+    for (mlir::Operation &op : loopLists[0].getBody()->getOperations()) {
+
+      if (auto contract = llvm::dyn_cast<mlir::vector::ContractionOp>(op)) {
+
+        LogicalResult validate = validateContractOps(
+            rewriter, contract, blockingFactor, srcBuffLhs, srcBuffRhs, true);
+
+        if (failed(validate))
+          return rewriter.notifyMatchFailure(
+              contractOp, "The associated contract operations doesn't satisfy "
+                          "the re-write conditions either the dimensions are "
+                          "wrong or MemRef source are 
diff erent.");
+
+        ops.push_back(contract);
+      }
+    }
+
+    scf::ForOp outerLoop;
+    scf::ForOp innerLoop;
+
+    scf::ForOp newLoop;
+    // Case 2a: Reduction loop depth is 2.
+    if (loopLists.size() == 2) {
+      outerLoop = loopLists[1];
+      innerLoop = loopLists[0];
+
+      SmallVector<Value> loopItrArgs = createTileZeros(
+          rewriter, outerLoop.getLoc(), opType, outerLoop, ops.size());
+
+      newLoop = scf::ForOp::create(
+          rewriter, outerLoop.getLoc(), outerLoop.getLowerBound(),
+          outerLoop.getUpperBound(), outerLoop.getStep(), loopItrArgs,
+          [&](OpBuilder &rewriterOuterLoop, Location locOuterLoop,
+              Value ivOuterLoop, ValueRange iterArgsOuterLoop) {
+            auto newInnerLoop = scf::ForOp::create(
+                rewriter, innerLoop.getLoc(), innerLoop.getLowerBound(),
+                innerLoop.getUpperBound(), innerLoop.getStep(),
+                iterArgsOuterLoop,
+                [&](OpBuilder &rewriterNewInnerLoop, Location locNewInnerLoop,
+                    Value ivNewInnerLoop, ValueRange iterArgsNewInnerLoop) {
+                  IRMapping mapping;
+                  mapping.map(
+                      vectorOpLhs->getOperand(
+                          getIndexPosition(contractOp.getLhs(), outerLoop) + 1),
+                      ivOuterLoop);
+                  mapping.map(
+                      vectorOpLhs->getOperand(
+                          getIndexPosition(contractOp.getLhs(), innerLoop) + 1),
+                      ivNewInnerLoop);
+                  auto lhsClone =
+                      rewriterNewInnerLoop.clone(*vectorOpLhs, mapping);
+
+                  IRMapping rhsMapping;
+                  rhsMapping.map(
+                      vectorOpRhs->getOperand(
+                          getIndexPosition(contractOp.getRhs(), outerLoop) + 1),
+                      ivOuterLoop);
+                  rhsMapping.map(
+                      vectorOpRhs->getOperand(
+                          getIndexPosition(contractOp.getRhs(), innerLoop) + 1),
+                      ivNewInnerLoop);
+                  auto rhsClone =
+                      rewriterNewInnerLoop.clone(*vectorOpRhs, rhsMapping);
+
+                  SmallVector<Value> accumulators = createTiledDp(
+                      rewriter, locNewInnerLoop, ops, lhsClone->getResult(0),
+                      rhsClone->getResult(0), ipType, opType,
+                      iterArgsNewInnerLoop, blockingFactor);
+
+                  scf::YieldOp::create(rewriterNewInnerLoop, locNewInnerLoop,
+                                       accumulators);
+                });
+
+            scf::YieldOp::create(rewriterOuterLoop, locOuterLoop,
+                                 newInnerLoop.getResults());
+          });
+    }
+
+    // Case 2b: Reduction loop depth is 1.
+    if (loopLists.size() == 1) {
+      outerLoop = loopLists[0];
+
+      SmallVector<Value> loopItrArgs = createTileZeros(
+          rewriter, outerLoop.getLoc(), opType, outerLoop, ops.size());
+      newLoop = scf::ForOp::create(
+          rewriter, outerLoop.getLoc(), outerLoop.getLowerBound(),
+          outerLoop.getUpperBound(), outerLoop.getStep(), loopItrArgs,
+          [&](OpBuilder &rewriterOuterLoop, Location locOuterLoop,
+              Value ivOuterLoop, ValueRange iterArgsOuterLoop) {
+            IRMapping mapping;
+            mapping.map(
+                vectorOpLhs->getOperand(
+                    getIndexPosition(contractOp.getLhs(), outerLoop) + 1),
+                ivOuterLoop);
+
+            auto lhsClone = rewriterOuterLoop.clone(*vectorOpLhs, mapping);
+
+            IRMapping rhsMapping;
+            rhsMapping.map(
+                vectorOpRhs->getOperand(
+                    getIndexPosition(contractOp.getRhs(), outerLoop) + 1),
+                ivOuterLoop);
+
+            auto rhsClone = rewriterOuterLoop.clone(*vectorOpRhs, rhsMapping);
+
+            SmallVector<Value> accumulators = createTiledDp(
+                rewriter, locOuterLoop, ops, lhsClone->getResult(0),
+                rhsClone->getResult(0), ipType, opType, iterArgsOuterLoop,
+                blockingFactor);
+
+            scf::YieldOp::create(rewriterOuterLoop, locOuterLoop, accumulators);
+          });
+    }
+
+    // post processing after the loop creation.
+    // Copy the amx tile accumulation results to a MemRef buffer, add the
+    // initial accumulation value, and store back to the C-Matrix
+    auto bufferType = MemRefType::get({16, 16}, opType);
+    auto bBuffer =
+        memref::AllocaOp::create(rewriter, outerLoop.getLoc(), bufferType);
+
+    SmallVector<Value> dps = newLoop.getResults();
+    for (size_t i = 0; i < ops.size(); i++) {
+      vector::ContractionOp contOp = ops[i];
+      Operation *resultWriteOp =
+          traceToVectorWriteLikeUserOperation(contOp.getResult());
+      rewriter.setInsertionPoint(resultWriteOp);
+
+      Value indexOp_0 =
+          arith::ConstantIndexOp::create(rewriter, outerLoop.getLoc(), 0);
+
+      amx::TileStoreOp::create(rewriter, outerLoop.getLoc(), bBuffer,
+                               ValueRange{indexOp_0, indexOp_0}, dps[i]);
+
+      auto c0 = arith::ConstantIndexOp::create(rewriter, outerLoop.getLoc(), 0);
+      auto one =
+          arith::ConstantIndexOp::create(rewriter, outerLoop.getLoc(), 1);
+      auto mBound =
+          arith::ConstantIndexOp::create(rewriter, outerLoop.getLoc(), 16);
+
+      scf::ForOp::create(
+          rewriter, outerLoop.getLoc(), c0, mBound, one, ValueRange{},
+          [&](OpBuilder &builder, Location loc, Value iv, ValueRange iterArgs) {
+            auto resultAcc = vector::LoadOp::create(
+                rewriter, loc, VectorType::get(16, opType), bBuffer,
+                ValueRange{iv, c0});
+
+            Operation *accReadOp =
+                traceToVectorReadLikeParentOperation(ops[i].getAcc());
+
+            Value srcBuffAcc;
+            SmallVector<Value> indicesAcc;
+
+            llvm::TypeSwitch<Operation *>(accReadOp)
+                .Case<TransferReadOp, LoadOp>([&](auto readOp) {
+                  srcBuffAcc = readOp.getOperand(0);
+
+                  auto indices = readOp.getIndices();
+                  indicesAcc.reserve(indices.size());
+
+                  llvm::transform(
+                      indices, std::back_inserter(indicesAcc),
+                      [&](OpFoldResult ofr) {
+                        return mlir::getValueOrCreateConstantIndexOp(rewriter,
+                                                                     loc, ofr);
+                      });
+                });
+
+            Value sum = arith::AddIOp::create(builder, loc, iv, indicesAcc[0]);
+            indicesAcc[indicesAcc.size() - 2] = sum;
+
+            auto acc = vector::LoadOp::create(rewriter, loc,
+                                              VectorType::get(16, opType),
+                                              srcBuffAcc, indicesAcc);
+            Value addition;
+            if (ipType.isBF16())
+              addition = arith::AddFOp::create(rewriter, loc, resultAcc, acc);
+
+            if (ipType.isSignlessInteger(8))
+              addition = arith::AddIOp::create(rewriter, loc, resultAcc, acc);
+
+            vector::StoreOp::create(builder, loc, addition, srcBuffAcc,
+                                    indicesAcc);
+
+            scf::YieldOp::create(builder, outerLoop.getLoc());
+          });
+
+      rewriter.eraseOp(resultWriteOp);
+    }
+
+    return success();
+  }
+};
+
+} // namespace
+
+void x86::populateVectorContractToAMXDotProductPatterns(
+    RewritePatternSet &patterns) {
+  patterns.add<VectorContractToAMXDotProduct>(patterns.getContext());
+}

diff  --git a/mlir/test/Dialect/X86/AMX/vector-contract-to-tiled-dp.mlir b/mlir/test/Dialect/X86/AMX/vector-contract-to-tiled-dp.mlir
new file mode 100644
index 0000000000000..cde15b680a037
--- /dev/null
+++ b/mlir/test/Dialect/X86/AMX/vector-contract-to-tiled-dp.mlir
@@ -0,0 +1,1034 @@
+// RUN: mlir-opt %s -transform-interpreter -cse -split-input-file | FileCheck %s
+
+!vecA = vector<1x16x16x4xi8>
+!vecB = vector<1x16x16x4xi8>
+!vecC = vector<16x16xi32>
+!memrefA = memref<1x32x16x4xi8>
+!memrefB = memref<1x16x32x4xi8>
+!memrefC = memref<32x32xi32>
+#map = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d3, d4)>
+#map1 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d3, d2, d4)>
+#map2 = affine_map<(d0, d4, d1, d2, d3) -> (d1, d2)>
+func.func @brgemm_int8(
+  %arg0: !memrefA, %arg1: !memrefB, %arg2: !memrefC) -> !memrefC
+{
+  %c0 = arith.constant 0 : index
+  %0 = ub.poison : i8
+  %32 = ub.poison : i32
+
+  %1 = vector.transfer_read %arg0[%c0, %c0, %c0, %c0], %0 {in_bounds = [true, true, true, true]} :
+        !memrefA, !vecA
+  %2 = vector.transfer_read %arg1[%c0, %c0, %c0, %c0], %0 {in_bounds = [true, true, true, true]} :
+        !memrefB, !vecB
+
+  %3 = vector.transfer_read %arg2[%c0, %c0], %32 {in_bounds = [true, true]} : !memrefC, !vecC
+
+  %4 = vector.contract {
+    indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["reduction", "reduction", "parallel", "parallel", "reduction"],
+    kind = #vector.kind<add>}
+    %1, %2, %3 : !vecA, !vecB into !vecC
+
+  vector.transfer_write %4, %arg2[%c0, %c0] {in_bounds = [true, true]} : !vecC, !memrefC
+
+  return %arg2 : !memrefC
+}
+
+// CHECK-LABEL: @brgemm_int8
+// CHECK: x86.amx.tile_load {{.*}} !x86.amx.tile<16x64xi8>
+// CHECK: x86.amx.tile_load {{.*}} !x86.amx.tile<16x64xi8>
+// CHECK: x86.amx.tile_load {{.*}} !x86.amx.tile<16x16xi32>
+// CHECK: x86.amx.tile_muli
+// CHECK: x86.amx.tile_store {{.*}} !x86.amx.tile<16x16xi32>
+// CHECK-NOT: vector.contract
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.x86.vector_contract_to_amx_dot_product
+    } : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+!vecA = vector<16x16x4xi8>
+!vecB = vector<16x16x4xi8>
+!vecC = vector<16x16xi32>
+!memrefA = memref<32x16x4xi8>
+!memrefB = memref<16x32x4xi8>
+!memrefC = memref<32x32xi32>
+#map = affine_map<(d4, d1, d2, d3) -> (d1, d3, d4)>
+#map1 = affine_map<(d4, d1, d2, d3) -> (d3, d2, d4)>
+#map2 = affine_map<(d4, d1, d2, d3) -> (d1, d2)>
+func.func @matmul_int8(
+  %arg0: !memrefA, %arg1: !memrefB, %arg2: !memrefC) -> !memrefC
+{
+  %c0 = arith.constant 0 : index
+  %0 = ub.poison : i8
+  %32 = ub.poison : i32
+
+  %1 = vector.transfer_read %arg0[%c0, %c0, %c0], %0 {in_bounds = [true, true, true]} :
+        !memrefA, !vecA
+  %2 = vector.transfer_read %arg1[%c0, %c0, %c0], %0 {in_bounds = [true, true, true]} :
+        !memrefB, !vecB
+
+  %3 = vector.transfer_read %arg2[%c0, %c0], %32 {in_bounds = [true, true]} : !memrefC, !vecC
+
+  %4 = vector.contract {
+    indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["reduction", "parallel", "parallel", "reduction"],
+    kind = #vector.kind<add>}
+    %1, %2, %3 : !vecA, !vecB into !vecC
+
+  vector.transfer_write %4, %arg2[%c0, %c0] {in_bounds = [true, true]} : !vecC, !memrefC
+
+  return %arg2 : !memrefC
+}
+
+// CHECK-LABEL: @matmul_int8
+// CHECK: x86.amx.tile_load {{.*}} !x86.amx.tile<16x64xi8>
+// CHECK: x86.amx.tile_load {{.*}} !x86.amx.tile<16x64xi8>
+// CHECK: x86.amx.tile_load {{.*}} !x86.amx.tile<16x16xi32>
+// CHECK: x86.amx.tile_muli
+// CHECK: x86.amx.tile_store {{.*}} !x86.amx.tile<16x16xi32>
+// CHECK-NOT: vector.contract
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.x86.vector_contract_to_amx_dot_product
+    } : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+!vecA = vector<1x16x16x2xbf16>
+!vecB = vector<1x16x16x2xbf16>
+!vecC = vector<16x16xf32>
+!memrefA = memref<1x32x16x2xbf16>
+!memrefB = memref<1x16x32x2xbf16>
+!memrefC = memref<32x32xf32>
+#map = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d3, d4)>
+#map1 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d3, d2, d4)>
+#map2 = affine_map<(d0, d4, d1, d2, d3) -> (d1, d2)>
+func.func @brgemm_bf16(
+  %arg0: !memrefA, %arg1: !memrefB, %arg2: !memrefC) -> !memrefC
+{
+  %c0 = arith.constant 0 : index
+  %0 = ub.poison : bf16
+  %32 = ub.poison : f32
+
+  %1 = vector.transfer_read %arg0[%c0, %c0, %c0, %c0], %0 {in_bounds = [true, true, true, true]} :
+        !memrefA, !vecA
+  %2 = vector.transfer_read %arg1[%c0, %c0, %c0, %c0], %0 {in_bounds = [true, true, true, true]} :
+        !memrefB, !vecB
+
+  %3 = vector.transfer_read %arg2[%c0, %c0], %32 {in_bounds = [true, true]} : !memrefC, !vecC
+
+  %4 = vector.contract {
+    indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["reduction", "reduction", "parallel", "parallel", "reduction"],
+    kind = #vector.kind<add>}
+    %1, %2, %3 : !vecA, !vecB into !vecC
+
+  vector.transfer_write %4, %arg2[%c0, %c0] {in_bounds = [true, true]} : !vecC, !memrefC
+
+  return %arg2 : !memrefC
+}
+
+// CHECK-LABEL: @brgemm_bf16
+// CHECK: x86.amx.tile_load {{.*}} !x86.amx.tile<16x32xbf16>
+// CHECK: x86.amx.tile_load {{.*}} !x86.amx.tile<16x32xbf16>
+// CHECK: x86.amx.tile_load {{.*}} !x86.amx.tile<16x16xf32>
+// CHECK: x86.amx.tile_mulf
+// CHECK: x86.amx.tile_store {{.*}} !x86.amx.tile<16x16xf32>
+// CHECK-NOT: vector.contract
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.x86.vector_contract_to_amx_dot_product
+    } : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+!vecA = vector<1x16x16x2xbf16>
+!vecB = vector<1x16x16x2xbf16>
+!vecC = vector<1x16x16xf32>
+!memrefA = memref<1x32x16x2xbf16>
+!memrefB = memref<1x16x32x2xbf16>
+!memrefC = memref<1x32x32xf32>
+#map = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d3, d4)>
+#map1 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d3, d2, d4)>
+#map2 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d2)>
+func.func @batch_matmul_bf16(
+  %arg0: !memrefA, %arg1: !memrefB, %arg2: !memrefC) -> !memrefC
+{
+  %c0 = arith.constant 0 : index
+  %0 = ub.poison : bf16
+  %32 = ub.poison : f32
+
+  %1 = vector.transfer_read %arg0[%c0, %c0, %c0, %c0], %0 {in_bounds = [true, true, true, true]} :
+        !memrefA, !vecA
+  %2 = vector.transfer_read %arg1[%c0, %c0, %c0, %c0], %0 {in_bounds = [true, true, true, true]} :
+        !memrefB, !vecB
+
+  %3 = vector.transfer_read %arg2[%c0, %c0, %c0], %32 {in_bounds = [true, true, true]} : !memrefC, !vecC
+
+  %4 = vector.contract {
+    indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["parallel", "reduction", "parallel", "parallel", "reduction"],
+    kind = #vector.kind<add>}
+    %1, %2, %3 : !vecA, !vecB into !vecC
+
+  vector.transfer_write %4, %arg2[%c0, %c0, %c0] {in_bounds = [true, true, true]} : !vecC, !memrefC
+
+  return %arg2 : !memrefC
+}
+
+// CHECK-LABEL: @batch_matmul_bf16
+// CHECK: x86.amx.tile_load {{.*}} !x86.amx.tile<16x32xbf16>
+// CHECK: x86.amx.tile_load {{.*}} !x86.amx.tile<16x32xbf16>
+// CHECK: x86.amx.tile_load {{.*}} !x86.amx.tile<16x16xf32>
+// CHECK: x86.amx.tile_mulf
+// CHECK: x86.amx.tile_store {{.*}} !x86.amx.tile<16x16xf32>
+// CHECK-NOT: vector.contract
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.x86.vector_contract_to_amx_dot_product
+    } : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+!vecAB = vector<1x16x16x2xbf16>
+!vecC = vector<16x16xf32>
+!memrefA = memref<1x32x16x2xbf16, strided<[8192, 128, 2, 1], offset: ?>>
+!memrefB = memref<1x16x32x2xbf16, strided<[16384, 256, 2, 1], offset: ?>>
+!memrefC = memref<32x32xf32, strided<[128, 1], offset: ?>>
+
+#map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d2, d4, d1)>
+#map1 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d4, d3, d1)>
+#map2 = affine_map<(d0, d1, d2, d3, d4) -> (d2, d3)>
+
+func.func @brgemm_bf16_loop(%arg0: memref<16x64x64x2xbf16>, %arg1: memref<16x64x128x2xbf16>, %arg2: memref<64x128xf32>) -> memref<64x128xf32>  {
+  %0 = ub.poison : f32
+  %1 = ub.poison : bf16
+  %c0 = arith.constant 0 : index
+  %c64 = arith.constant 64 : index
+  %c128 = arith.constant 128 : index
+  %c16 = arith.constant 16 : index
+  %c32 = arith.constant 32 : index
+  %c1 = arith.constant 1 : index
+
+  scf.for %arg3 = %c0 to %c64 step %c32 {
+    scf.for %arg4 = %c0 to %c128 step %c32 {
+
+      %subview = memref.subview %arg2[%arg3, %arg4] [32, 32] [1, 1] :
+                memref<64x128xf32> to !memrefC
+      %2 = vector.transfer_read %subview[%c0, %c0], %0 {in_bounds = [true, true]} :
+                !memrefC, !vecC
+      %3 = vector.transfer_read %subview[%c0, %c16], %0 {in_bounds = [true, true]} :
+                !memrefC, !vecC
+      %4 = vector.transfer_read %subview[%c16, %c0], %0 {in_bounds = [true, true]} :
+                !memrefC, !vecC
+      %5 = vector.transfer_read %subview[%c16, %c16], %0 {in_bounds = [true, true]} :
+                !memrefC, !vecC
+
+      %6:4 = scf.for %arg5 = %c0 to %c16 step %c1 iter_args(%arg6 = %2, %arg7 = %3, %arg8 = %4, %arg9 = %5) -> (!vecC, !vecC, !vecC, !vecC) {
+        %7:4 = scf.for %arg10 = %c0 to %c64 step %c16 iter_args(%arg11 = %arg6, %arg12 = %arg7, %arg13 = %arg8, %arg14 = %arg9) -> (!vecC, !vecC, !vecC, !vecC) {
+
+          %subview_0 = memref.subview %arg0[%arg5, %arg3, %arg10, 0] [1, 32, 16, 2] [1, 1, 1, 1] :
+                memref<16x64x64x2xbf16> to !memrefA
+          %subview_1 = memref.subview %arg1[%arg5, %arg10, %arg4, 0] [1, 16, 32, 2] [1, 1, 1, 1] :
+                memref<16x64x128x2xbf16> to !memrefB
+          %8 = vector.transfer_read %subview_0[%c0, %c0, %c0, %c0], %1 {in_bounds = [true, true, true, true]} :
+                !memrefA, !vecAB
+          %9 = vector.transfer_read %subview_1[%c0, %c0, %c0, %c0], %1 {in_bounds = [true, true, true, true]} :
+                !memrefB, !vecAB
+
+          %10 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types =
+                ["reduction", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+                %8, %9, %arg11 {unroll_shape = array<i64: 1, 2, 16, 16, 16>} : !vecAB, !vecAB into !vecC
+
+          %11 = vector.transfer_read %subview_1[%c0, %c0, %c16, %c0], %1 {in_bounds = [true, true, true, true]} :
+                !memrefB, !vecAB
+          %12 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types =
+                ["reduction", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+                %8, %11, %arg12 {unroll_shape = array<i64: 1, 2, 16, 16, 16>} : !vecAB, !vecAB into !vecC
+          %13 = vector.transfer_read %subview_0[%c0, %c16, %c0, %c0], %1 {in_bounds = [true, true, true, true]} :
+                !memrefA, !vecAB
+          %14 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types =
+                ["reduction", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+                %13, %9, %arg13 {unroll_shape = array<i64: 1, 2, 16, 16, 16>} : !vecAB, !vecAB into !vecC
+          %15 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types =
+                ["reduction", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+                %13, %11, %arg14 {unroll_shape = array<i64: 1, 2, 16, 16, 16>} : !vecAB, !vecAB into !vecC
+
+          scf.yield %10, %12, %14, %15 : !vecC, !vecC, !vecC, !vecC
+        }
+        scf.yield %7#0, %7#1, %7#2, %7#3 : !vecC, !vecC, !vecC, !vecC
+      }
+
+      vector.transfer_write %6#3, %subview[%c16, %c16] {in_bounds = [true, true]} :
+        !vecC, !memrefC
+      vector.transfer_write %6#2, %subview[%c16, %c0] {in_bounds = [true, true]} :
+        !vecC, !memrefC
+      vector.transfer_write %6#1, %subview[%c0, %c16] {in_bounds = [true, true]} :
+        !vecC, !memrefC
+      vector.transfer_write %6#0, %subview[%c0, %c0] {in_bounds = [true, true]} :
+        !vecC, !memrefC
+    }
+  }
+
+  return %arg2 : memref<64x128xf32>
+}
+
+// CHECK-LABEL: @brgemm_bf16_loop
+// CHECK-2: scf.for {{.*}} -> (!x86.amx.tile<16x16xf32>, !x86.amx.tile<16x16xf32>, !x86.amx.tile<16x16xf32>, !x86.amx.tile<16x16xf32>) { 
+// CHECK-4: x86.amx.tile_zero : !x86.amx.tile<16x16xf32>
+// CHECK-4: x86.amx.tile_load
+// CHECK-4: x86.amx.tile_mulf
+// CHECK: scf.yield {{.*}} : !x86.amx.tile<16x16xf32>, !x86.amx.tile<16x16xf32>, !x86.amx.tile<16x16xf32>, !x86.amx.tile<16x16xf32>
+// CHECK-NOT: scf.for {{.*}} vector<16x16xf32>, vector<16x16xf32>, vector<16x16xf32>, vector<16x16xf32>
+// CHECK-NOT: vector.contract
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.x86.vector_contract_to_amx_dot_product
+    } : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+!vecAB = vector<16x16x4xi8>
+!vecC = vector<16x16xi32>
+!memrefA = memref<16x16x4xi8, strided<[256, 4, 1], offset: ?>>
+!memrefB = memref<16x32x4xi8, strided<[512, 4, 1], offset: ?>>
+!memrefC = memref<16x32xi32, strided<[128, 1], offset: ?>>
+
+#map = affine_map<(d0, d1, d2, d3) -> (d1, d3, d0)>
+#map1 = affine_map<(d0, d1, d2, d3) -> (d3, d2, d0)>
+#map2 = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+
+func.func @matmul_int8_loop(%arg0: memref<64x64x4xi8>, %arg1: memref<64x128x4xi8>, %arg2: memref<64x128xi32>) {
+  %0 = ub.poison : i32
+  %1 = ub.poison : i8
+  %c0 = arith.constant 0 : index
+  %c64 = arith.constant 64 : index
+  %c128 = arith.constant 128 : index
+  %c16 = arith.constant 16 : index
+  %c32 = arith.constant 32 : index
+  scf.for %arg3 = %c0 to %c64 step %c16 {
+    scf.for %arg4 = %c0 to %c128 step %c32 {
+
+      %subview = memref.subview %arg2[%arg3, %arg4] [16, 32] [1, 1] :
+                memref<64x128xi32> to !memrefC
+      %2 = vector.transfer_read %subview[%c0, %c0], %0 {in_bounds = [true, true]} :
+                !memrefC, !vecC
+      %3 = vector.transfer_read %subview[%c0, %c16], %0 {in_bounds = [true, true]} :
+                !memrefC, !vecC
+
+      %4:2 = scf.for %arg5 = %c0 to %c64 step %c16 iter_args(%arg6 = %2, %arg7 = %3) -> (!vecC, !vecC) {
+
+        %subview_0 = memref.subview %arg0[%arg3, %arg5, 0] [16, 16, 4] [1, 1, 1] :
+                memref<64x64x4xi8> to !memrefA
+        %subview_1 = memref.subview %arg1[%arg5, %arg4, 0] [16, 32, 4] [1, 1, 1] :
+                memref<64x128x4xi8> to !memrefB
+        %5 = vector.transfer_read %subview_0[%c0, %c0, %c0], %1 {in_bounds = [true, true, true]} :
+                !memrefA, !vecAB
+        %6 = vector.transfer_read %subview_1[%c0, %c0, %c0], %1 {in_bounds = [true, true, true]} :
+                !memrefB, !vecAB
+        %7 = vector.transfer_read %subview_1[%c0, %c16, %c0], %1 {in_bounds = [true, true, true]} :
+                !memrefB, !vecAB
+
+        %8 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types =
+                ["reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+                %5, %6, %arg6 {unroll_shape = array<i64: 4, 16, 16, 16>} : !vecAB, !vecAB into !vecC
+        %9 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types =
+                ["reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+                %5, %7, %arg7 {unroll_shape = array<i64: 4, 16, 16, 16>} : !vecAB, !vecAB into !vecC
+
+        scf.yield %8, %9 : !vecC, !vecC
+      }
+
+      vector.transfer_write %4#1, %subview[%c0, %c16] {in_bounds = [true, true]} :
+                !vecC, !memrefC
+      vector.transfer_write %4#0, %subview[%c0, %c0] {in_bounds = [true, true]} :
+                !vecC, !memrefC
+    }
+  }
+
+  return
+}
+
+// CHECK-LABEL: @matmul_int8_loop
+// CHECK-2: x86.amx.tile_zero : !x86.amx.tile<16x16xi32>
+// CHECK: scf.for {{.*}} -> (!x86.amx.tile<16x16xi32>, !x86.amx.tile<16x16xi32>) {
+// CHECK-3: x86.amx.tile_load
+// CHECK-2: x86.amx.tile_muli
+// CHECK: scf.yield {{.*}} !x86.amx.tile<16x16xi32>, !x86.amx.tile<16x16xi32>
+// CHECK-NOT: scf.for {{.*}} vector<16x16xi32>, vector<16x16xi32>
+// CHECK-NOT: vector.contract
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.x86.vector_contract_to_amx_dot_product
+    } : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+!vecAB = vector<1x16x16x4xi8>
+!vecC = vector<1x16x16xi32>
+!memrefA = memref<1x16x16x4xi8, strided<[16384, 256, 4, 1], offset: ?>>
+!memrefB = memref<1x16x32x4xi8, strided<[32768, 512, 4, 1], offset: ?>>
+!memrefC = memref<1x16x32xi32, strided<[8192, 128, 1], offset: ?>>
+
+#map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d2, d4, d1)>
+#map1 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d4, d3, d1)>
+#map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d2, d3)>
+
+func.func @batch_matmul_int8_loop(%arg0: memref<16x64x64x4xi8>, %arg1: memref<16x64x128x4xi8>, %arg2: memref<16x64x128xi32>) {
+  %0 = ub.poison : i32
+  %1 = ub.poison : i8
+  %c0 = arith.constant 0 : index
+  %c64 = arith.constant 64 : index
+  %c128 = arith.constant 128 : index
+  %c16 = arith.constant 16 : index
+  %c32 = arith.constant 32 : index
+  %c1 = arith.constant 1 : index
+  scf.for %arg3 = %c0 to %c64 step %c16 {
+    scf.for %arg4 = %c0 to %c128 step %c32 {
+      scf.for %arg5 = %c0 to %c16 step %c1 {
+
+        %subview = memref.subview %arg2[%arg5, %arg3, %arg4] [1, 16, 32] [1, 1, 1] :
+                memref<16x64x128xi32> to !memrefC
+        %2 = vector.transfer_read %subview[%c0, %c0, %c0], %0 {in_bounds = [true, true, true]} :
+                !memrefC, !vecC
+        %3 = vector.transfer_read %subview[%c0, %c0, %c16], %0 {in_bounds = [true, true, true]} :
+                !memrefC, !vecC
+        %4:2 = scf.for %arg6 = %c0 to %c64 step %c16 iter_args(%arg7 = %2, %arg8 = %3) -> (!vecC, !vecC) {
+
+          %subview_0 = memref.subview %arg0[%arg5, %arg3, %arg6, 0] [1, 16, 16, 4] [1, 1, 1, 1] :
+                memref<16x64x64x4xi8> to !memrefA
+          %subview_1 = memref.subview %arg1[%arg5, %arg6, %arg4, 0] [1, 16, 32, 4] [1, 1, 1, 1] :
+                memref<16x64x128x4xi8> to !memrefB
+          %5 = vector.transfer_read %subview_0[%c0, %c0, %c0, %c0], %1 {in_bounds = [true, true, true, true]} :
+                !memrefA, !vecAB
+          %6 = vector.transfer_read %subview_1[%c0, %c0, %c0, %c0], %1 {in_bounds = [true, true, true, true]} :
+                !memrefB, !vecAB
+          %7 = vector.transfer_read %subview_1[%c0, %c0, %c16, %c0], %1 {in_bounds = [true, true, true, true]} :
+                !memrefB, !vecAB
+          %8 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types =
+                ["parallel", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+                %5, %6, %arg7 {unroll_shape = array<i64: 1, 4, 16, 16, 16>} : !vecAB, !vecAB into !vecC
+          %9 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types =
+                ["parallel", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+                %5, %7, %arg8 {unroll_shape = array<i64: 1, 4, 16, 16, 16>} : !vecAB, !vecAB into !vecC
+          scf.yield %8, %9 : !vecC, !vecC
+        }
+
+        vector.transfer_write %4#1, %subview[%c0, %c0, %c16] {in_bounds = [true, true, true]} :
+                !vecC, !memrefC
+        vector.transfer_write %4#0, %subview[%c0, %c0, %c0] {in_bounds = [true, true, true]} :
+                !vecC, !memrefC
+      }
+    }
+  }
+  return
+}
+
+// CHECK-LABEL: @batch_matmul_int8_loop
+// CHECK-2: x86.amx.tile_zero : !x86.amx.tile<16x16xi32>
+// CHECK: scf.for {{.*}} -> (!x86.amx.tile<16x16xi32>, !x86.amx.tile<16x16xi32>) {
+// CHECK-3: x86.amx.tile_load
+// CHECK-2: x86.amx.tile_muli
+// CHECK: scf.yield {{.*}} !x86.amx.tile<16x16xi32>, !x86.amx.tile<16x16xi32>
+// CHECK-NOT: scf.for {{.*}} vector<16x16xi32>, vector<16x16xi32>
+// CHECK-NOT: vector.contract
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.x86.vector_contract_to_amx_dot_product
+    } : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+!vecA = vector<1x16x16x4xi8>
+!vecB = vector<1x16x16x4xi8>
+!vecC = vector<16x16xi32>
+!memrefA = memref<1x32x16x4xi8>
+!memrefB = memref<1x16x32x4xi8>
+!memrefC = memref<32x32xi32>
+#map = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d3, d4)>
+#map1 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d3, d2, d4)>
+#map2 = affine_map<(d0, d4, d1, d2, d3) -> (d1, d2)>
+func.func @negative_invalid_vc_kind(
+  %arg0: !memrefA, %arg1: !memrefB, %arg2: !memrefC) -> !memrefC
+{
+  %c0 = arith.constant 0 : index
+  %0 = ub.poison : i8
+  %32 = ub.poison : i32
+
+  %1 = vector.transfer_read %arg0[%c0, %c0, %c0, %c0], %0 {in_bounds = [true, true, true, true]} :
+        !memrefA, !vecA
+  %2 = vector.transfer_read %arg1[%c0, %c0, %c0, %c0], %0 {in_bounds = [true, true, true, true]} :
+        !memrefB, !vecB
+
+  %3 = vector.transfer_read %arg2[%c0, %c0], %32 {in_bounds = [true, true]} : !memrefC, !vecC
+
+  %4 = vector.contract {
+    indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["reduction", "reduction", "parallel", "parallel", "reduction"],
+    kind = #vector.kind<mul>}
+    %1, %2, %3 : !vecA, !vecB into !vecC
+
+  vector.transfer_write %4, %arg2[%c0, %c0] {in_bounds = [true, true]} : !vecC, !memrefC
+
+  return %arg2 : !memrefC
+}
+
+// CHECK-LABEL: @negative_invalid_vc_kind
+// CHECK-NOT: x86.amx.tile_load {{.*}} !x86.amx.tile<16x64xi8>
+// CHECK-NOT: x86.amx.tile_muli
+// CHECK-NOT: x86.amx.tile_store {{.*}} !x86.amx.tile<16x16xi32>
+// CHECK: vector.contract
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.x86.vector_contract_to_amx_dot_product
+    } : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+!vecA = vector<1x32x16x2xbf16>
+!vecB = vector<1x16x32x2xbf16>
+!vecC = vector<1x32x32xf32>
+!memrefA = memref<1x32x16x2xbf16>
+!memrefB = memref<1x16x32x2xbf16>
+!memrefC = memref<1x32x32xf32>
+#map = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d3, d4)>
+#map1 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d3, d2, d4)>
+#map2 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d2)>
+func.func @negative_wrong_dimensions(
+  %arg0: !memrefA, %arg1: !memrefB, %arg2: !memrefC) -> !memrefC
+{
+  %c0 = arith.constant 0 : index
+  %0 = ub.poison : bf16
+  %32 = ub.poison : f32
+
+  %1 = vector.transfer_read %arg0[%c0, %c0, %c0, %c0], %0 {in_bounds = [true, true, true, true]} :
+        !memrefA, !vecA
+  %2 = vector.transfer_read %arg1[%c0, %c0, %c0, %c0], %0 {in_bounds = [true, true, true, true]} :
+        !memrefB, !vecB
+
+  %3 = vector.transfer_read %arg2[%c0, %c0, %c0], %32 {in_bounds = [true, true, true]} : !memrefC, !vecC
+
+  %4 = vector.contract {
+    indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["parallel", "reduction", "parallel", "parallel", "reduction"],
+    kind = #vector.kind<add>}
+    %1, %2, %3 : !vecA, !vecB into !vecC
+
+  vector.transfer_write %4, %arg2[%c0, %c0, %c0] {in_bounds = [true, true, true]} : !vecC, !memrefC
+
+  return %arg2 : !memrefC
+}
+
+// CHECK-LABEL: @negative_wrong_dimensions
+// CHECK-NOT: x86.amx.tile_load
+// CHECK-NOT: x86.amx.tile_mulf
+// CHECK-NOT: x86.amx.tile_store 
+// CHECK: vector.contract
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.x86.vector_contract_to_amx_dot_product
+    } : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+!vecA = vector<16x16x4xi8>
+!vecB = vector<16x16x4xi8>
+!vecC = vector<16x16xi32>
+!memrefB = memref<16x32x4xi8>
+!memrefC = memref<32x32xi32>
+#map = affine_map<(d4, d1, d2, d3) -> (d1, d3, d4)>
+#map1 = affine_map<(d4, d1, d2, d3) -> (d3, d2, d4)>
+#map2 = affine_map<(d4, d1, d2, d3) -> (d1, d2)>
+func.func @negative_no_memref_source_LHS(
+  %arg0: !vecA, %arg1: !memrefB, %arg2: !memrefC) -> !memrefC
+{
+  %c0 = arith.constant 0 : index
+  %0 = ub.poison : i8
+  %32 = ub.poison : i32
+
+  %2 = vector.transfer_read %arg1[%c0, %c0, %c0], %0 {in_bounds = [true, true, true]} :
+        !memrefB, !vecB
+
+  %3 = vector.transfer_read %arg2[%c0, %c0], %32 {in_bounds = [true, true]} : !memrefC, !vecC
+
+  %4 = vector.contract {
+    indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["reduction", "parallel", "parallel", "reduction"],
+    kind = #vector.kind<add>}
+    %arg0, %2, %3 : !vecA, !vecB into !vecC
+
+  vector.transfer_write %4, %arg2[%c0, %c0] {in_bounds = [true, true]} : !vecC, !memrefC
+
+  return %arg2 : !memrefC
+}
+
+// CHECK-LABEL: @negative_no_memref_source_LHS
+// CHECK-NOT: x86.amx.tile_load {{.*}} !x86.amx.tile<16x64xi8>
+// CHECK-NOT: x86.amx.tile_muli
+// CHECK-NOT: x86.amx.tile_store {{.*}} !x86.amx.tile<16x16xi32>
+// CHECK: vector.contract
+
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.x86.vector_contract_to_amx_dot_product
+    } : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+!vecA = vector<16x64xi8>
+!vecB = vector<64x16xi8>
+!vecC = vector<16x16xi32>
+!memrefA = memref<32x64xi8>
+!memrefB = memref<64x32xi8>
+!memrefC = memref<32x32xi32>
+#map = affine_map<(d1, d2, d3) -> (d1, d3)>
+#map1 = affine_map<(d1, d2, d3) -> (d3, d2)>
+#map2 = affine_map<(d1, d2, d3) -> (d1, d2)>
+func.func @negative_no_vnni_packed(
+  %arg0: !memrefA, %arg1: !memrefB, %arg2: !memrefC) -> !memrefC
+{
+  %c0 = arith.constant 0 : index
+  %0 = ub.poison : i8
+  %32 = ub.poison : i32
+
+  %1 = vector.transfer_read %arg0[%c0, %c0], %0 {in_bounds = [true, true]} :
+        !memrefA, !vecA
+  %2 = vector.transfer_read %arg1[%c0, %c0], %0 {in_bounds = [true, true]} :
+        !memrefB, !vecB
+
+  %3 = vector.transfer_read %arg2[%c0, %c0], %32 {in_bounds = [true, true]} : !memrefC, !vecC
+
+  %4 = vector.contract {
+    indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["parallel", "parallel", "reduction"],
+    kind = #vector.kind<add>}
+    %1, %2, %3 : !vecA, !vecB into !vecC
+
+  vector.transfer_write %4, %arg2[%c0, %c0] {in_bounds = [true, true]} : !vecC, !memrefC
+
+  return %arg2 : !memrefC
+}
+
+// CHECK-LABEL: @negative_no_vnni_packed
+// CHECK-NOT: x86.amx.tile_load {{.*}} !x86.amx.tile<16x64xi8>
+// CHECK-NOT: x86.amx.tile_muli
+// CHECK-NOT: x86.amx.tile_store {{.*}} !x86.amx.tile<16x16xi32>
+// CHECK: vector.contract
+
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.x86.vector_contract_to_amx_dot_product
+    } : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+!vecAB = vector<16x16x4xi8>
+!vecC = vector<16x16xi32>
+!memrefA = memref<16x16x4xi8, strided<[256, 4, 1], offset: ?>>
+!memrefB = memref<16x32x4xi8, strided<[512, 4, 1], offset: ?>>
+!memrefC = memref<16x32xi32, strided<[128, 1], offset: ?>>
+
+#map = affine_map<(d0, d1, d2, d3) -> (d1, d3, d0)>
+#map1 = affine_map<(d0, d1, d2, d3) -> (d3, d2, d0)>
+#map2 = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+
+func.func @negative_VCs_LHS_src_
diff er(%arg0: memref<64x64x4xi8>, %arg1: memref<64x128x4xi8>, %arg2: memref<64x128xi32>, %arg13: memref<64x64x4xi8>) {
+  %0 = ub.poison : i32
+  %1 = ub.poison : i8
+  %c0 = arith.constant 0 : index
+  %c64 = arith.constant 64 : index
+  %c128 = arith.constant 128 : index
+  %c16 = arith.constant 16 : index
+  %c32 = arith.constant 32 : index
+  scf.for %arg3 = %c0 to %c64 step %c16 {
+    scf.for %arg4 = %c0 to %c128 step %c32 {
+
+      %subview = memref.subview %arg2[%arg3, %arg4] [16, 32] [1, 1] :
+                memref<64x128xi32> to !memrefC
+      %2 = vector.transfer_read %subview[%c0, %c0], %0 {in_bounds = [true, true]} :
+                !memrefC, !vecC
+      %3 = vector.transfer_read %subview[%c0, %c16], %0 {in_bounds = [true, true]} :
+                !memrefC, !vecC
+
+      %4:2 = scf.for %arg5 = %c0 to %c64 step %c16 iter_args(%arg6 = %2, %arg7 = %3) -> (!vecC, !vecC) {
+
+        %subview_0 = memref.subview %arg0[%arg3, %arg5, 0] [16, 16, 4] [1, 1, 1] :
+                memref<64x64x4xi8> to !memrefA
+        %subview_negative = memref.subview %arg13[%arg3, %arg5, 0] [16, 16, 4] [1, 1, 1] :
+                memref<64x64x4xi8> to !memrefA
+        %subview_1 = memref.subview %arg1[%arg5, %arg4, 0] [16, 32, 4] [1, 1, 1] :
+                memref<64x128x4xi8> to !memrefB
+        %5 = vector.transfer_read %subview_0[%c0, %c0, %c0], %1 {in_bounds = [true, true, true]} :
+                !memrefA, !vecAB
+        %odd_load = vector.transfer_read %subview_negative[%c0, %c0, %c0], %1 {in_bounds = [true, true, true]} :
+                !memrefA, !vecAB
+        %6 = vector.transfer_read %subview_1[%c0, %c0, %c0], %1 {in_bounds = [true, true, true]} :
+                !memrefB, !vecAB
+        %7 = vector.transfer_read %subview_1[%c0, %c16, %c0], %1 {in_bounds = [true, true, true]} :
+                !memrefB, !vecAB
+
+        %8 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types =
+                ["reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+                %5, %6, %arg6 {unroll_shape = array<i64: 4, 16, 16, 16>} : !vecAB, !vecAB into !vecC
+        %9 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types =
+                ["reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+                %odd_load, %7, %arg7 {unroll_shape = array<i64: 4, 16, 16, 16>} : !vecAB, !vecAB into !vecC
+
+        scf.yield %8, %9 : !vecC, !vecC
+      }
+
+      vector.transfer_write %4#1, %subview[%c0, %c16] {in_bounds = [true, true]} :
+                !vecC, !memrefC
+      vector.transfer_write %4#0, %subview[%c0, %c0] {in_bounds = [true, true]} :
+                !vecC, !memrefC
+    }
+  }
+
+  return
+}
+
+// CHECK-LABEL: @negative_VCs_LHS_src_
diff er
+// CHECK-NOT: x86.amx.tile_zero : !x86.amx.tile<16x16xi32>
+// CHECK-NOT: scf.for {{.*}} -> (!x86.amx.tile<16x16xi32>, !x86.amx.tile<16x16xi32>) {
+// CHECK-NOT: x86.amx.tile_load
+// CHECK-NOT: x86.amx.tile_muli
+// CHECK-NOT: scf.yield {{.*}} !x86.amx.tile<16x16xi32>, !x86.amx.tile<16x16xi32>
+// CHECK: scf.for {{.*}} -> (vector<16x16xi32>, vector<16x16xi32>) {
+// CHECK: vector.contract
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.x86.vector_contract_to_amx_dot_product
+    } : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+!vecA = vector<1x16x16x4xi8>
+!vecB = vector<1x16x32x4xi8>
+!vecC = vector<1x16x32xi32>
+!memrefA = memref<1x16x16x4xi8, strided<[16384, 256, 4, 1], offset: ?>>
+!memrefB = memref<1x16x32x4xi8, strided<[32768, 512, 4, 1], offset: ?>>
+!memrefC = memref<1x16x32xi32, strided<[8192, 128, 1], offset: ?>>
+
+#map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d2, d4, d1)>
+#map1 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d4, d3, d1)>
+#map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d2, d3)>
+
+func.func @negative_wrong_N_dim(%arg0: memref<16x64x64x4xi8>, %arg1: memref<16x64x128x4xi8>, %arg2: memref<16x64x128xi32>) {
+  %0 = ub.poison : i32
+  %1 = ub.poison : i8
+  %c0 = arith.constant 0 : index
+  %c64 = arith.constant 64 : index
+  %c128 = arith.constant 128 : index
+  %c16 = arith.constant 16 : index
+  %c32 = arith.constant 32 : index
+  %c1 = arith.constant 1 : index
+  scf.for %arg3 = %c0 to %c64 step %c16 {
+    scf.for %arg4 = %c0 to %c128 step %c32 {
+      scf.for %arg5 = %c0 to %c16 step %c1 {
+
+        %subview = memref.subview %arg2[%arg5, %arg3, %arg4] [1, 16, 32] [1, 1, 1] :
+                memref<16x64x128xi32> to !memrefC
+        %2 = vector.transfer_read %subview[%c0, %c0, %c0], %0 {in_bounds = [true, true, true]} :
+                !memrefC, !vecC
+        %3 = scf.for %arg6 = %c0 to %c64 step %c16 iter_args(%arg7 = %2) -> (!vecC) {
+          %subview_0 = memref.subview %arg0[%arg5, %arg3, %arg6, 0] [1, 16, 16, 4] [1, 1, 1, 1] :
+                memref<16x64x64x4xi8> to !memrefA
+          %subview_1 = memref.subview %arg1[%arg5, %arg6, %arg4, 0] [1, 16, 32, 4] [1, 1, 1, 1] :
+                memref<16x64x128x4xi8> to !memrefB
+          %4 = vector.transfer_read %subview_0[%c0, %c0, %c0, %c0], %1 {in_bounds = [true, true, true, true]} :
+                !memrefA, !vecA
+          %5 = vector.transfer_read %subview_1[%c0, %c0, %c0, %c0], %1 {in_bounds = [true, true, true, true]} :
+                !memrefB, !vecB
+          %6 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types =
+                ["parallel", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+                %4, %5, %arg7 {unroll_shape = array<i64: 1, 4, 16, 32, 16>} : !vecA, !vecB into !vecC
+          scf.yield %6 : !vecC
+        }
+        vector.transfer_write %3, %subview[%c0, %c0, %c0] {in_bounds = [true, true, true]} :
+                !vecC, !memrefC
+      }
+    }
+  }
+  return
+}
+
+// CHECK-LABEL: @negative_wrong_N_dim
+// CHECK-NOT: x86.amx.tile_zero : !x86.amx.tile<16x16xi32>
+// CHECK-NOT: x86.amx.tile_load
+// CHECK-NOT: x86.amx.tile_muli
+// CHECK: vector.contract
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.x86.vector_contract_to_amx_dot_product
+    } : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+!vecAB = vector<1x1x16x16x4xi8>
+!vecC = vector<16x16xi32>
+!memrefA = memref<1x1x16x16x4xi8, strided<[262144, 16384, 256, 4, 1], offset: ?>>
+!memrefB = memref<1x1x16x32x4xi8, strided<[524288, 32768, 512, 4, 1], offset: ?>>
+!memrefC = memref<16x32xi32, strided<[128, 1], offset: ?>>
+
+#map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d3, d5, d2)>
+#map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d5, d4, d2)>
+#map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d3, d4)>
+
+func.func @negative_reduction_loop_depth_3(%arg0: memref<2x16x64x64x4xi8>, %arg1: memref<2x16x64x128x4xi8>, %arg2: memref<64x128xi32>) attributes {dlti.target_system_spec = #dlti.target_system_spec<"CPU" = #dlti.target_device_spec<"reg_gemm_unroll" = [16, 16, 16]>>} {
+  %0 = ub.poison : i32
+  %1 = ub.poison : i8
+  %c0 = arith.constant 0 : index
+  %c64 = arith.constant 64 : index
+  %c128 = arith.constant 128 : index
+  %c2 = arith.constant 2 : index
+  %c16 = arith.constant 16 : index
+  %c32 = arith.constant 32 : index
+  %c1 = arith.constant 1 : index
+  scf.for %arg3 = %c0 to %c64 step %c16 {
+    scf.for %arg4 = %c0 to %c128 step %c32 {
+      %subview = memref.subview %arg2[%arg3, %arg4] [16, 32] [1, 1] :
+                memref<64x128xi32> to !memrefC
+      %2 = vector.transfer_read %subview[%c0, %c0], %0 {in_bounds = [true, true]} :
+                !memrefC, !vecC
+      %3 = vector.transfer_read %subview[%c0, %c16], %0 {in_bounds = [true, true]} :
+                !memrefC, !vecC
+
+      %4:2 = scf.for %arg5 = %c0 to %c2 step %c1 iter_args(%arg6 = %2, %arg7 = %3) -> (!vecC, !vecC) {
+        %5:2 = scf.for %arg8 = %c0 to %c16 step %c1 iter_args(%arg9 = %arg6, %arg10 = %arg7) -> (!vecC, !vecC) {
+          %6:2 = scf.for %arg11 = %c0 to %c64 step %c16 iter_args(%arg12 = %arg9, %arg13 = %arg10) -> (!vecC, !vecC) {
+            %subview_0 = memref.subview %arg0[%arg5, %arg8, %arg3, %arg11, 0] [1, 1, 16, 16, 4] [1, 1, 1, 1, 1] :
+                memref<2x16x64x64x4xi8> to !memrefA
+            %subview_1 = memref.subview %arg1[%arg5, %arg8, %arg11, %arg4, 0] [1, 1, 16, 32, 4] [1, 1, 1, 1, 1] :
+                memref<2x16x64x128x4xi8> to !memrefB
+            %7 = vector.transfer_read %subview_0[%c0, %c0, %c0, %c0, %c0], %1 {in_bounds = [true, true, true, true, true]} :
+                !memrefA, !vecAB
+            %8 = vector.transfer_read %subview_1[%c0, %c0, %c0, %c0, %c0], %1 {in_bounds = [true, true, true, true, true]} :
+                !memrefB, !vecAB
+            %9 = vector.transfer_read %subview_1[%c0, %c0, %c0, %c16, %c0], %1 {in_bounds = [true, true, true, true, true]} :
+                !memrefB, !vecAB
+
+            %10 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types =
+                ["reduction", "reduction", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+                %7, %8, %arg12 {unroll_shape = array<i64: 1, 1, 4, 16, 16, 16>} : !vecAB, !vecAB into !vecC
+            %11 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types =
+                ["reduction", "reduction", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+                %7, %9, %arg13 {unroll_shape = array<i64: 1, 1, 4, 16, 16, 16>} : !vecAB, !vecAB into !vecC
+            scf.yield %10, %11 : !vecC, !vecC
+          }
+          scf.yield %6#0, %6#1 : !vecC, !vecC
+        }
+        scf.yield %5#0, %5#1 : !vecC, !vecC
+      }
+
+      vector.transfer_write %4#1, %subview[%c0, %c16] {in_bounds = [true, true]} :
+                !vecC, !memrefC
+      vector.transfer_write %4#0, %subview[%c0, %c0] {in_bounds = [true, true]} :
+                !vecC, !memrefC
+    }
+  }
+  return
+}
+
+
+// CHECK-LABEL: @negative_reduction_loop_depth_3
+// CHECK-NOT: x86.amx.tile_zero : !x86.amx.tile<16x16xi32>
+// CHECK-NOT: x86.amx.tile_load
+// CHECK-NOT: x86.amx.tile_muli
+// CHECK: vector.contract
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.x86.vector_contract_to_amx_dot_product
+    } : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+!vecA = vector<1x16x16x2xf16>
+!vecB = vector<1x16x16x2xf16>
+!vecC = vector<16x16xf32>
+!memrefA = memref<1x32x16x2xf16>
+!memrefB = memref<1x16x32x2xf16>
+!memrefC = memref<32x32xf32>
+#map = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d3, d4)>
+#map1 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d3, d2, d4)>
+#map2 = affine_map<(d0, d4, d1, d2, d3) -> (d1, d2)>
+func.func @negative_wrong_type_f16(
+  %arg0: !memrefA, %arg1: !memrefB, %arg2: !memrefC) -> !memrefC
+{
+  %c0 = arith.constant 0 : index
+  %0 = ub.poison : f16
+  %32 = ub.poison : f32
+
+  %1 = vector.transfer_read %arg0[%c0, %c0, %c0, %c0], %0 {in_bounds = [true, true, true, true]} :
+        !memrefA, !vecA
+  %2 = vector.transfer_read %arg1[%c0, %c0, %c0, %c0], %0 {in_bounds = [true, true, true, true]} :
+        !memrefB, !vecB
+
+  %3 = vector.transfer_read %arg2[%c0, %c0], %32 {in_bounds = [true, true]} : !memrefC, !vecC
+
+  %4 = vector.contract {
+    indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["reduction", "reduction", "parallel", "parallel", "reduction"],
+    kind = #vector.kind<add>}
+    %1, %2, %3 : !vecA, !vecB into !vecC
+
+  vector.transfer_write %4, %arg2[%c0, %c0] {in_bounds = [true, true]} : !vecC, !memrefC
+
+  return %arg2 : !memrefC
+}
+
+// CHECK-LABEL: @negative_wrong_type_f16
+// CHECK-NOT: x86.amx.tile_load
+// CHECK-NOT: x86.amx.tile_mulf
+// CHECK-NOT: x86.amx.tile_store
+// CHECK: vector.contract
+
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.x86.vector_contract_to_amx_dot_product
+    } : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+!vecA = vector<1x16x16x2xbf16>
+!vecB = vector<1x16x16x2xbf16>
+!vecC = vector<16x16xbf16>
+!memrefA = memref<1x32x16x2xbf16>
+!memrefB = memref<1x16x32x2xbf16>
+!memrefC = memref<32x32xbf16>
+#map = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d3, d4)>
+#map1 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d3, d2, d4)>
+#map2 = affine_map<(d0, d4, d1, d2, d3) -> (d1, d2)>
+func.func @negative_wrong_acc_type_bf16(
+  %arg0: !memrefA, %arg1: !memrefB, %arg2: !memrefC) -> !memrefC
+{
+  %c0 = arith.constant 0 : index
+  %0 = ub.poison : bf16
+  %32 = ub.poison : bf16
+
+  %1 = vector.transfer_read %arg0[%c0, %c0, %c0, %c0], %0 {in_bounds = [true, true, true, true]} :
+        !memrefA, !vecA
+  %2 = vector.transfer_read %arg1[%c0, %c0, %c0, %c0], %0 {in_bounds = [true, true, true, true]} :
+        !memrefB, !vecB
+
+  %3 = vector.transfer_read %arg2[%c0, %c0], %32 {in_bounds = [true, true]} : !memrefC, !vecC
+
+  %4 = vector.contract {
+    indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["reduction", "reduction", "parallel", "parallel", "reduction"],
+    kind = #vector.kind<add>}
+    %1, %2, %3 : !vecA, !vecB into !vecC
+
+  vector.transfer_write %4, %arg2[%c0, %c0] {in_bounds = [true, true]} : !vecC, !memrefC
+
+  return %arg2 : !memrefC
+}
+
+// CHECK-LABEL: @negative_wrong_acc_type_bf16
+// CHECK-NOT: x86.amx.tile_load
+// CHECK-NOT: x86.amx.tile_mulf
+// CHECK-NOT: x86.amx.tile_store
+// CHECK: vector.contract
+
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.x86.vector_contract_to_amx_dot_product
+    } : !transform.any_op
+    transform.yield
+  }
+}
+


        


More information about the Mlir-commits mailing list