[Mlir-commits] [mlir] [mlir][x86] Lower vector.contract to packed type tiled dot-product. (PR #182810)

Arun Thangamani llvmlistbot at llvm.org
Wed Mar 4 22:21:52 PST 2026


https://github.com/arun-thmn updated https://github.com/llvm/llvm-project/pull/182810

>From 4700e8d10e8a4dca1b466dffc4a0b1552d324967 Mon Sep 17 00:00:00 2001
From: Arun Thangamani <arun.thangamani at intel.com>
Date: Sun, 22 Feb 2026 22:47:13 -0800
Subject: [PATCH 1/7] initial commit for VC to tiled-dp lowering

---
 mlir/include/mlir/Dialect/AMX/CMakeLists.txt  |   2 +
 .../AMX/TransformOps/AMXTransformOps.h        |  31 +
 .../AMX/TransformOps/AMXTransformOps.td       |  32 +
 .../Dialect/AMX/TransformOps/CMakeLists.txt   |   4 +
 mlir/include/mlir/Dialect/AMX/Transforms.h    |   7 +
 mlir/lib/Dialect/AMX/CMakeLists.txt           |   1 +
 .../AMX/TransformOps/AMXTransformOps.cpp      |  57 ++
 .../Dialect/AMX/TransformOps/CMakeLists.txt   |  17 +
 .../lib/Dialect/AMX/Transforms/CMakeLists.txt |   1 +
 ...torContractToPackedTypeTiledDotProduct.cpp | 586 ++++++++++++++++++
 mlir/lib/RegisterAllExtensions.cpp            |   2 +
 .../AMX/vector-contract-to-tiled-dp.mlir      | 242 ++++++++
 12 files changed, 982 insertions(+)
 create mode 100644 mlir/include/mlir/Dialect/AMX/TransformOps/AMXTransformOps.h
 create mode 100644 mlir/include/mlir/Dialect/AMX/TransformOps/AMXTransformOps.td
 create mode 100644 mlir/include/mlir/Dialect/AMX/TransformOps/CMakeLists.txt
 create mode 100644 mlir/lib/Dialect/AMX/TransformOps/AMXTransformOps.cpp
 create mode 100644 mlir/lib/Dialect/AMX/TransformOps/CMakeLists.txt
 create mode 100644 mlir/lib/Dialect/AMX/Transforms/VectorContractToPackedTypeTiledDotProduct.cpp
 create mode 100644 mlir/test/Dialect/AMX/vector-contract-to-tiled-dp.mlir

diff --git a/mlir/include/mlir/Dialect/AMX/CMakeLists.txt b/mlir/include/mlir/Dialect/AMX/CMakeLists.txt
index f875c78d240cc..b211cb3b53fd7 100644
--- a/mlir/include/mlir/Dialect/AMX/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/AMX/CMakeLists.txt
@@ -3,3 +3,5 @@ add_mlir_doc(AMX AMX Dialects/ -gen-dialect-doc -dialect=amx)
 
 add_mlir_interface(AMXInterfaces)
 add_dependencies(MLIRAMXIncGen MLIRAMXInterfacesIncGen)
+
+add_subdirectory(TransformOps)
diff --git a/mlir/include/mlir/Dialect/AMX/TransformOps/AMXTransformOps.h b/mlir/include/mlir/Dialect/AMX/TransformOps/AMXTransformOps.h
new file mode 100644
index 0000000000000..8806635df8eb5
--- /dev/null
+++ b/mlir/include/mlir/Dialect/AMX/TransformOps/AMXTransformOps.h
@@ -0,0 +1,31 @@
+//===- AMXTransformOps.h - AMX transform ops --------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_AMX_TRANSFORMOPS_AMXTRANSFORMOPS_H
+#define MLIR_DIALECT_AMX_TRANSFORMOPS_AMXTRANSFORMOPS_H
+
+#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
+#include "mlir/IR/OpImplementation.h"
+
+//===----------------------------------------------------------------------===//
+// AMX Transform Operations
+//===----------------------------------------------------------------------===//
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/AMX/TransformOps/AMXTransformOps.h.inc"
+
+namespace mlir {
+class DialectRegistry;
+
+namespace amx {
+void registerTransformDialectExtension(DialectRegistry &registry);
+
+} // namespace amx
+} // namespace mlir
+
+#endif // MLIR_DIALECT_AMX_TRANSFORMOPS_AMXTRANSFORMOPS_H
diff --git a/mlir/include/mlir/Dialect/AMX/TransformOps/AMXTransformOps.td b/mlir/include/mlir/Dialect/AMX/TransformOps/AMXTransformOps.td
new file mode 100644
index 0000000000000..74bcba3c37e57
--- /dev/null
+++ b/mlir/include/mlir/Dialect/AMX/TransformOps/AMXTransformOps.td
@@ -0,0 +1,32 @@
+//===- AMXTransformOps.td - AMX transform ops --*- tablegen -*-------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef AMX_TRANSFORM_OPS
+#define AMX_TRANSFORM_OPS
+
+include "mlir/Dialect/Transform/IR/TransformDialect.td"
+include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.td"
+include "mlir/Interfaces/SideEffectInterfaces.td"
+include "mlir/IR/OpBase.td"
+include "mlir/Dialect/Transform/IR/TransformAttrs.td"
+include "mlir/Dialect/Transform/IR/TransformTypes.td"
+include "mlir/IR/RegionKindInterface.td"
+
+def ApplyVectorContractToPackedTypeTiledDotProductPatternsOp : Op<Transform_Dialect,
+    "apply_patterns.amx.vector_contract_to_packed_type_tiled_dot_product",
+    [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+  let description = [{
+    Collect patterns to lower a BF16/Int8 type vector.contract operation 
+	to a BF16/Int8 tiled dot-product.
+  }];
+
+  let assemblyFormat = "attr-dict";
+}
+
+
+#endif // AMX_TRANSFORM_OPS
diff --git a/mlir/include/mlir/Dialect/AMX/TransformOps/CMakeLists.txt b/mlir/include/mlir/Dialect/AMX/TransformOps/CMakeLists.txt
new file mode 100644
index 0000000000000..41255b936be71
--- /dev/null
+++ b/mlir/include/mlir/Dialect/AMX/TransformOps/CMakeLists.txt
@@ -0,0 +1,4 @@
+set(LLVM_TARGET_DEFINITIONS AMXTransformOps.td)
+mlir_tablegen(AMXTransformOps.h.inc -gen-op-decls)
+mlir_tablegen(AMXTransformOps.cpp.inc -gen-op-defs)
+add_mlir_dialect_tablegen_target(MLIRAMXTransformOpsIncGen)
diff --git a/mlir/include/mlir/Dialect/AMX/Transforms.h b/mlir/include/mlir/Dialect/AMX/Transforms.h
index 7391ec2ff6b14..f95bf296c1555 100644
--- a/mlir/include/mlir/Dialect/AMX/Transforms.h
+++ b/mlir/include/mlir/Dialect/AMX/Transforms.h
@@ -28,6 +28,13 @@ void configureAMXLegalizeForExportTarget(LLVMConversionTarget &target);
 /// Register LLVM conversion interface for AMX dialect.
 void registerConvertAMXToLLVMInterface(DialectRegistry &registry);
 
+namespace amx {
+
+void populateVectorContractToPackedTypeTiledDotProductPatterns(
+    RewritePatternSet &patterns);
+
+}
+
 } // namespace mlir
 
 #endif // MLIR_DIALECT_AMX_TRANSFORMS_H
diff --git a/mlir/lib/Dialect/AMX/CMakeLists.txt b/mlir/lib/Dialect/AMX/CMakeLists.txt
index 9f57627c321fb..cb1e9d01821a2 100644
--- a/mlir/lib/Dialect/AMX/CMakeLists.txt
+++ b/mlir/lib/Dialect/AMX/CMakeLists.txt
@@ -1,2 +1,3 @@
 add_subdirectory(IR)
 add_subdirectory(Transforms)
+add_subdirectory(TransformOps)
diff --git a/mlir/lib/Dialect/AMX/TransformOps/AMXTransformOps.cpp b/mlir/lib/Dialect/AMX/TransformOps/AMXTransformOps.cpp
new file mode 100644
index 0000000000000..6ff573407b42b
--- /dev/null
+++ b/mlir/lib/Dialect/AMX/TransformOps/AMXTransformOps.cpp
@@ -0,0 +1,57 @@
+//===- AMXTransformOps.cpp ------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/AMX/TransformOps/AMXTransformOps.h"
+#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
+#include "mlir/Dialect/AMX/AMXDialect.h"
+#include "mlir/Dialect/AMX/Transforms.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/Transform/IR/TransformDialect.h"
+#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+
+#include "mlir/IR/OpImplementation.h"
+#include "mlir/IR/RegionKindInterface.h"
+
+using namespace mlir;
+using namespace mlir::amx;
+using namespace mlir::transform;
+
+void mlir::transform::ApplyVectorContractToPackedTypeTiledDotProductPatternsOp::
+    populatePatterns(RewritePatternSet &patterns) {
+  amx::populateVectorContractToPackedTypeTiledDotProductPatterns(patterns);
+}
+
+//===----------------------------------------------------------------------===//
+// Transform op registration
+//===----------------------------------------------------------------------===//
+
+namespace {
+class AMXTransformDialectExtension
+    : public transform::TransformDialectExtension<
+          AMXTransformDialectExtension> {
+public:
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(AMXTransformDialectExtension)
+
+  AMXTransformDialectExtension() {
+    declareGeneratedDialect<amx::AMXDialect>();
+    declareGeneratedDialect<LLVM::LLVMDialect>();
+    registerTransformOps<
+#define GET_OP_LIST
+#include "mlir/Dialect/AMX/TransformOps/AMXTransformOps.cpp.inc"
+        >();
+  }
+};
+} // namespace
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/AMX/TransformOps/AMXTransformOps.cpp.inc"
+
+void mlir::amx::registerTransformDialectExtension(DialectRegistry &registry) {
+  registry.addExtensions<AMXTransformDialectExtension>();
+}
diff --git a/mlir/lib/Dialect/AMX/TransformOps/CMakeLists.txt b/mlir/lib/Dialect/AMX/TransformOps/CMakeLists.txt
new file mode 100644
index 0000000000000..30b4304586ab7
--- /dev/null
+++ b/mlir/lib/Dialect/AMX/TransformOps/CMakeLists.txt
@@ -0,0 +1,17 @@
+add_mlir_dialect_library(MLIRAMXTransformOps
+  AMXTransformOps.cpp
+
+  DEPENDS
+  MLIRAMXTransformOpsIncGen
+
+  LINK_LIBS PUBLIC
+  MLIRIR
+  MLIRLLVMCommonConversion
+  MLIRLLVMDialect
+  MLIRVectorDialect
+  MLIRSideEffectInterfaces
+  MLIRTransformDialect
+  MLIRTransformDialectUtils
+  MLIRAMXDialect
+  MLIRAMXTransforms
+  )
diff --git a/mlir/lib/Dialect/AMX/Transforms/CMakeLists.txt b/mlir/lib/Dialect/AMX/Transforms/CMakeLists.txt
index e827bc475e930..e422f275c7742 100644
--- a/mlir/lib/Dialect/AMX/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/AMX/Transforms/CMakeLists.txt
@@ -1,5 +1,6 @@
 add_mlir_dialect_library(MLIRAMXTransforms
   LegalizeForLLVMExport.cpp
+  VectorContractToPackedTypeTiledDotProduct.cpp
 
   LINK_LIBS PUBLIC
   MLIRAMXDialect
diff --git a/mlir/lib/Dialect/AMX/Transforms/VectorContractToPackedTypeTiledDotProduct.cpp b/mlir/lib/Dialect/AMX/Transforms/VectorContractToPackedTypeTiledDotProduct.cpp
new file mode 100644
index 0000000000000..37fd3bb33848b
--- /dev/null
+++ b/mlir/lib/Dialect/AMX/Transforms/VectorContractToPackedTypeTiledDotProduct.cpp
@@ -0,0 +1,586 @@
+//===- VectorContractToPackedTypeTiledDotProduct.cpp ----------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/AMX/AMXDialect.h"
+#include "mlir/Dialect/AMX/Transforms.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
+
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/Dominance.h"
+#include "mlir/IR/PatternMatch.h"
+#include "llvm/Support/Casting.h"
+
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+using namespace mlir;
+using namespace mlir::vector;
+using namespace mlir::amx;
+
+namespace {
+
+static Operation *traceToVectorReadLikeParentOperation(Value v) {
+  while (true) {
+    // Case 1: Value defined by an operation
+    if (Operation *defOp = v.getDefiningOp()) {
+      if (isa<vector::TransferReadOp, vector::LoadOp>(defOp)) {
+        return defOp;
+      }
+
+      if (isa<vector::ShapeCastOp, vector::ShuffleOp>(defOp)) {
+        return nullptr;
+      }
+
+      return nullptr;
+    }
+
+    // Case 2: BlockArgument (scf.for iter_arg)
+    if (auto barg = dyn_cast<BlockArgument>(v)) {
+      auto *parentOp = barg.getOwner()->getParentOp();
+
+      if (auto forOp = dyn_cast<scf::ForOp>(parentOp)) {
+        unsigned argNum = barg.getArgNumber();
+
+        // arg0 = induction variable (not an iter_arg)
+        if (argNum == 0)
+          return nullptr;
+
+        unsigned iterIdx = argNum - 1;
+        v = forOp.getInitArgs()[iterIdx];
+        continue;
+      }
+
+      return nullptr;
+    }
+
+    return nullptr;
+  }
+}
+
+static Operation *traceToVectorWriteLikeUserOperation(Value v) {
+  for (OpOperand &use : v.getUses()) {
+    Operation *user = use.getOwner();
+
+    // --- TERMINAL OPS ---
+    if (isa<vector::TransferWriteOp>(user) || isa<vector::StoreOp>(user)) {
+      return user;
+    }
+
+    if (isa<vector::ShapeCastOp, vector::ShuffleOp>(user)) {
+      return nullptr;
+    }
+
+    // --- SCF YIELD ---
+    if (auto yield = dyn_cast<scf::YieldOp>(user)) {
+      Operation *parent = yield->getParentOp();
+      unsigned idx = use.getOperandNumber();
+      if (auto *res =
+              traceToVectorWriteLikeUserOperation(parent->getResult(idx)))
+        return res;
+      continue;
+    }
+
+    // --- SCF FOR ---
+    if (auto forOp = dyn_cast<scf::ForOp>(user)) {
+      unsigned idx = use.getOperandNumber();
+      if (auto *res = traceToVectorWriteLikeUserOperation(forOp.getResult(idx)))
+        return res;
+      continue;
+    }
+
+    // --- GENERIC CASE ---
+    for (Value res : user->getResults()) {
+      if (auto *found = traceToVectorWriteLikeUserOperation(res))
+        return found;
+    }
+  }
+
+  return nullptr;
+}
+
+static Value collapseInnerDims(OpBuilder &builder, mlir::Location loc,
+                               Value input, int64_t firstDimToCollapse) {
+  ShapedType inputType = cast<ShapedType>(input.getType());
+  if (inputType.getRank() == 1)
+    return input;
+  SmallVector<ReassociationIndices> reassociation;
+  for (int64_t i = 0; i < firstDimToCollapse; ++i)
+    reassociation.push_back(ReassociationIndices{i});
+  ReassociationIndices collapsedIndices;
+  for (int64_t i = firstDimToCollapse; i < inputType.getRank(); ++i)
+    collapsedIndices.push_back(i);
+  reassociation.push_back(collapsedIndices);
+  return memref::CollapseShapeOp::create(builder, loc, input, reassociation);
+}
+
+static std::pair<Value, SmallVector<Value>> getSrcIndxValue(OpBuilder &rewriter,
+                                                            Location loc,
+                                                            Value operand,
+                                                            bool isNotAcc) {
+  Value srcBuff;
+  SmallVector<OpFoldResult> indexVals;
+  llvm::TypeSwitch<Operation *>(operand.getDefiningOp())
+      .Case<TransferReadOp, LoadOp>([&](auto readOp) {
+        indexVals = SmallVector<OpFoldResult>(readOp.getIndices().begin(),
+                                              readOp.getIndices().end());
+        srcBuff = readOp.getOperand(0);
+      });
+
+  if (isNotAcc) {
+    indexVals.pop_back();
+  }
+
+  SmallVector<Value> indices;
+  indices.reserve(indexVals.size());
+
+  for (OpFoldResult ofr : indexVals) {
+    indices.push_back(
+        mlir::getValueOrCreateConstantIndexOp(rewriter, loc, ofr));
+  }
+
+  if (isNotAcc) {
+    auto subviewType = cast<ShapedType>(srcBuff.getType());
+    auto subviewRank = subviewType.getRank();
+    srcBuff = collapseInnerDims(rewriter, loc, srcBuff, subviewRank - 2);
+  }
+  return {srcBuff, indices};
+}
+
+static unsigned getIndexPosition(Value operand, scf::ForOp loop) {
+  Value iv = loop.getInductionVar();
+
+  Value srcBuff;
+  SmallVector<OpFoldResult> indexVals;
+  llvm::TypeSwitch<Operation *>(operand.getDefiningOp())
+      .Case<TransferReadOp, LoadOp>([&](auto readOp) {
+        indexVals = SmallVector<OpFoldResult>(readOp.getIndices().begin(),
+                                              readOp.getIndices().end());
+        srcBuff = readOp.getOperand(0);
+      });
+
+  auto subview = srcBuff.getDefiningOp<memref::SubViewOp>();
+  if (!subview)
+    return 0;
+
+  auto offsets = subview.getOffsets();
+
+  for (auto it : llvm::enumerate(offsets)) {
+    if (it.value() == iv)
+      return it.index();
+  }
+
+  return 0;
+}
+
+static amx::TileLoadOp createTileLoads(OpBuilder &rewriter, Location loc,
+                                       Value operand, Value mat, Type ipType,
+                                       bool rhs) {
+
+  SmallVector<OpFoldResult> indexVals;
+  llvm::TypeSwitch<Operation *>(operand.getDefiningOp())
+      .Case<TransferReadOp, LoadOp>([&](auto readOp) {
+        indexVals = SmallVector<OpFoldResult>(readOp.getIndices().begin(),
+                                              readOp.getIndices().end());
+      });
+
+  indexVals.pop_back();
+  SmallVector<Value> indices;
+  indices.reserve(indexVals.size());
+
+  for (OpFoldResult ofr : indexVals) {
+    indices.push_back(
+        mlir::getValueOrCreateConstantIndexOp(rewriter, loc, ofr));
+  }
+
+  if (rhs) {
+    int offset = 4;
+    if (ipType.isBF16()) {
+      offset = 2;
+    }
+
+    auto c2 = arith::ConstantIndexOp::create(rewriter, loc, offset);
+    indices[indices.size() - 1] =
+        arith::MulIOp::create(rewriter, loc, indices[indices.size() - 1], c2);
+  }
+
+  amx::TileType tileType = amx::TileType::get({16, 64}, ipType);
+  if (ipType.isBF16()) {
+    tileType = amx::TileType::get({16, 32}, ipType);
+  }
+
+  auto load = amx::TileLoadOp::create(rewriter, loc, tileType, mat, indices);
+
+  return load;
+}
+
+static SmallVector<Value> createTiledDp(OpBuilder &rewriter, Location loc,
+                                        SmallVector<vector::ContractionOp> ops,
+                                        Value matA, Value matB, Type ipType,
+                                        Type opType, ValueRange accIterArgs) {
+  auto subviewType = cast<ShapedType>(matA.getType());
+  auto subviewRank = subviewType.getRank();
+  auto collapsedOpnd = collapseInnerDims(rewriter, loc, matA, subviewRank - 2);
+
+  auto subviewType1 = cast<ShapedType>(matB.getType());
+  auto subviewRank1 = subviewType1.getRank();
+  auto collapsedOpnd1 =
+      collapseInnerDims(rewriter, loc, matB, subviewRank1 - 2);
+
+  SmallVector<Value> accumulators;
+  llvm::DenseMap<Operation *, amx::TileLoadOp> readsToTileLoads;
+
+  for (size_t i = 0; i < ops.size(); i++) {
+
+    Operation *readOpLhs = ops[i].getLhs().getDefiningOp();
+
+    amx::TileLoadOp tilesLhs;
+
+    auto it = readsToTileLoads.find(readOpLhs);
+    if (it != readsToTileLoads.end()) {
+      tilesLhs = it->second;
+    } else {
+
+      tilesLhs = createTileLoads(rewriter, loc, ops[i].getLhs(), collapsedOpnd,
+                                 ipType, false);
+      readsToTileLoads.try_emplace(readOpLhs, tilesLhs);
+    }
+
+    Operation *readOpRhs = ops[i].getRhs().getDefiningOp();
+
+    amx::TileLoadOp tilesRhs;
+
+    auto it1 = readsToTileLoads.find(readOpRhs);
+    if (it1 != readsToTileLoads.end()) {
+      tilesRhs = it1->second;
+    } else {
+
+      tilesRhs = createTileLoads(rewriter, loc, ops[i].getRhs(), collapsedOpnd1,
+                                 ipType, true);
+      readsToTileLoads.try_emplace(readOpRhs, tilesRhs);
+    }
+
+    auto tileType1 = amx::TileType::get({16, 16}, opType);
+
+    Value dp;
+    if (ipType.isBF16())
+      dp = amx::TileMulFOp::create(rewriter, loc, tileType1, tilesLhs, tilesRhs,
+                                   accIterArgs[i]);
+
+    if (ipType.isSignlessInteger(8))
+      dp = amx::TileMulIOp::create(rewriter, loc, tileType1, tilesLhs, tilesRhs,
+                                   accIterArgs[i]);
+
+    accumulators.push_back(dp);
+  }
+  return accumulators;
+}
+
+static SmallVector<Value> createTileZeros(OpBuilder &rewriter, Location loc,
+                                          Type opType, scf::ForOp outerLoop,
+                                          int64_t size) {
+  rewriter.setInsertionPoint(outerLoop);
+
+  SmallVector<Value> loopItrArgs;
+  auto zeroTileType = amx::TileType::get({16, 16}, opType);
+
+  for (int i = 0; i < size; i++) {
+    auto zeroTile = amx::TileZeroOp::create(rewriter, loc, zeroTileType);
+    loopItrArgs.push_back(zeroTile);
+  }
+  return loopItrArgs;
+}
+
+//   vector.contract <1x1xf32>, <1x16xf32> into <1x16xf32>
+// ```
+// to
+// ```
+//   vector.broadcast %lhs to <16xf32>
+//   vector.fma vector<16xf32>
+// ```
+struct VectorContractToPackedTypeTiledDotProduct
+    : public OpRewritePattern<vector::ContractionOp> {
+  using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
+                                PatternRewriter &rewriter) const override {
+
+    if (contractOp.getKind() != vector::CombiningKind::ADD)
+      return rewriter.notifyMatchFailure(contractOp,
+                                         "Expects add combining kind.");
+
+    Operation *accReadOp =
+        traceToVectorReadLikeParentOperation(contractOp.getAcc());
+
+    Operation *resultWriteOp =
+        traceToVectorWriteLikeUserOperation(contractOp.getResult());
+
+    if (!accReadOp || !resultWriteOp)
+      return failure();
+
+    /*if (accReadOp->getBlock() == contractOp->getBlock())
+      return failure();
+
+    if (resultWriteOp->getBlock() == contractOp->getBlock())
+      return failure(); */
+
+    // case for just one vc rewrite.
+    if (accReadOp->getBlock() == contractOp->getBlock() &&
+        resultWriteOp->getBlock() == contractOp->getBlock()) {
+      Location loc = contractOp.getLoc();
+      auto tileType = amx::TileType::get({16, 32}, rewriter.getBF16Type());
+
+      auto [srcBuffLhs, indicesLhs] = getSrcIndxValue(
+          rewriter, contractOp.getLoc(), contractOp.getLhs(), true);
+      auto loadLhs = amx::TileLoadOp::create(rewriter, loc, tileType,
+                                             srcBuffLhs, indicesLhs);
+
+      auto [srcBuffRhs, indicesRhs] = getSrcIndxValue(
+          rewriter, contractOp.getLoc(), contractOp.getRhs(), true);
+      auto loadRhs = amx::TileLoadOp::create(rewriter, loc, tileType,
+                                             srcBuffRhs, indicesRhs);
+
+      auto [srcBuffAcc, indicesAcc] = getSrcIndxValue(
+          rewriter, contractOp.getLoc(), contractOp.getAcc(), false);
+      auto tileTypeAcc = amx::TileType::get({16, 16}, rewriter.getF32Type());
+      auto loadAcc = amx::TileLoadOp::create(rewriter, loc, tileTypeAcc,
+                                             srcBuffAcc, indicesAcc);
+
+      auto tdp = amx::TileMulFOp::create(rewriter, loc, tileTypeAcc, loadLhs,
+                                         loadRhs, loadAcc);
+      amx::TileStoreOp::create(rewriter, loc, srcBuffAcc, indicesAcc, tdp);
+
+      rewriter.eraseOp(resultWriteOp);
+      return success();
+    }
+
+    SmallVector<scf::ForOp> list;
+    Operation *current = contractOp;
+
+    while (true) {
+      Operation *parent = current->getParentOfType<scf::ForOp>();
+      list.push_back(dyn_cast<scf::ForOp>(parent));
+
+      if (accReadOp->getBlock() == parent->getBlock()) {
+        break;
+      }
+
+      current = parent;
+    }
+
+    if (list.size() > 2 || list.size() == 0)
+      return failure();
+
+    SmallVector<vector::ContractionOp> ops;
+
+    for (mlir::Operation &op : list[0].getBody()->getOperations()) {
+
+      if (auto contract = llvm::dyn_cast<mlir::vector::ContractionOp>(op)) {
+        ops.push_back(contract);
+      }
+    }
+
+    Type ipType;
+    Type opType;
+    VectorType lhsTy = contractOp.getLhsType();
+    if (lhsTy.getElementType().isBF16()) {
+      ipType = rewriter.getBF16Type();
+      opType = rewriter.getF32Type();
+    }
+
+    if (lhsTy.getElementType().isSignlessInteger(8)) {
+      ipType = rewriter.getIntegerType(8);
+      opType = rewriter.getIntegerType(32);
+    }
+
+    scf::ForOp outerLoop;
+    scf::ForOp innerLoop;
+
+    auto vectorReadOpLhs =
+        contractOp.getLhs().getDefiningOp<vector::TransferReadOp>();
+    auto vectorReadOpRhs =
+        contractOp.getRhs().getDefiningOp<vector::TransferReadOp>();
+
+    scf::ForOp newLoop;
+    if (list.size() == 2) {
+      outerLoop = list[1];
+      innerLoop = list[0];
+
+      SmallVector<Value> loopItrArgs = createTileZeros(
+          rewriter, outerLoop.getLoc(), opType, outerLoop, ops.size());
+
+      newLoop = scf::ForOp::create(
+          rewriter, outerLoop.getLoc(), outerLoop.getLowerBound(),
+          outerLoop.getUpperBound(), outerLoop.getStep(), loopItrArgs,
+          [&](OpBuilder &rewriterOuterLoop, Location locOuterLoop,
+              Value ivOuterLoop, ValueRange iterArgsOuterLoop) {
+            auto newInnerLoop = scf::ForOp::create(
+                rewriter, innerLoop.getLoc(), innerLoop.getLowerBound(),
+                innerLoop.getUpperBound(), innerLoop.getStep(),
+                iterArgsOuterLoop,
+                [&](OpBuilder &rewriterNewInnerLoop, Location locNewInnerLoop,
+                    Value ivNewInnerLoop, ValueRange iterArgsNewInnerLoop) {
+                  IRMapping mapping;
+                  mapping.map(
+                      vectorReadOpLhs.getBase().getDefiningOp()->getOperand(
+                          getIndexPosition(contractOp.getLhs(), outerLoop) + 1),
+                      ivOuterLoop);
+                  mapping.map(
+                      vectorReadOpLhs.getBase().getDefiningOp()->getOperand(
+                          getIndexPosition(contractOp.getLhs(), innerLoop) + 1),
+                      ivNewInnerLoop);
+                  auto lhsClone = rewriterNewInnerLoop.clone(
+                      *vectorReadOpLhs.getBase().getDefiningOp(), mapping);
+
+                  IRMapping rhsMapping;
+                  rhsMapping.map(
+                      vectorReadOpRhs.getBase().getDefiningOp()->getOperand(
+                          getIndexPosition(contractOp.getRhs(), outerLoop) + 1),
+                      ivOuterLoop);
+                  rhsMapping.map(
+                      vectorReadOpRhs.getBase().getDefiningOp()->getOperand(
+                          getIndexPosition(contractOp.getRhs(), innerLoop) + 1),
+                      ivNewInnerLoop);
+                  auto rhsClone = rewriterNewInnerLoop.clone(
+                      *vectorReadOpRhs.getBase().getDefiningOp(), rhsMapping);
+
+                  SmallVector<Value> accumulators = createTiledDp(
+                      rewriter, locNewInnerLoop, ops, lhsClone->getResult(0),
+                      rhsClone->getResult(0), ipType, opType,
+                      iterArgsNewInnerLoop);
+
+                  scf::YieldOp::create(rewriterNewInnerLoop, locNewInnerLoop,
+                                       accumulators);
+                });
+
+            scf::YieldOp::create(rewriterOuterLoop, locOuterLoop,
+                                 newInnerLoop.getResults());
+          });
+    }
+
+    if (list.size() == 1) {
+      outerLoop = list[0];
+
+      SmallVector<Value> loopItrArgs = createTileZeros(
+          rewriter, outerLoop.getLoc(), opType, outerLoop, ops.size());
+      newLoop = scf::ForOp::create(
+          rewriter, outerLoop.getLoc(), outerLoop.getLowerBound(),
+          outerLoop.getUpperBound(), outerLoop.getStep(), loopItrArgs,
+          [&](OpBuilder &rewriterOuterLoop, Location locOuterLoop,
+              Value ivOuterLoop, ValueRange iterArgsOuterLoop) {
+            IRMapping mapping;
+            mapping.map(
+                vectorReadOpLhs.getBase().getDefiningOp()->getOperand(
+                    getIndexPosition(contractOp.getLhs(), outerLoop) + 1),
+                ivOuterLoop);
+
+            auto lhsClone = rewriterOuterLoop.clone(
+                *vectorReadOpLhs.getBase().getDefiningOp(), mapping);
+
+            IRMapping rhsMapping;
+            rhsMapping.map(
+                vectorReadOpRhs.getBase().getDefiningOp()->getOperand(
+                    getIndexPosition(contractOp.getRhs(), outerLoop) + 1),
+                ivOuterLoop);
+
+            auto rhsClone = rewriterOuterLoop.clone(
+                *vectorReadOpRhs.getBase().getDefiningOp(), rhsMapping);
+
+            SmallVector<Value> accumulators = createTiledDp(
+                rewriter, locOuterLoop, ops, lhsClone->getResult(0),
+                rhsClone->getResult(0), ipType, opType, iterArgsOuterLoop);
+
+            scf::YieldOp::create(rewriterOuterLoop, locOuterLoop, accumulators);
+          });
+    }
+
+    // post processing after the loop
+    auto bufferType = MemRefType::get({16, 16}, opType);
+    auto bBuffer =
+        memref::AllocaOp::create(rewriter, outerLoop.getLoc(), bufferType);
+
+    SmallVector<Value> dps = newLoop.getResults();
+
+    for (size_t i = 0; i < ops.size(); i++) {
+      vector::ContractionOp contOp = ops[i];
+      Operation *resultWriteOp =
+          traceToVectorWriteLikeUserOperation(contOp.getResult());
+      rewriter.setInsertionPoint(resultWriteOp);
+
+      Value indexOp_0 =
+          arith::ConstantIndexOp::create(rewriter, outerLoop.getLoc(), 0);
+
+      amx::TileStoreOp::create(rewriter, outerLoop.getLoc(), bBuffer,
+                               ValueRange{indexOp_0, indexOp_0}, dps[i]);
+
+      auto c0 = arith::ConstantIndexOp::create(rewriter, outerLoop.getLoc(), 0);
+      auto one =
+          arith::ConstantIndexOp::create(rewriter, outerLoop.getLoc(), 1);
+      auto mBound =
+          arith::ConstantIndexOp::create(rewriter, outerLoop.getLoc(), 16);
+
+      scf::ForOp::create(
+          rewriter, outerLoop.getLoc(), c0, mBound, one, ValueRange{},
+          [&](OpBuilder &builder, Location loc, Value iv, ValueRange iterArgs) {
+            auto row1 = vector::LoadOp::create(rewriter, loc,
+                                               VectorType::get(16, opType),
+                                               bBuffer, ValueRange{iv, c0});
+
+            Operation *readOp1 =
+                traceToVectorReadLikeParentOperation(ops[i].getAcc());
+
+            Value srcBuff;
+            SmallVector<OpFoldResult> indexVals;
+            llvm::TypeSwitch<Operation *>(readOp1).Case<TransferReadOp, LoadOp>(
+                [&](auto readOp) {
+                  indexVals = SmallVector<OpFoldResult>(
+                      readOp.getIndices().begin(), readOp.getIndices().end());
+                  srcBuff = readOp.getOperand(0);
+                });
+
+            SmallVector<Value> indices;
+            indices.reserve(indexVals.size());
+
+            for (OpFoldResult ofr : indexVals) {
+              indices.push_back(
+                  mlir::getValueOrCreateConstantIndexOp(rewriter, loc, ofr));
+            }
+
+            Value sum = arith::AddIOp::create(builder, loc, iv, indices[0]);
+            indices[0] = sum;
+
+            auto row2 = vector::LoadOp::create(
+                rewriter, loc, VectorType::get(16, opType), srcBuff, indices);
+
+            Value addition;
+            if (ipType.isBF16())
+              addition = arith::AddFOp::create(rewriter, loc, row1, row2);
+
+            if (ipType.isSignlessInteger(8))
+              addition = arith::AddIOp::create(rewriter, loc, row1, row2);
+
+            vector::StoreOp::create(builder, loc, addition, srcBuff, indices);
+
+            scf::YieldOp::create(builder, outerLoop.getLoc());
+          });
+
+      rewriter.eraseOp(resultWriteOp);
+    }
+
+    return success();
+  }
+};
+
+} // namespace
+
+void amx::populateVectorContractToPackedTypeTiledDotProductPatterns(
+    RewritePatternSet &patterns) {
+  patterns.add<VectorContractToPackedTypeTiledDotProduct>(
+      patterns.getContext());
+}
diff --git a/mlir/lib/RegisterAllExtensions.cpp b/mlir/lib/RegisterAllExtensions.cpp
index 14cfb7b5ac352..dbd482ba85ee1 100644
--- a/mlir/lib/RegisterAllExtensions.cpp
+++ b/mlir/lib/RegisterAllExtensions.cpp
@@ -33,6 +33,7 @@
 #include "mlir/Conversion/UBToLLVM/UBToLLVM.h"
 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
 #include "mlir/Dialect/AMX/Transforms.h"
+#include "mlir/Dialect/AMX/TransformOps/AMXTransformOps.h"
 #include "mlir/Dialect/Affine/TransformOps/AffineTransformOps.h"
 #include "mlir/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.h"
 #include "mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.h"
@@ -97,6 +98,7 @@ void mlir::registerAllExtensions(DialectRegistry &registry) {
 
   // Register all transform dialect extensions.
   affine::registerTransformDialectExtension(registry);
+  amx::registerTransformDialectExtension(registry);
   bufferization::registerTransformDialectExtension(registry);
   dlti::registerTransformDialectExtension(registry);
   func::registerTransformDialectExtension(registry);
diff --git a/mlir/test/Dialect/AMX/vector-contract-to-tiled-dp.mlir b/mlir/test/Dialect/AMX/vector-contract-to-tiled-dp.mlir
new file mode 100644
index 0000000000000..081eb56a029f6
--- /dev/null
+++ b/mlir/test/Dialect/AMX/vector-contract-to-tiled-dp.mlir
@@ -0,0 +1,242 @@
+// RUN: mlir-opt %s -transform-interpreter -cse -split-input-file | FileCheck %s
+
+
+#map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d2, d4, d1)>
+#map1 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d4, d3, d1)>
+#map2 = affine_map<(d0, d1, d2, d3, d4) -> (d2, d3)>
+
+  module {
+    func.func @brgemm_amx(%arg0: memref<16x64x64x2xbf16>, %arg1: memref<16x64x128x2xbf16>, %arg2: memref<64x128xf32>) -> memref<64x128xf32> attributes {dlti.target_system_spec = #dlti.target_system_spec<"CPU" = #dlti.target_device_spec<"reg_gemm_unroll" = [16, 16, 16]>>} {
+      %0 = ub.poison : f32
+      %1 = ub.poison : bf16
+      %c0 = arith.constant 0 : index
+      %c64 = arith.constant 64 : index
+      %c128 = arith.constant 128 : index
+      %c16 = arith.constant 16 : index
+      %c32 = arith.constant 32 : index
+      %c1 = arith.constant 1 : index
+      scf.for %arg3 = %c0 to %c64 step %c32 {
+        scf.for %arg4 = %c0 to %c128 step %c32 {
+          %subview = memref.subview %arg2[%arg3, %arg4] [32, 32] [1, 1] : memref<64x128xf32> to memref<32x32xf32, strided<[128, 1], offset: ?>>
+          %2 = vector.transfer_read %subview[%c0, %c0], %0 {in_bounds = [true, true]} : memref<32x32xf32, strided<[128, 1], offset: ?>>, vector<16x16xf32>
+          %3 = vector.transfer_read %subview[%c0, %c16], %0 {in_bounds = [true, true]} : memref<32x32xf32, strided<[128, 1], offset: ?>>, vector<16x16xf32>
+          %4 = vector.transfer_read %subview[%c16, %c0], %0 {in_bounds = [true, true]} : memref<32x32xf32, strided<[128, 1], offset: ?>>, vector<16x16xf32>
+          %5 = vector.transfer_read %subview[%c16, %c16], %0 {in_bounds = [true, true]} : memref<32x32xf32, strided<[128, 1], offset: ?>>, vector<16x16xf32>
+          %6:4 = scf.for %arg5 = %c0 to %c16 step %c1 iter_args(%arg6 = %2, %arg7 = %3, %arg8 = %4, %arg9 = %5) -> (vector<16x16xf32>, vector<16x16xf32>, vector<16x16xf32>, vector<16x16xf32>) {
+            %7:4 = scf.for %arg10 = %c0 to %c64 step %c16 iter_args(%arg11 = %arg6, %arg12 = %arg7, %arg13 = %arg8, %arg14 = %arg9) -> (vector<16x16xf32>, vector<16x16xf32>, vector<16x16xf32>, vector<16x16xf32>) {
+              %subview_0 = memref.subview %arg0[%arg5, %arg3, %arg10, 0] [1, 32, 16, 2] [1, 1, 1, 1] : memref<16x64x64x2xbf16> to memref<1x32x16x2xbf16, strided<[8192, 128, 2, 1], offset: ?>>
+              %subview_1 = memref.subview %arg1[%arg5, %arg10, %arg4, 0] [1, 16, 32, 2] [1, 1, 1, 1] : memref<16x64x128x2xbf16> to memref<1x16x32x2xbf16, strided<[16384, 256, 2, 1], offset: ?>>
+              %8 = vector.transfer_read %subview_0[%c0, %c0, %c0, %c0], %1 {in_bounds = [true, true, true, true]} : memref<1x32x16x2xbf16, strided<[8192, 128, 2, 1], offset: ?>>, vector<1x16x16x2xbf16>
+              %9 = vector.transfer_read %subview_1[%c0, %c0, %c0, %c0], %1 {in_bounds = [true, true, true, true]} : memref<1x16x32x2xbf16, strided<[16384, 256, 2, 1], offset: ?>>, vector<1x16x16x2xbf16>
+              %10 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["reduction", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %8, %9, %arg11 {unroll_shape = array<i64: 1, 2, 16, 16, 16>} : vector<1x16x16x2xbf16>, vector<1x16x16x2xbf16> into vector<16x16xf32>
+              %11 = vector.transfer_read %subview_1[%c0, %c0, %c16, %c0], %1 {in_bounds = [true, true, true, true]} : memref<1x16x32x2xbf16, strided<[16384, 256, 2, 1], offset: ?>>, vector<1x16x16x2xbf16>
+              %12 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["reduction", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %8, %11, %arg12 {unroll_shape = array<i64: 1, 2, 16, 16, 16>} : vector<1x16x16x2xbf16>, vector<1x16x16x2xbf16> into vector<16x16xf32>
+              %13 = vector.transfer_read %subview_0[%c0, %c16, %c0, %c0], %1 {in_bounds = [true, true, true, true]} : memref<1x32x16x2xbf16, strided<[8192, 128, 2, 1], offset: ?>>, vector<1x16x16x2xbf16>
+              %14 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["reduction", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %13, %9, %arg13 {unroll_shape = array<i64: 1, 2, 16, 16, 16>} : vector<1x16x16x2xbf16>, vector<1x16x16x2xbf16> into vector<16x16xf32>
+              %15 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["reduction", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %13, %11, %arg14 {unroll_shape = array<i64: 1, 2, 16, 16, 16>} : vector<1x16x16x2xbf16>, vector<1x16x16x2xbf16> into vector<16x16xf32>
+              scf.yield %10, %12, %14, %15 : vector<16x16xf32>, vector<16x16xf32>, vector<16x16xf32>, vector<16x16xf32>
+            }
+            scf.yield %7#0, %7#1, %7#2, %7#3 : vector<16x16xf32>, vector<16x16xf32>, vector<16x16xf32>, vector<16x16xf32>
+          }
+          vector.transfer_write %6#3, %subview[%c16, %c16] {in_bounds = [true, true]} : vector<16x16xf32>, memref<32x32xf32, strided<[128, 1], offset: ?>>
+          vector.transfer_write %6#2, %subview[%c16, %c0] {in_bounds = [true, true]} : vector<16x16xf32>, memref<32x32xf32, strided<[128, 1], offset: ?>>
+          vector.transfer_write %6#1, %subview[%c0, %c16] {in_bounds = [true, true]} : vector<16x16xf32>, memref<32x32xf32, strided<[128, 1], offset: ?>>
+          vector.transfer_write %6#0, %subview[%c0, %c0] {in_bounds = [true, true]} : vector<16x16xf32>, memref<32x32xf32, strided<[128, 1], offset: ?>>
+        }
+      }
+      %alloc = memref.alloc() : memref<64x128xf32>
+      memref.copy %arg2, %alloc : memref<64x128xf32> to memref<64x128xf32>
+      return %alloc : memref<64x128xf32>
+    }
+  }
+
+// CHECK-LABEL: @brgemm_amx
+// CHECK: amx.tile_mulf
+// CHECK-NOT: vector.contract
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.amx.vector_contract_to_packed_type_tiled_dot_product
+    } : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+#map = affine_map<(d0, d1, d2, d3) -> (d1, d3, d0)>
+#map1 = affine_map<(d0, d1, d2, d3) -> (d3, d2, d0)>
+#map2 = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+
+  module {
+    func.func @batch_amx(%arg0: memref<64x64x2xbf16>, %arg1: memref<64x128x2xbf16>, %arg2: memref<64x128xf32>) -> memref<64x128xf32> attributes {dlti.target_system_spec = #dlti.target_system_spec<"CPU" = #dlti.target_device_spec<"reg_gemm_unroll" = [16, 16, 16]>>} {
+      %0 = ub.poison : f32
+      %1 = ub.poison : bf16
+      %c0 = arith.constant 0 : index
+      %c64 = arith.constant 64 : index
+      %c128 = arith.constant 128 : index
+      %c32 = arith.constant 32 : index
+      %c16 = arith.constant 16 : index
+      scf.for %arg3 = %c0 to %c64 step %c32 {
+        scf.for %arg4 = %c0 to %c128 step %c32 {
+          %subview = memref.subview %arg2[%arg3, %arg4] [32, 32] [1, 1] : memref<64x128xf32> to memref<32x32xf32, strided<[128, 1], offset: ?>>
+          %2 = vector.transfer_read %subview[%c0, %c0], %0 {in_bounds = [true, true]} : memref<32x32xf32, strided<[128, 1], offset: ?>>, vector<16x16xf32>
+          %3 = vector.transfer_read %subview[%c0, %c16], %0 {in_bounds = [true, true]} : memref<32x32xf32, strided<[128, 1], offset: ?>>, vector<16x16xf32>
+          %4 = vector.transfer_read %subview[%c16, %c0], %0 {in_bounds = [true, true]} : memref<32x32xf32, strided<[128, 1], offset: ?>>, vector<16x16xf32>
+          %5 = vector.transfer_read %subview[%c16, %c16], %0 {in_bounds = [true, true]} : memref<32x32xf32, strided<[128, 1], offset: ?>>, vector<16x16xf32>
+          %6:4 = scf.for %arg5 = %c0 to %c64 step %c16 iter_args(%arg6 = %2, %arg7 = %3, %arg8 = %4, %arg9 = %5) -> (vector<16x16xf32>, vector<16x16xf32>, vector<16x16xf32>, vector<16x16xf32>) {
+            %subview_0 = memref.subview %arg0[%arg3, %arg5, 0] [32, 16, 2] [1, 1, 1] : memref<64x64x2xbf16> to memref<32x16x2xbf16, strided<[128, 2, 1], offset: ?>>
+            %subview_1 = memref.subview %arg1[%arg5, %arg4, 0] [16, 32, 2] [1, 1, 1] : memref<64x128x2xbf16> to memref<16x32x2xbf16, strided<[256, 2, 1], offset: ?>>
+            %7 = vector.transfer_read %subview_0[%c0, %c0, %c0], %1 {in_bounds = [true, true, true]} : memref<32x16x2xbf16, strided<[128, 2, 1], offset: ?>>, vector<16x16x2xbf16>
+            %8 = vector.transfer_read %subview_0[%c16, %c0, %c0], %1 {in_bounds = [true, true, true]} : memref<32x16x2xbf16, strided<[128, 2, 1], offset: ?>>, vector<16x16x2xbf16>
+            %9 = vector.transfer_read %subview_1[%c0, %c0, %c0], %1 {in_bounds = [true, true, true]} : memref<16x32x2xbf16, strided<[256, 2, 1], offset: ?>>, vector<16x16x2xbf16>
+            %10 = vector.transfer_read %subview_1[%c0, %c16, %c0], %1 {in_bounds = [true, true, true]} : memref<16x32x2xbf16, strided<[256, 2, 1], offset: ?>>, vector<16x16x2xbf16>
+            %11 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %7, %9, %arg6 {unroll_shape = array<i64: 2, 16, 16, 16>} : vector<16x16x2xbf16>, vector<16x16x2xbf16> into vector<16x16xf32>
+            %12 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %7, %10, %arg7 {unroll_shape = array<i64: 2, 16, 16, 16>} : vector<16x16x2xbf16>, vector<16x16x2xbf16> into vector<16x16xf32>
+            %13 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %8, %9, %arg8 {unroll_shape = array<i64: 2, 16, 16, 16>} : vector<16x16x2xbf16>, vector<16x16x2xbf16> into vector<16x16xf32>
+            %14 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %8, %10, %arg9 {unroll_shape = array<i64: 2, 16, 16, 16>} : vector<16x16x2xbf16>, vector<16x16x2xbf16> into vector<16x16xf32>
+            scf.yield %11, %12, %13, %14 : vector<16x16xf32>, vector<16x16xf32>, vector<16x16xf32>, vector<16x16xf32>
+          }
+          vector.transfer_write %6#3, %subview[%c16, %c16] {in_bounds = [true, true]} : vector<16x16xf32>, memref<32x32xf32, strided<[128, 1], offset: ?>>
+          vector.transfer_write %6#2, %subview[%c16, %c0] {in_bounds = [true, true]} : vector<16x16xf32>, memref<32x32xf32, strided<[128, 1], offset: ?>>
+          vector.transfer_write %6#1, %subview[%c0, %c16] {in_bounds = [true, true]} : vector<16x16xf32>, memref<32x32xf32, strided<[128, 1], offset: ?>>
+          vector.transfer_write %6#0, %subview[%c0, %c0] {in_bounds = [true, true]} : vector<16x16xf32>, memref<32x32xf32, strided<[128, 1], offset: ?>>
+        }
+      }
+      %alloc = memref.alloc() : memref<64x128xf32>
+      memref.copy %arg2, %alloc : memref<64x128xf32> to memref<64x128xf32>
+      return %alloc : memref<64x128xf32>
+    }
+  }
+
+// CHECK-LABEL: @batch_amx
+// CHECK: amx.tile_mulf
+// CHECK-NOT: vector.contract
+
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.amx.vector_contract_to_packed_type_tiled_dot_product
+    } : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+#map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d2, d4, d1)>
+#map1 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d4, d3, d1)>
+#map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d2, d3)>
+
+  module {
+    func.func @matmul_amx(%arg0: memref<16x64x64x2xbf16>, %arg1: memref<16x64x128x2xbf16>, %arg2: memref<16x64x128xf32>) -> memref<16x64x128xf32> attributes {dlti.target_system_spec = #dlti.target_system_spec<"CPU" = #dlti.target_device_spec<"reg_gemm_unroll" = [16, 16, 16]>>} {
+      %0 = ub.poison : f32
+      %1 = ub.poison : bf16
+      %c0 = arith.constant 0 : index
+      %c64 = arith.constant 64 : index
+      %c128 = arith.constant 128 : index
+      %c16 = arith.constant 16 : index
+      %c32 = arith.constant 32 : index
+      %c1 = arith.constant 1 : index
+      scf.for %arg3 = %c0 to %c64 step %c32 {
+        scf.for %arg4 = %c0 to %c128 step %c32 {
+          scf.for %arg5 = %c0 to %c16 step %c1 {
+            %subview = memref.subview %arg2[%arg5, %arg3, %arg4] [1, 32, 32] [1, 1, 1] : memref<16x64x128xf32> to memref<1x32x32xf32, strided<[8192, 128, 1], offset: ?>>
+            %2 = vector.transfer_read %subview[%c0, %c0, %c0], %0 {in_bounds = [true, true, true]} : memref<1x32x32xf32, strided<[8192, 128, 1], offset: ?>>, vector<1x16x16xf32>
+            %3 = vector.transfer_read %subview[%c0, %c0, %c16], %0 {in_bounds = [true, true, true]} : memref<1x32x32xf32, strided<[8192, 128, 1], offset: ?>>, vector<1x16x16xf32>
+            %4 = vector.transfer_read %subview[%c0, %c16, %c0], %0 {in_bounds = [true, true, true]} : memref<1x32x32xf32, strided<[8192, 128, 1], offset: ?>>, vector<1x16x16xf32>
+            %5 = vector.transfer_read %subview[%c0, %c16, %c16], %0 {in_bounds = [true, true, true]} : memref<1x32x32xf32, strided<[8192, 128, 1], offset: ?>>, vector<1x16x16xf32>
+            %6:4 = scf.for %arg6 = %c0 to %c64 step %c16 iter_args(%arg7 = %2, %arg8 = %3, %arg9 = %4, %arg10 = %5) -> (vector<1x16x16xf32>, vector<1x16x16xf32>, vector<1x16x16xf32>, vector<1x16x16xf32>) {
+              %subview_0 = memref.subview %arg0[%arg5, %arg3, %arg6, 0] [1, 32, 16, 2] [1, 1, 1, 1] : memref<16x64x64x2xbf16> to memref<1x32x16x2xbf16, strided<[8192, 128, 2, 1], offset: ?>>
+              %subview_1 = memref.subview %arg1[%arg5, %arg6, %arg4, 0] [1, 16, 32, 2] [1, 1, 1, 1] : memref<16x64x128x2xbf16> to memref<1x16x32x2xbf16, strided<[16384, 256, 2, 1], offset: ?>>
+              %7 = vector.transfer_read %subview_0[%c0, %c0, %c0, %c0], %1 {in_bounds = [true, true, true, true]} : memref<1x32x16x2xbf16, strided<[8192, 128, 2, 1], offset: ?>>, vector<1x16x16x2xbf16>
+              %8 = vector.transfer_read %subview_0[%c0, %c16, %c0, %c0], %1 {in_bounds = [true, true, true, true]} : memref<1x32x16x2xbf16, strided<[8192, 128, 2, 1], offset: ?>>, vector<1x16x16x2xbf16>
+              %9 = vector.transfer_read %subview_1[%c0, %c0, %c0, %c0], %1 {in_bounds = [true, true, true, true]} : memref<1x16x32x2xbf16, strided<[16384, 256, 2, 1], offset: ?>>, vector<1x16x16x2xbf16>
+              %10 = vector.transfer_read %subview_1[%c0, %c0, %c16, %c0], %1 {in_bounds = [true, true, true, true]} : memref<1x16x32x2xbf16, strided<[16384, 256, 2, 1], offset: ?>>, vector<1x16x16x2xbf16>
+              %11 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %7, %9, %arg7 {unroll_shape = array<i64: 1, 2, 16, 16, 16>} : vector<1x16x16x2xbf16>, vector<1x16x16x2xbf16> into vector<1x16x16xf32>
+              %12 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %7, %10, %arg8 {unroll_shape = array<i64: 1, 2, 16, 16, 16>} : vector<1x16x16x2xbf16>, vector<1x16x16x2xbf16> into vector<1x16x16xf32>
+              %13 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %8, %9, %arg9 {unroll_shape = array<i64: 1, 2, 16, 16, 16>} : vector<1x16x16x2xbf16>, vector<1x16x16x2xbf16> into vector<1x16x16xf32>
+              %14 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %8, %10, %arg10 {unroll_shape = array<i64: 1, 2, 16, 16, 16>} : vector<1x16x16x2xbf16>, vector<1x16x16x2xbf16> into vector<1x16x16xf32>
+              scf.yield %11, %12, %13, %14 : vector<1x16x16xf32>, vector<1x16x16xf32>, vector<1x16x16xf32>, vector<1x16x16xf32>
+            }
+            vector.transfer_write %6#3, %subview[%c0, %c16, %c16] {in_bounds = [true, true, true]} : vector<1x16x16xf32>, memref<1x32x32xf32, strided<[8192, 128, 1], offset: ?>>
+            vector.transfer_write %6#2, %subview[%c0, %c16, %c0] {in_bounds = [true, true, true]} : vector<1x16x16xf32>, memref<1x32x32xf32, strided<[8192, 128, 1], offset: ?>>
+            vector.transfer_write %6#1, %subview[%c0, %c0, %c16] {in_bounds = [true, true, true]} : vector<1x16x16xf32>, memref<1x32x32xf32, strided<[8192, 128, 1], offset: ?>>
+            vector.transfer_write %6#0, %subview[%c0, %c0, %c0] {in_bounds = [true, true, true]} : vector<1x16x16xf32>, memref<1x32x32xf32, strided<[8192, 128, 1], offset: ?>>
+          }
+        }
+      }
+      %alloc = memref.alloc() : memref<16x64x128xf32>
+      memref.copy %arg2, %alloc : memref<16x64x128xf32> to memref<16x64x128xf32>
+      return %alloc : memref<16x64x128xf32>
+    }
+  }
+
+// CHECK-LABEL: @matmul_amx
+// CHECK: amx.tile_mulf
+// CHECK-NOT: vector.contract
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.amx.vector_contract_to_packed_type_tiled_dot_product
+    } : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+!vecA = vector<1x16x16x2xbf16>
+!vecB = vector<1x16x16x2xbf16>
+!vecC = vector<16x16xf32>
+!memrefA = memref<1x32x16x2xbf16>
+!memrefB = memref<1x16x32x2xbf16>
+!memrefC = memref<32x32xf32>
+#map = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d3, d4)>
+#map1 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d3, d2, d4)>
+#map2 = affine_map<(d0, d4, d1, d2, d3) -> (d1, d2)>
+func.func @amx(
+  %arg0: !memrefA, %arg1: !memrefB, %arg2: !memrefC) -> !memrefC
+{
+  %c0 = arith.constant 0 : index
+  %0 = ub.poison : bf16
+  %32 = ub.poison : f32
+
+  %1 = vector.transfer_read %arg0[%c0, %c0, %c0, %c0], %0 {in_bounds = [true, true, true, true]} :
+        !memrefA, !vecA
+  %2 = vector.transfer_read %arg1[%c0, %c0, %c0, %c0], %0 {in_bounds = [true, true, true, true]} :
+        !memrefB, !vecB
+
+  %3 = vector.transfer_read %arg2[%c0, %c0], %32 {in_bounds = [true, true]} : !memrefC, !vecC
+
+  %4 = vector.contract {
+    indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["reduction", "reduction", "parallel", "parallel", "reduction"],
+    kind = #vector.kind<add>}
+    %1, %2, %3 : !vecA, !vecB into !vecC
+
+  vector.transfer_write %4, %arg2[%c0, %c0] {in_bounds = [true, true]} : !vecC, !memrefC
+
+  return %arg2 : !memrefC
+}
+
+// CHECK-LABEL: @amx
+// CHECK: amx.tile_mulf
+// CHECK-NOT: vector.contract
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.amx.vector_contract_to_packed_type_tiled_dot_product
+    } : !transform.any_op
+    transform.yield
+  }
+}
+

>From 43b55930d0edf7e032ecce2738d2e610fc1c3122 Mon Sep 17 00:00:00 2001
From: Arun Thangamani <arun.thangamani at intel.com>
Date: Sun, 22 Feb 2026 22:55:44 -0800
Subject: [PATCH 2/7] fix code formatting errors

---
 mlir/lib/RegisterAllExtensions.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/lib/RegisterAllExtensions.cpp b/mlir/lib/RegisterAllExtensions.cpp
index dbd482ba85ee1..6f01088a0ac3b 100644
--- a/mlir/lib/RegisterAllExtensions.cpp
+++ b/mlir/lib/RegisterAllExtensions.cpp
@@ -32,8 +32,8 @@
 #include "mlir/Conversion/SCFToEmitC/SCFToEmitC.h"
 #include "mlir/Conversion/UBToLLVM/UBToLLVM.h"
 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
-#include "mlir/Dialect/AMX/Transforms.h"
 #include "mlir/Dialect/AMX/TransformOps/AMXTransformOps.h"
+#include "mlir/Dialect/AMX/Transforms.h"
 #include "mlir/Dialect/Affine/TransformOps/AffineTransformOps.h"
 #include "mlir/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.h"
 #include "mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.h"

>From 4097bfb53ce3ab19376a99c7e875753d53001811 Mon Sep 17 00:00:00 2001
From: Arun Thangamani <arun.thangamani at intel.com>
Date: Sun, 22 Feb 2026 23:27:54 -0800
Subject: [PATCH 3/7] replacing common fun with x86vector utils API

---
 ...torContractToPackedTypeTiledDotProduct.cpp | 96 ++-----------------
 1 file changed, 6 insertions(+), 90 deletions(-)

diff --git a/mlir/lib/Dialect/AMX/Transforms/VectorContractToPackedTypeTiledDotProduct.cpp b/mlir/lib/Dialect/AMX/Transforms/VectorContractToPackedTypeTiledDotProduct.cpp
index 37fd3bb33848b..faac5530373e9 100644
--- a/mlir/lib/Dialect/AMX/Transforms/VectorContractToPackedTypeTiledDotProduct.cpp
+++ b/mlir/lib/Dialect/AMX/Transforms/VectorContractToPackedTypeTiledDotProduct.cpp
@@ -11,6 +11,7 @@
 #include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/Dialect/Vector/Utils/VectorUtils.h"
+#include "mlir/Dialect/X86Vector/Utils/X86VectorUtils.h"
 
 #include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/Dominance.h"
@@ -26,85 +27,6 @@ using namespace mlir::amx;
 
 namespace {
 
-static Operation *traceToVectorReadLikeParentOperation(Value v) {
-  while (true) {
-    // Case 1: Value defined by an operation
-    if (Operation *defOp = v.getDefiningOp()) {
-      if (isa<vector::TransferReadOp, vector::LoadOp>(defOp)) {
-        return defOp;
-      }
-
-      if (isa<vector::ShapeCastOp, vector::ShuffleOp>(defOp)) {
-        return nullptr;
-      }
-
-      return nullptr;
-    }
-
-    // Case 2: BlockArgument (scf.for iter_arg)
-    if (auto barg = dyn_cast<BlockArgument>(v)) {
-      auto *parentOp = barg.getOwner()->getParentOp();
-
-      if (auto forOp = dyn_cast<scf::ForOp>(parentOp)) {
-        unsigned argNum = barg.getArgNumber();
-
-        // arg0 = induction variable (not an iter_arg)
-        if (argNum == 0)
-          return nullptr;
-
-        unsigned iterIdx = argNum - 1;
-        v = forOp.getInitArgs()[iterIdx];
-        continue;
-      }
-
-      return nullptr;
-    }
-
-    return nullptr;
-  }
-}
-
-static Operation *traceToVectorWriteLikeUserOperation(Value v) {
-  for (OpOperand &use : v.getUses()) {
-    Operation *user = use.getOwner();
-
-    // --- TERMINAL OPS ---
-    if (isa<vector::TransferWriteOp>(user) || isa<vector::StoreOp>(user)) {
-      return user;
-    }
-
-    if (isa<vector::ShapeCastOp, vector::ShuffleOp>(user)) {
-      return nullptr;
-    }
-
-    // --- SCF YIELD ---
-    if (auto yield = dyn_cast<scf::YieldOp>(user)) {
-      Operation *parent = yield->getParentOp();
-      unsigned idx = use.getOperandNumber();
-      if (auto *res =
-              traceToVectorWriteLikeUserOperation(parent->getResult(idx)))
-        return res;
-      continue;
-    }
-
-    // --- SCF FOR ---
-    if (auto forOp = dyn_cast<scf::ForOp>(user)) {
-      unsigned idx = use.getOperandNumber();
-      if (auto *res = traceToVectorWriteLikeUserOperation(forOp.getResult(idx)))
-        return res;
-      continue;
-    }
-
-    // --- GENERIC CASE ---
-    for (Value res : user->getResults()) {
-      if (auto *found = traceToVectorWriteLikeUserOperation(res))
-        return found;
-    }
-  }
-
-  return nullptr;
-}
-
 static Value collapseInnerDims(OpBuilder &builder, mlir::Location loc,
                                Value input, int64_t firstDimToCollapse) {
   ShapedType inputType = cast<ShapedType>(input.getType());
@@ -297,13 +219,6 @@ static SmallVector<Value> createTileZeros(OpBuilder &rewriter, Location loc,
   return loopItrArgs;
 }
 
-//   vector.contract <1x1xf32>, <1x16xf32> into <1x16xf32>
-// ```
-// to
-// ```
-//   vector.broadcast %lhs to <16xf32>
-//   vector.fma vector<16xf32>
-// ```
 struct VectorContractToPackedTypeTiledDotProduct
     : public OpRewritePattern<vector::ContractionOp> {
   using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
@@ -316,10 +231,10 @@ struct VectorContractToPackedTypeTiledDotProduct
                                          "Expects add combining kind.");
 
     Operation *accReadOp =
-        traceToVectorReadLikeParentOperation(contractOp.getAcc());
+        x86vector::traceToVectorReadLikeParentOperation(contractOp.getAcc());
 
     Operation *resultWriteOp =
-        traceToVectorWriteLikeUserOperation(contractOp.getResult());
+        x86vector::traceToVectorWriteLikeUserOperation(contractOp.getResult());
 
     if (!accReadOp || !resultWriteOp)
       return failure();
@@ -510,7 +425,7 @@ struct VectorContractToPackedTypeTiledDotProduct
     for (size_t i = 0; i < ops.size(); i++) {
       vector::ContractionOp contOp = ops[i];
       Operation *resultWriteOp =
-          traceToVectorWriteLikeUserOperation(contOp.getResult());
+          x86vector::traceToVectorWriteLikeUserOperation(contOp.getResult());
       rewriter.setInsertionPoint(resultWriteOp);
 
       Value indexOp_0 =
@@ -533,7 +448,8 @@ struct VectorContractToPackedTypeTiledDotProduct
                                                bBuffer, ValueRange{iv, c0});
 
             Operation *readOp1 =
-                traceToVectorReadLikeParentOperation(ops[i].getAcc());
+                x86vector::traceToVectorReadLikeParentOperation(
+                    ops[i].getAcc());
 
             Value srcBuff;
             SmallVector<OpFoldResult> indexVals;

>From 4f74ede48afb561423fe79204b35d6415539c97e Mon Sep 17 00:00:00 2001
From: Arun Thangamani <arun.thangamani at intel.com>
Date: Mon, 2 Mar 2026 23:45:37 -0800
Subject: [PATCH 4/7] clean-up + more test-cases

---
 mlir/include/mlir/Dialect/AMX/Transforms.h    |   7 +-
 ...torContractToPackedTypeTiledDotProduct.cpp | 515 +++++++----
 .../AMX/vector-contract-to-tiled-dp.mlir      | 867 +++++++++++++++---
 3 files changed, 1089 insertions(+), 300 deletions(-)

diff --git a/mlir/include/mlir/Dialect/AMX/Transforms.h b/mlir/include/mlir/Dialect/AMX/Transforms.h
index f95bf296c1555..50eab54ac58ab 100644
--- a/mlir/include/mlir/Dialect/AMX/Transforms.h
+++ b/mlir/include/mlir/Dialect/AMX/Transforms.h
@@ -30,10 +30,13 @@ void registerConvertAMXToLLVMInterface(DialectRegistry &registry);
 
 namespace amx {
 
+// A set of patterns for lowering 32-bit packed vector contraction operations
+// to their corresponding packed-type tiled dot-product operations, using
+// AMX ultimately targeting the relevant x86 LLVM intrinsics (e.g., BF16 and
+// Int8).
 void populateVectorContractToPackedTypeTiledDotProductPatterns(
     RewritePatternSet &patterns);
-
-}
+} // namespace amx
 
 } // namespace mlir
 
diff --git a/mlir/lib/Dialect/AMX/Transforms/VectorContractToPackedTypeTiledDotProduct.cpp b/mlir/lib/Dialect/AMX/Transforms/VectorContractToPackedTypeTiledDotProduct.cpp
index faac5530373e9..af4f3ef1c934a 100644
--- a/mlir/lib/Dialect/AMX/Transforms/VectorContractToPackedTypeTiledDotProduct.cpp
+++ b/mlir/lib/Dialect/AMX/Transforms/VectorContractToPackedTypeTiledDotProduct.cpp
@@ -27,25 +27,37 @@ using namespace mlir::amx;
 
 namespace {
 
+// Function to collapse the last two dimension (vnni and k) to help the
+// amx.tile_load to correctly load the packed element type.
 static Value collapseInnerDims(OpBuilder &builder, mlir::Location loc,
-                               Value input, int64_t firstDimToCollapse) {
+                               Value input) {
   ShapedType inputType = cast<ShapedType>(input.getType());
+  int64_t firstDimToCollapse = inputType.getRank() - 2;
+
   if (inputType.getRank() == 1)
     return input;
+
   SmallVector<ReassociationIndices> reassociation;
   for (int64_t i = 0; i < firstDimToCollapse; ++i)
     reassociation.push_back(ReassociationIndices{i});
+
   ReassociationIndices collapsedIndices;
   for (int64_t i = firstDimToCollapse; i < inputType.getRank(); ++i)
     collapsedIndices.push_back(i);
+
   reassociation.push_back(collapsedIndices);
   return memref::CollapseShapeOp::create(builder, loc, input, reassociation);
 }
 
-static std::pair<Value, SmallVector<Value>> getSrcIndxValue(OpBuilder &rewriter,
-                                                            Location loc,
-                                                            Value operand,
-                                                            bool isNotAcc) {
+// Get the MemRef source and offset index for the operands of
+// vector.contract.
+static FailureOr<std::pair<Value, SmallVector<Value>>>
+getSrcIndxValue(OpBuilder &rewriter, Location loc, Value operand,
+                bool isNotAcc) {
+  Operation *defOp = operand.getDefiningOp();
+  if (!defOp)
+    return failure();
+
   Value srcBuff;
   SmallVector<OpFoldResult> indexVals;
   llvm::TypeSwitch<Operation *>(operand.getDefiningOp())
@@ -55,6 +67,9 @@ static std::pair<Value, SmallVector<Value>> getSrcIndxValue(OpBuilder &rewriter,
         srcBuff = readOp.getOperand(0);
       });
 
+  if (!srcBuff)
+    return failure();
+
   if (isNotAcc) {
     indexVals.pop_back();
   }
@@ -68,24 +83,95 @@ static std::pair<Value, SmallVector<Value>> getSrcIndxValue(OpBuilder &rewriter,
   }
 
   if (isNotAcc) {
-    auto subviewType = cast<ShapedType>(srcBuff.getType());
-    auto subviewRank = subviewType.getRank();
-    srcBuff = collapseInnerDims(rewriter, loc, srcBuff, subviewRank - 2);
+    srcBuff = collapseInnerDims(rewriter, loc, srcBuff);
+  }
+
+  return std::make_pair(srcBuff, indices);
+}
+
+// Function to validate the vector.contract operation.
+static LogicalResult validateContractOps(OpBuilder &rewriter,
+                                         vector::ContractionOp contractOp,
+                                         unsigned int blockingFactor,
+                                         Value srcBuffLhs, Value srcBuffRhs,
+                                         bool srcValidate) {
+
+  if (srcValidate) {
+    // Get the MemRef buffer of LHS operand.
+    auto srcIndxLhs = getSrcIndxValue(rewriter, contractOp.getLoc(),
+                                      contractOp.getLhs(), false);
+    if (failed(srcIndxLhs))
+      return failure();
+    auto [buffLhs, indicesLhs] = *srcIndxLhs;
+
+    // Get the MemRef buffer of RHS operand.
+    auto srcIndxRhs = getSrcIndxValue(rewriter, contractOp.getLoc(),
+                                      contractOp.getRhs(), false);
+    if (failed(srcIndxRhs))
+      return failure();
+    auto [buffRhs, indicesRhs] = *srcIndxRhs;
+
+    // Return failure if the Memref buff didn't match.
+    if (buffLhs != srcBuffLhs)
+      return failure();
+
+    if (buffRhs != srcBuffRhs)
+      return failure();
   }
-  return {srcBuff, indices};
+
+  VectorType accTy = dyn_cast<VectorType>(contractOp.getAccType());
+  if (!accTy)
+    return failure();
+
+  // The Accumulator dims should be 16 or 1. Like <1x16x16> or <16x16>.
+  ArrayRef<int64_t> accShape = accTy.getShape();
+  llvm::SmallVector<int64_t> nonUnitDimAcc;
+  llvm::copy_if(accShape, std::back_inserter(nonUnitDimAcc),
+                [](int64_t dim) { return (dim != 16 && dim != 1); });
+
+  if (nonUnitDimAcc.size() != 0)
+    return failure();
+
+  // The LHS dims should be 16 or vnni or 1. Like <1x16x16x2> or
+  // <16x16x4>. The vnni dims should be 2 or 4.
+  VectorType lhsTy = contractOp.getLhsType();
+  ArrayRef<int64_t> lhsShape = lhsTy.getShape();
+  llvm::SmallVector<int64_t> nonUnitDimLhs;
+  llvm::copy_if(lhsShape, std::back_inserter(nonUnitDimLhs),
+                [](int64_t dim) { return (dim != 16 && dim != 1); });
+
+  if (nonUnitDimLhs.size() != 1)
+    return failure();
+
+  if (nonUnitDimLhs[0] != blockingFactor)
+    return failure();
+
+  // The RHS dims should be 16 or vnni or 1. Like <1x16x16x2> or
+  // <16x16x4>. The vnni dims should be 2 or 4.
+  VectorType rhsTy = contractOp.getRhsType();
+  ArrayRef<int64_t> rhsShape = rhsTy.getShape();
+  llvm::SmallVector<int64_t> nonUnitDimRhs;
+  llvm::copy_if(rhsShape, std::back_inserter(nonUnitDimRhs),
+                [](int64_t dim) { return (dim != 16 && dim != 1); });
+
+  if (nonUnitDimRhs.size() != 1)
+    return failure();
+
+  if (nonUnitDimRhs[0] != blockingFactor)
+    return failure();
+
+  return success();
 }
 
+// Returns the loop index position to get mapped during the
+// MemRef type clone.
 static unsigned getIndexPosition(Value operand, scf::ForOp loop) {
   Value iv = loop.getInductionVar();
 
   Value srcBuff;
-  SmallVector<OpFoldResult> indexVals;
   llvm::TypeSwitch<Operation *>(operand.getDefiningOp())
-      .Case<TransferReadOp, LoadOp>([&](auto readOp) {
-        indexVals = SmallVector<OpFoldResult>(readOp.getIndices().begin(),
-                                              readOp.getIndices().end());
-        srcBuff = readOp.getOperand(0);
-      });
+      .Case<TransferReadOp, LoadOp>(
+          [&](auto readOp) { srcBuff = readOp.getOperand(0); });
 
   auto subview = srcBuff.getDefiningOp<memref::SubViewOp>();
   if (!subview)
@@ -101,103 +187,76 @@ static unsigned getIndexPosition(Value operand, scf::ForOp loop) {
   return 0;
 }
 
+// Creates amx.tile_loads.
 static amx::TileLoadOp createTileLoads(OpBuilder &rewriter, Location loc,
                                        Value operand, Value mat, Type ipType,
-                                       bool rhs) {
-
-  SmallVector<OpFoldResult> indexVals;
-  llvm::TypeSwitch<Operation *>(operand.getDefiningOp())
-      .Case<TransferReadOp, LoadOp>([&](auto readOp) {
-        indexVals = SmallVector<OpFoldResult>(readOp.getIndices().begin(),
-                                              readOp.getIndices().end());
-      });
-
-  indexVals.pop_back();
-  SmallVector<Value> indices;
-  indices.reserve(indexVals.size());
+                                       bool rhs, unsigned int offset) {
 
-  for (OpFoldResult ofr : indexVals) {
-    indices.push_back(
-        mlir::getValueOrCreateConstantIndexOp(rewriter, loc, ofr));
-  }
+  auto srcIndx = getSrcIndxValue(rewriter, loc, operand, false);
+  auto [srcBuff, indices] = *srcIndx;
+  indices.pop_back();
 
   if (rhs) {
-    int offset = 4;
-    if (ipType.isBF16()) {
-      offset = 2;
-    }
-
-    auto c2 = arith::ConstantIndexOp::create(rewriter, loc, offset);
-    indices[indices.size() - 1] =
-        arith::MulIOp::create(rewriter, loc, indices[indices.size() - 1], c2);
-  }
-
-  amx::TileType tileType = amx::TileType::get({16, 64}, ipType);
-  if (ipType.isBF16()) {
-    tileType = amx::TileType::get({16, 32}, ipType);
+    auto cOffset = arith::ConstantIndexOp::create(rewriter, loc, offset);
+    indices[indices.size() - 1] = arith::MulIOp::create(
+        rewriter, loc, indices[indices.size() - 1], cOffset);
   }
 
+  amx::TileType tileType = amx::TileType::get({16, (16 * offset)}, ipType);
   auto load = amx::TileLoadOp::create(rewriter, loc, tileType, mat, indices);
-
   return load;
 }
 
+// Creates tiled amx dot-products.
 static SmallVector<Value> createTiledDp(OpBuilder &rewriter, Location loc,
                                         SmallVector<vector::ContractionOp> ops,
                                         Value matA, Value matB, Type ipType,
-                                        Type opType, ValueRange accIterArgs) {
-  auto subviewType = cast<ShapedType>(matA.getType());
-  auto subviewRank = subviewType.getRank();
-  auto collapsedOpnd = collapseInnerDims(rewriter, loc, matA, subviewRank - 2);
+                                        Type opType, ValueRange accIterArgs,
+                                        unsigned int offset) {
 
-  auto subviewType1 = cast<ShapedType>(matB.getType());
-  auto subviewRank1 = subviewType1.getRank();
-  auto collapsedOpnd1 =
-      collapseInnerDims(rewriter, loc, matB, subviewRank1 - 2);
+  auto subviewCollapseLhs = collapseInnerDims(rewriter, loc, matA);
+  auto subviewCollapseRhs = collapseInnerDims(rewriter, loc, matB);
 
   SmallVector<Value> accumulators;
+  // Stores the amx.tile_load operation vs it's equivalent vector tranfer_read
+  // or load operations.
   llvm::DenseMap<Operation *, amx::TileLoadOp> readsToTileLoads;
 
+  // Iterate over the contraction operations and compute the tiled dot-product.
   for (size_t i = 0; i < ops.size(); i++) {
 
     Operation *readOpLhs = ops[i].getLhs().getDefiningOp();
-
     amx::TileLoadOp tilesLhs;
-
-    auto it = readsToTileLoads.find(readOpLhs);
-    if (it != readsToTileLoads.end()) {
-      tilesLhs = it->second;
+    auto itLhs = readsToTileLoads.find(readOpLhs);
+    if (itLhs != readsToTileLoads.end()) {
+      tilesLhs = itLhs->second;
     } else {
-
-      tilesLhs = createTileLoads(rewriter, loc, ops[i].getLhs(), collapsedOpnd,
-                                 ipType, false);
+      tilesLhs = createTileLoads(rewriter, loc, ops[i].getLhs(),
+                                 subviewCollapseLhs, ipType, false, offset);
       readsToTileLoads.try_emplace(readOpLhs, tilesLhs);
     }
 
     Operation *readOpRhs = ops[i].getRhs().getDefiningOp();
-
     amx::TileLoadOp tilesRhs;
-
-    auto it1 = readsToTileLoads.find(readOpRhs);
-    if (it1 != readsToTileLoads.end()) {
-      tilesRhs = it1->second;
+    auto itRhs = readsToTileLoads.find(readOpRhs);
+    if (itRhs != readsToTileLoads.end()) {
+      tilesRhs = itRhs->second;
     } else {
-
-      tilesRhs = createTileLoads(rewriter, loc, ops[i].getRhs(), collapsedOpnd1,
-                                 ipType, true);
+      tilesRhs = createTileLoads(rewriter, loc, ops[i].getRhs(),
+                                 subviewCollapseRhs, ipType, true, offset);
       readsToTileLoads.try_emplace(readOpRhs, tilesRhs);
     }
 
-    auto tileType1 = amx::TileType::get({16, 16}, opType);
+    auto accTileType = amx::TileType::get({16, 16}, opType);
 
     Value dp;
     if (ipType.isBF16())
-      dp = amx::TileMulFOp::create(rewriter, loc, tileType1, tilesLhs, tilesRhs,
-                                   accIterArgs[i]);
+      dp = amx::TileMulFOp::create(rewriter, loc, accTileType, tilesLhs,
+                                   tilesRhs, accIterArgs[i]);
 
     if (ipType.isSignlessInteger(8))
-      dp = amx::TileMulIOp::create(rewriter, loc, tileType1, tilesLhs, tilesRhs,
-                                   accIterArgs[i]);
+      dp = amx::TileMulIOp::create(rewriter, loc, accTileType, tilesLhs,
+                                   tilesRhs, accIterArgs[i]);
 
     accumulators.push_back(dp);
   }
@@ -219,6 +278,23 @@ static SmallVector<Value> createTileZeros(OpBuilder &rewriter, Location loc,
   return loopItrArgs;
 }
 
+// Implements tiled dot-product operation for a vector.contract operation or a
+// sequence of vector.contracts inside the reduction loops.
+//
+// For example - for F32 type:
+// ```
+//   vector.transfer_read %arg0 {{.}*} : memref<16x32x4xi8>, vector<16x16x4xi8>
+//   vector.transfer_read %arg1 {{.}*} : memref<16x32x4xi8>, vector<16x16x4xi8>
+//   vector.contract <16x16x4xi8>, <16x16x4xi8> into <16x16xi32>
+//   vector.transfer_write arg2 {{.}*} : vector<16x16xi32>, memref<32x32xi32>
+// ```
+// to
+// ```
+//   amx.tile_load %arg0 {{.}*} : memref<16x32x4xi8> into !amx.tile<16x64xi8>
+//   amx.tile_load %arg1 {{.}*} : memref<16x32x4xi8> into !amx.tile<16x64xi8>
+//   amx.tile_muli !amx.tile<16x64xi8> -> !amx.tile<16x16xi32>
+//   amx.tile_store %arg2{{.}*} : memref<32x32xi32>, !amx.tile<16x16xi32>
+// ```
 struct VectorContractToPackedTypeTiledDotProduct
     : public OpRewritePattern<vector::ContractionOp> {
   using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
@@ -230,6 +306,32 @@ struct VectorContractToPackedTypeTiledDotProduct
       return rewriter.notifyMatchFailure(contractOp,
                                          "Expects add combining kind.");
 
+    unsigned int blockingFactor =
+        contractOp.getLhsType().getElementType().isBF16() ? 2 : 4;
+    bool isVnni = x86vector::isInVnniLayout(contractOp.getOperation(),
+                                            contractOp.getIndexingMapsArray(),
+                                            blockingFactor);
+
+    VectorType lhsTy = contractOp.getLhsType();
+    if (!lhsTy.getElementType().isBF16() &&
+        !lhsTy.getElementType().isSignlessInteger(8))
+      return rewriter.notifyMatchFailure(
+          contractOp, "Only BF16/Int8 lowering is supported.");
+
+    VectorType accTy = dyn_cast<VectorType>(contractOp.getAccType());
+    if (!accTy)
+      return rewriter.notifyMatchFailure(contractOp, "Wrong accmulator type.");
+
+    if ((lhsTy.getElementType().isBF16() && !accTy.getElementType().isF32()) ||
+        (lhsTy.getElementType().isSignlessInteger(8) &&
+         !accTy.getElementType().isSignlessInteger(32)))
+      return rewriter.notifyMatchFailure(contractOp,
+                                         "Only F32 for BF16 or Int32 for Int8 "
+                                         "accumulation type is supported.");
+    if (!isVnni)
+      return rewriter.notifyMatchFailure(
+          contractOp, "Only VNNI-packed inputs are supported.");
+
     Operation *accReadOp =
         x86vector::traceToVectorReadLikeParentOperation(contractOp.getAcc());
 
@@ -237,50 +339,110 @@ struct VectorContractToPackedTypeTiledDotProduct
         x86vector::traceToVectorWriteLikeUserOperation(contractOp.getResult());
 
     if (!accReadOp || !resultWriteOp)
-      return failure();
+      return rewriter.notifyMatchFailure(
+          contractOp, "The ACC operand of the vector.contract should be a "
+                      "transfer_read or a load. And, the result should be "
+                      "stored using transfer_write or store.");
 
-    /*if (accReadOp->getBlock() == contractOp->getBlock())
-      return failure();
+    Type ipType;
+    Type opType;
+
+    if (lhsTy.getElementType().isBF16()) {
+      ipType = rewriter.getBF16Type();
+      opType = rewriter.getF32Type();
+    }
+
+    if (lhsTy.getElementType().isSignlessInteger(8)) {
+      ipType = rewriter.getIntegerType(8);
+      opType = rewriter.getIntegerType(32);
+    }
 
-    if (resultWriteOp->getBlock() == contractOp->getBlock())
-      return failure(); */
+    if (accReadOp->getBlock() == contractOp->getBlock() &&
+        resultWriteOp->getBlock() != contractOp->getBlock())
+      return rewriter.notifyMatchFailure(
+          contractOp, "The accumulator store is in different block.");
+
+    if (accReadOp->getBlock() != contractOp->getBlock() &&
+        resultWriteOp->getBlock() == contractOp->getBlock())
+      return rewriter.notifyMatchFailure(
+          contractOp, "The accumulator read is in different block.");
 
-    // case for just one vc rewrite.
+    // Case 1: For just one VC rewrite. Where all accumulator read/write
+    // within the same block.
     if (accReadOp->getBlock() == contractOp->getBlock() &&
         resultWriteOp->getBlock() == contractOp->getBlock()) {
-      Location loc = contractOp.getLoc();
-      auto tileType = amx::TileType::get({16, 32}, rewriter.getBF16Type());
 
-      auto [srcBuffLhs, indicesLhs] = getSrcIndxValue(
-          rewriter, contractOp.getLoc(), contractOp.getLhs(), true);
+      LogicalResult validate = validateContractOps(
+          rewriter, contractOp, blockingFactor, Value(), Value(), false);
+
+      if (failed(validate))
+        return rewriter.notifyMatchFailure(
+            contractOp, "The contract operation doesn't satisfy the operands "
+                        "dimensions. M, N, and vnni dims are 16, 16, and 2/4. "
+                        "The rest dims should be 1.");
+
+      Location loc = contractOp.getLoc();
+      llvm::outs() << "Reaching-here1" << "\n";
+      auto srcIndxLhs = getSrcIndxValue(rewriter, contractOp.getLoc(),
+                                        contractOp.getLhs(), true);
+      llvm::outs() << "Reaching-here2" << "\n";
+      if (failed(srcIndxLhs))
+        return rewriter.notifyMatchFailure(contractOp,
+                                           "The LHS src is not a MemRef type.");
+      llvm::outs() << "Reaching-here3" << "\n";
+
+      auto [srcBuffLhs, indicesLhs] = *srcIndxLhs;
+
+      auto srcIndxRhs = getSrcIndxValue(rewriter, contractOp.getLoc(),
+                                        contractOp.getRhs(), true);
+      if (failed(srcIndxRhs))
+        return rewriter.notifyMatchFailure(contractOp,
+                                           "The RHS src is not a MemRef type.");
+      auto [srcBuffRhs, indicesRhs] = *srcIndxRhs;
+
+      auto srcIndxAcc = getSrcIndxValue(rewriter, contractOp.getLoc(),
+                                        contractOp.getAcc(), false);
+      if (failed(srcIndxAcc))
+        return rewriter.notifyMatchFailure(contractOp,
+                                           "The ACC src is not a MemRef type.");
+      auto [srcBuffAcc, indicesAcc] = *srcIndxAcc;
+
+      // amx.tile_loads
+      auto tileType = amx::TileType::get({16, (16 * blockingFactor)}, ipType);
       auto loadLhs = amx::TileLoadOp::create(rewriter, loc, tileType,
                                              srcBuffLhs, indicesLhs);
-
-      auto [srcBuffRhs, indicesRhs] = getSrcIndxValue(
-          rewriter, contractOp.getLoc(), contractOp.getRhs(), true);
       auto loadRhs = amx::TileLoadOp::create(rewriter, loc, tileType,
                                              srcBuffRhs, indicesRhs);
 
-      auto [srcBuffAcc, indicesAcc] = getSrcIndxValue(
-          rewriter, contractOp.getLoc(), contractOp.getAcc(), false);
-      auto tileTypeAcc = amx::TileType::get({16, 16}, rewriter.getF32Type());
+      auto tileTypeAcc = amx::TileType::get({16, 16}, opType);
       auto loadAcc = amx::TileLoadOp::create(rewriter, loc, tileTypeAcc,
                                              srcBuffAcc, indicesAcc);
 
-      auto tdp = amx::TileMulFOp::create(rewriter, loc, tileTypeAcc, loadLhs,
-                                         loadRhs, loadAcc);
-      amx::TileStoreOp::create(rewriter, loc, srcBuffAcc, indicesAcc, tdp);
+      // Tiled dot-product.
+      Value dp;
+      if (ipType.isBF16())
+        dp = amx::TileMulFOp::create(rewriter, loc, tileTypeAcc, loadLhs,
+                                     loadRhs, loadAcc);
+
+      if (ipType.isSignlessInteger(8))
+        dp = amx::TileMulIOp::create(rewriter, loc, tileTypeAcc, loadLhs,
+                                     loadRhs, loadAcc);
+
+      amx::TileStoreOp::create(rewriter, loc, srcBuffAcc, indicesAcc, dp);
 
       rewriter.eraseOp(resultWriteOp);
       return success();
     }
 
-    SmallVector<scf::ForOp> list;
+    // Case 2: The acc are passed as iter args through the reduction loop.
+    // We support, reduction loop depth until 2. TODO: Support for n-depth
+    // reduction loop.
+    SmallVector<scf::ForOp> loopLists;
     Operation *current = contractOp;
 
     while (true) {
       Operation *parent = current->getParentOfType<scf::ForOp>();
-      list.push_back(dyn_cast<scf::ForOp>(parent));
+      loopLists.push_back(dyn_cast<scf::ForOp>(parent));
 
       if (accReadOp->getBlock() == parent->getBlock()) {
         break;
@@ -289,43 +451,63 @@ struct VectorContractToPackedTypeTiledDotProduct
       current = parent;
     }
 
-    if (list.size() > 2 || list.size() == 0)
-      return failure();
+    if (loopLists.size() > 2 || loopLists.size() == 0)
+      return rewriter.notifyMatchFailure(
+          contractOp, "Rewrite is supported until reduction loop depth of 2.");
 
-    SmallVector<vector::ContractionOp> ops;
+    auto srcIndxLhs = getSrcIndxValue(rewriter, contractOp.getLoc(),
+                                      contractOp.getLhs(), false);
+    if (failed(srcIndxLhs))
+      return rewriter.notifyMatchFailure(contractOp,
+                                         "The LHS src is not a MemRef type.");
+    auto [srcBuffLhs, indicesLhs] = *srcIndxLhs;
 
-    for (mlir::Operation &op : list[0].getBody()->getOperations()) {
+    auto srcIndxRhs = getSrcIndxValue(rewriter, contractOp.getLoc(),
+                                      contractOp.getRhs(), false);
+    if (failed(srcIndxRhs))
+      return rewriter.notifyMatchFailure(contractOp,
+                                         "The RHS src is not a MemRef type.");
+    auto [srcBuffRhs, indicesRhs] = *srcIndxRhs;
+
+    Operation *vectorOpLhs;
+    llvm::TypeSwitch<Operation *>(contractOp.getLhs().getDefiningOp())
+        .Case<TransferReadOp, LoadOp>([&](auto readOp) {
+          vectorOpLhs = readOp.getBase().getDefiningOp();
+        });
+
+    Operation *vectorOpRhs;
+    llvm::TypeSwitch<Operation *>(contractOp.getRhs().getDefiningOp())
+        .Case<TransferReadOp, LoadOp>([&](auto readOp) {
+          vectorOpRhs = readOp.getBase().getDefiningOp();
+        });
+
+    // Retrive all the contaction operation within the loop.
+    SmallVector<vector::ContractionOp> ops;
+    for (mlir::Operation &op : loopLists[0].getBody()->getOperations()) {
 
       if (auto contract = llvm::dyn_cast<mlir::vector::ContractionOp>(op)) {
-        ops.push_back(contract);
-      }
-    }
 
-    Type ipType;
-    Type opType;
-    VectorType lhsTy = contractOp.getLhsType();
-    if (lhsTy.getElementType().isBF16()) {
-      ipType = rewriter.getBF16Type();
-      opType = rewriter.getF32Type();
-    }
+        LogicalResult validate = validateContractOps(
+            rewriter, contract, blockingFactor, srcBuffLhs, srcBuffRhs, true);
 
-    if (lhsTy.getElementType().isSignlessInteger(8)) {
-      ipType = rewriter.getIntegerType(8);
-      opType = rewriter.getIntegerType(32);
+        if (failed(validate))
+          return rewriter.notifyMatchFailure(
+              contractOp, "The associated contract operations doesn't satisfy "
+                          "the re-write conditions either the dimensions are "
+                          "wrong or MemRef source are different.");
+
+        ops.push_back(contract);
+      }
     }
 
     scf::ForOp outerLoop;
     scf::ForOp innerLoop;
 
-    auto vectorReadOpLhs =
-        contractOp.getLhs().getDefiningOp<vector::TransferReadOp>();
-    auto vectorReadOpRhs =
-        contractOp.getRhs().getDefiningOp<vector::TransferReadOp>();
-
     scf::ForOp newLoop;
-    if (list.size() == 2) {
-      outerLoop = list[1];
-      innerLoop = list[0];
+    // Case 2a: Reduction loop depth is 2.
+    if (loopLists.size() == 2) {
+      outerLoop = loopLists[1];
+      innerLoop = loopLists[0];
 
       SmallVector<Value> loopItrArgs = createTileZeros(
           rewriter, outerLoop.getLoc(), opType, outerLoop, ops.size());
@@ -343,32 +525,32 @@ struct VectorContractToPackedTypeTiledDotProduct
                     Value ivNewInnerLoop, ValueRange iterArgsNewInnerLoop) {
                   IRMapping mapping;
                   mapping.map(
-                      vectorReadOpLhs.getBase().getDefiningOp()->getOperand(
+                      vectorOpLhs->getOperand(
                           getIndexPosition(contractOp.getLhs(), outerLoop) + 1),
                       ivOuterLoop);
                   mapping.map(
-                      vectorReadOpLhs.getBase().getDefiningOp()->getOperand(
+                      vectorOpLhs->getOperand(
                           getIndexPosition(contractOp.getLhs(), innerLoop) + 1),
                       ivNewInnerLoop);
-                  auto lhsClone = rewriterNewInnerLoop.clone(
-                      *vectorReadOpLhs.getBase().getDefiningOp(), mapping);
+                  auto lhsClone =
+                      rewriterNewInnerLoop.clone(*vectorOpLhs, mapping);
 
                   IRMapping rhsMapping;
                   rhsMapping.map(
-                      vectorReadOpRhs.getBase().getDefiningOp()->getOperand(
+                      vectorOpRhs->getOperand(
                           getIndexPosition(contractOp.getRhs(), outerLoop) + 1),
                       ivOuterLoop);
                   rhsMapping.map(
-                      vectorReadOpRhs.getBase().getDefiningOp()->getOperand(
+                      vectorOpRhs->getOperand(
                           getIndexPosition(contractOp.getRhs(), innerLoop) + 1),
                       ivNewInnerLoop);
-                  auto rhsClone = rewriterNewInnerLoop.clone(
-                      *vectorReadOpRhs.getBase().getDefiningOp(), rhsMapping);
+                  auto rhsClone =
+                      rewriterNewInnerLoop.clone(*vectorOpRhs, rhsMapping);
 
                   SmallVector<Value> accumulators = createTiledDp(
                       rewriter, locNewInnerLoop, ops, lhsClone->getResult(0),
                       rhsClone->getResult(0), ipType, opType,
-                      iterArgsNewInnerLoop);
+                      iterArgsNewInnerLoop, blockingFactor);
 
                   scf::YieldOp::create(rewriterNewInnerLoop, locNewInnerLoop,
                                        accumulators);
@@ -379,8 +561,9 @@ struct VectorContractToPackedTypeTiledDotProduct
           });
     }
 
-    if (list.size() == 1) {
-      outerLoop = list[0];
+    // Case 2a: Reduction loop depth is 1.
+    if (loopLists.size() == 1) {
+      outerLoop = loopLists[0];
 
       SmallVector<Value> loopItrArgs = createTileZeros(
           rewriter, outerLoop.getLoc(), opType, outerLoop, ops.size());
@@ -391,37 +574,37 @@ struct VectorContractToPackedTypeTiledDotProduct
               Value ivOuterLoop, ValueRange iterArgsOuterLoop) {
             IRMapping mapping;
             mapping.map(
-                vectorReadOpLhs.getBase().getDefiningOp()->getOperand(
+                vectorOpLhs->getOperand(
                     getIndexPosition(contractOp.getLhs(), outerLoop) + 1),
                 ivOuterLoop);
 
-            auto lhsClone = rewriterOuterLoop.clone(
-                *vectorReadOpLhs.getBase().getDefiningOp(), mapping);
+            auto lhsClone = rewriterOuterLoop.clone(*vectorOpLhs, mapping);
 
             IRMapping rhsMapping;
             rhsMapping.map(
-                vectorReadOpRhs.getBase().getDefiningOp()->getOperand(
+                vectorOpRhs->getOperand(
                     getIndexPosition(contractOp.getRhs(), outerLoop) + 1),
                 ivOuterLoop);
 
-            auto rhsClone = rewriterOuterLoop.clone(
-                *vectorReadOpRhs.getBase().getDefiningOp(), rhsMapping);
+            auto rhsClone = rewriterOuterLoop.clone(*vectorOpRhs, rhsMapping);
 
             SmallVector<Value> accumulators = createTiledDp(
                 rewriter, locOuterLoop, ops, lhsClone->getResult(0),
-                rhsClone->getResult(0), ipType, opType, iterArgsOuterLoop);
+                rhsClone->getResult(0), ipType, opType, iterArgsOuterLoop,
+                blockingFactor);
 
             scf::YieldOp::create(rewriterOuterLoop, locOuterLoop, accumulators);
           });
     }
 
-    // post processing after the loop
+    // post processing after the loop creation.
+    // Copy the amx tile accumulation results to a MemRef buffer, add the
+    // initial accumulation value, and store back to the C-Matrix
     auto bufferType = MemRefType::get({16, 16}, opType);
     auto bBuffer =
         memref::AllocaOp::create(rewriter, outerLoop.getLoc(), bufferType);
 
     SmallVector<Value> dps = newLoop.getResults();
-
     for (size_t i = 0; i < ops.size(); i++) {
       vector::ContractionOp contOp = ops[i];
       Operation *resultWriteOp =
@@ -443,45 +626,47 @@ struct VectorContractToPackedTypeTiledDotProduct
       scf::ForOp::create(
           rewriter, outerLoop.getLoc(), c0, mBound, one, ValueRange{},
           [&](OpBuilder &builder, Location loc, Value iv, ValueRange iterArgs) {
-            auto row1 = vector::LoadOp::create(rewriter, loc,
-                                               VectorType::get(16, opType),
-                                               bBuffer, ValueRange{iv, c0});
+            auto resultAcc = vector::LoadOp::create(
+                rewriter, loc, VectorType::get(16, opType), bBuffer,
+                ValueRange{iv, c0});
 
-            Operation *readOp1 =
+            Operation *accReadOp =
                 x86vector::traceToVectorReadLikeParentOperation(
                     ops[i].getAcc());
 
-            Value srcBuff;
-            SmallVector<OpFoldResult> indexVals;
-            llvm::TypeSwitch<Operation *>(readOp1).Case<TransferReadOp, LoadOp>(
-                [&](auto readOp) {
-                  indexVals = SmallVector<OpFoldResult>(
-                      readOp.getIndices().begin(), readOp.getIndices().end());
-                  srcBuff = readOp.getOperand(0);
-                });
+            Value srcBuffAcc;
+            SmallVector<Value> indicesAcc;
 
-            SmallVector<Value> indices;
-            indices.reserve(indexVals.size());
+            llvm::TypeSwitch<Operation *>(accReadOp)
+                .Case<TransferReadOp, LoadOp>([&](auto readOp) {
+                  srcBuffAcc = readOp.getOperand(0);
 
-            for (OpFoldResult ofr : indexVals) {
-              indices.push_back(
-                  mlir::getValueOrCreateConstantIndexOp(rewriter, loc, ofr));
-            }
+                  auto indices = readOp.getIndices();
+                  indicesAcc.reserve(indices.size());
 
-            Value sum = arith::AddIOp::create(builder, loc, iv, indices[0]);
-            indices[0] = sum;
+                  llvm::transform(
+                      indices, std::back_inserter(indicesAcc),
+                      [&](OpFoldResult ofr) {
+                        return mlir::getValueOrCreateConstantIndexOp(rewriter,
+                                                                     loc, ofr);
+                      });
+                });
 
-            auto row2 = vector::LoadOp::create(
-                rewriter, loc, VectorType::get(16, opType), srcBuff, indices);
+            Value sum = arith::AddIOp::create(builder, loc, iv, indicesAcc[0]);
+            indicesAcc[0] = sum;
 
+            auto acc = vector::LoadOp::create(rewriter, loc,
+                                              VectorType::get(16, opType),
+                                              srcBuffAcc, indicesAcc);
             Value addition;
             if (ipType.isBF16())
-              addition = arith::AddFOp::create(rewriter, loc, row1, row2);
+              addition = arith::AddFOp::create(rewriter, loc, resultAcc, acc);
 
             if (ipType.isSignlessInteger(8))
-              addition = arith::AddIOp::create(rewriter, loc, row1, row2);
+              addition = arith::AddIOp::create(rewriter, loc, resultAcc, acc);
 
-            vector::StoreOp::create(builder, loc, addition, srcBuff, indices);
+            vector::StoreOp::create(builder, loc, addition, srcBuffAcc,
+                                    indicesAcc);
 
             scf::YieldOp::create(builder, outerLoop.getLoc());
           });
diff --git a/mlir/test/Dialect/AMX/vector-contract-to-tiled-dp.mlir b/mlir/test/Dialect/AMX/vector-contract-to-tiled-dp.mlir
index 081eb56a029f6..ccc1a0f0c9bdc 100644
--- a/mlir/test/Dialect/AMX/vector-contract-to-tiled-dp.mlir
+++ b/mlir/test/Dialect/AMX/vector-contract-to-tiled-dp.mlir
@@ -1,57 +1,311 @@
 // RUN: mlir-opt %s -transform-interpreter -cse -split-input-file | FileCheck %s
 
+!vecA = vector<1x16x16x4xi8>
+!vecB = vector<1x16x16x4xi8>
+!vecC = vector<16x16xi32>
+!memrefA = memref<1x32x16x4xi8>
+!memrefB = memref<1x16x32x4xi8>
+!memrefC = memref<32x32xi32>
+#map = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d3, d4)>
+#map1 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d3, d2, d4)>
+#map2 = affine_map<(d0, d4, d1, d2, d3) -> (d1, d2)>
+func.func @brgemm_int8(
+  %arg0: !memrefA, %arg1: !memrefB, %arg2: !memrefC) -> !memrefC
+{
+  %c0 = arith.constant 0 : index
+  %0 = ub.poison : i8
+  %32 = ub.poison : i32
+
+  %1 = vector.transfer_read %arg0[%c0, %c0, %c0, %c0], %0 {in_bounds = [true, true, true, true]} :
+        !memrefA, !vecA
+  %2 = vector.transfer_read %arg1[%c0, %c0, %c0, %c0], %0 {in_bounds = [true, true, true, true]} :
+        !memrefB, !vecB
+
+  %3 = vector.transfer_read %arg2[%c0, %c0], %32 {in_bounds = [true, true]} : !memrefC, !vecC
+
+  %4 = vector.contract {
+    indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["reduction", "reduction", "parallel", "parallel", "reduction"],
+    kind = #vector.kind<add>}
+    %1, %2, %3 : !vecA, !vecB into !vecC
+
+  vector.transfer_write %4, %arg2[%c0, %c0] {in_bounds = [true, true]} : !vecC, !memrefC
+
+  return %arg2 : !memrefC
+}
+
+// CHECK-LABEL: @brgemm_int8
+// CHECK: amx.tile_load {{.*}} !amx.tile<16x64xi8>
+// CHECK: amx.tile_load {{.*}} !amx.tile<16x64xi8>
+// CHECK: amx.tile_load {{.*}} !amx.tile<16x16xi32>
+// CHECK: amx.tile_muli
+// CHECK: amx.tile_store {{.*}} !amx.tile<16x16xi32>
+// CHECK-NOT: vector.contract
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.amx.vector_contract_to_packed_type_tiled_dot_product
+    } : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+!vecA = vector<16x16x4xi8>
+!vecB = vector<16x16x4xi8>
+!vecC = vector<16x16xi32>
+!memrefA = memref<32x16x4xi8>
+!memrefB = memref<16x32x4xi8>
+!memrefC = memref<32x32xi32>
+#map = affine_map<(d4, d1, d2, d3) -> (d1, d3, d4)>
+#map1 = affine_map<(d4, d1, d2, d3) -> (d3, d2, d4)>
+#map2 = affine_map<(d4, d1, d2, d3) -> (d1, d2)>
+func.func @matmul_int8(
+  %arg0: !memrefA, %arg1: !memrefB, %arg2: !memrefC) -> !memrefC
+{
+  %c0 = arith.constant 0 : index
+  %0 = ub.poison : i8
+  %32 = ub.poison : i32
+
+  %1 = vector.transfer_read %arg0[%c0, %c0, %c0], %0 {in_bounds = [true, true, true]} :
+        !memrefA, !vecA
+  %2 = vector.transfer_read %arg1[%c0, %c0, %c0], %0 {in_bounds = [true, true, true]} :
+        !memrefB, !vecB
+
+  %3 = vector.transfer_read %arg2[%c0, %c0], %32 {in_bounds = [true, true]} : !memrefC, !vecC
+
+  %4 = vector.contract {
+    indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["reduction", "parallel", "parallel", "reduction"],
+    kind = #vector.kind<add>}
+    %1, %2, %3 : !vecA, !vecB into !vecC
+
+  vector.transfer_write %4, %arg2[%c0, %c0] {in_bounds = [true, true]} : !vecC, !memrefC
+
+  return %arg2 : !memrefC
+}
+
+// CHECK-LABEL: @matmul_int8
+// CHECK: amx.tile_load {{.*}} !amx.tile<16x64xi8>
+// CHECK: amx.tile_load {{.*}} !amx.tile<16x64xi8>
+// CHECK: amx.tile_load {{.*}} !amx.tile<16x16xi32>
+// CHECK: amx.tile_muli
+// CHECK: amx.tile_store {{.*}} !amx.tile<16x16xi32>
+// CHECK-NOT: vector.contract
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.amx.vector_contract_to_packed_type_tiled_dot_product
+    } : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+!vecA = vector<1x16x16x2xbf16>
+!vecB = vector<1x16x16x2xbf16>
+!vecC = vector<16x16xf32>
+!memrefA = memref<1x32x16x2xbf16>
+!memrefB = memref<1x16x32x2xbf16>
+!memrefC = memref<32x32xf32>
+#map = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d3, d4)>
+#map1 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d3, d2, d4)>
+#map2 = affine_map<(d0, d4, d1, d2, d3) -> (d1, d2)>
+func.func @brgemm_bf16(
+  %arg0: !memrefA, %arg1: !memrefB, %arg2: !memrefC) -> !memrefC
+{
+  %c0 = arith.constant 0 : index
+  %0 = ub.poison : bf16
+  %32 = ub.poison : f32
+
+  %1 = vector.transfer_read %arg0[%c0, %c0, %c0, %c0], %0 {in_bounds = [true, true, true, true]} :
+        !memrefA, !vecA
+  %2 = vector.transfer_read %arg1[%c0, %c0, %c0, %c0], %0 {in_bounds = [true, true, true, true]} :
+        !memrefB, !vecB
+
+  %3 = vector.transfer_read %arg2[%c0, %c0], %32 {in_bounds = [true, true]} : !memrefC, !vecC
+
+  %4 = vector.contract {
+    indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["reduction", "reduction", "parallel", "parallel", "reduction"],
+    kind = #vector.kind<add>}
+    %1, %2, %3 : !vecA, !vecB into !vecC
+
+  vector.transfer_write %4, %arg2[%c0, %c0] {in_bounds = [true, true]} : !vecC, !memrefC
+
+  return %arg2 : !memrefC
+}
+
+// CHECK-LABEL: @brgemm_bf16
+// CHECK: amx.tile_load {{.*}} !amx.tile<16x32xbf16>
+// CHECK: amx.tile_load {{.*}} !amx.tile<16x32xbf16>
+// CHECK: amx.tile_load {{.*}} !amx.tile<16x16xf32>
+// CHECK: amx.tile_mulf
+// CHECK: amx.tile_store {{.*}} !amx.tile<16x16xf32>
+// CHECK-NOT: vector.contract
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.amx.vector_contract_to_packed_type_tiled_dot_product
+    } : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+!vecA = vector<1x16x16x2xbf16>
+!vecB = vector<1x16x16x2xbf16>
+!vecC = vector<1x16x16xf32>
+!memrefA = memref<1x32x16x2xbf16>
+!memrefB = memref<1x16x32x2xbf16>
+!memrefC = memref<1x32x32xf32>
+#map = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d3, d4)>
+#map1 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d3, d2, d4)>
+#map2 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d2)>
+func.func @batch_matmul_bf16(
+  %arg0: !memrefA, %arg1: !memrefB, %arg2: !memrefC) -> !memrefC
+{
+  %c0 = arith.constant 0 : index
+  %0 = ub.poison : bf16
+  %32 = ub.poison : f32
+
+  %1 = vector.transfer_read %arg0[%c0, %c0, %c0, %c0], %0 {in_bounds = [true, true, true, true]} :
+        !memrefA, !vecA
+  %2 = vector.transfer_read %arg1[%c0, %c0, %c0, %c0], %0 {in_bounds = [true, true, true, true]} :
+        !memrefB, !vecB
+
+  %3 = vector.transfer_read %arg2[%c0, %c0, %c0], %32 {in_bounds = [true, true, true]} : !memrefC, !vecC
+
+  %4 = vector.contract {
+    indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["parallel", "reduction", "parallel", "parallel", "reduction"],
+    kind = #vector.kind<add>}
+    %1, %2, %3 : !vecA, !vecB into !vecC
+
+  vector.transfer_write %4, %arg2[%c0, %c0, %c0] {in_bounds = [true, true, true]} : !vecC, !memrefC
+
+  return %arg2 : !memrefC
+}
+
+// CHECK-LABEL: @batch_matmul_bf16
+// CHECK: amx.tile_load {{.*}} !amx.tile<16x32xbf16>
+// CHECK: amx.tile_load {{.*}} !amx.tile<16x32xbf16>
+// CHECK: amx.tile_load {{.*}} !amx.tile<16x16xf32>
+// CHECK: amx.tile_mulf
+// CHECK: amx.tile_store {{.*}} !amx.tile<16x16xf32>
+// CHECK-NOT: vector.contract
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.amx.vector_contract_to_packed_type_tiled_dot_product
+    } : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+!vecAB = vector<1x16x16x2xbf16>
+!vecC = vector<16x16xf32>
+!memrefA = memref<1x32x16x2xbf16, strided<[8192, 128, 2, 1], offset: ?>>
+!memrefB = memref<1x16x32x2xbf16, strided<[16384, 256, 2, 1], offset: ?>>
+!memrefC = memref<32x32xf32, strided<[128, 1], offset: ?>>
 
 #map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d2, d4, d1)>
 #map1 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d4, d3, d1)>
 #map2 = affine_map<(d0, d1, d2, d3, d4) -> (d2, d3)>
 
-  module {
-    func.func @brgemm_amx(%arg0: memref<16x64x64x2xbf16>, %arg1: memref<16x64x128x2xbf16>, %arg2: memref<64x128xf32>) -> memref<64x128xf32> attributes {dlti.target_system_spec = #dlti.target_system_spec<"CPU" = #dlti.target_device_spec<"reg_gemm_unroll" = [16, 16, 16]>>} {
-      %0 = ub.poison : f32
-      %1 = ub.poison : bf16
-      %c0 = arith.constant 0 : index
-      %c64 = arith.constant 64 : index
-      %c128 = arith.constant 128 : index
-      %c16 = arith.constant 16 : index
-      %c32 = arith.constant 32 : index
-      %c1 = arith.constant 1 : index
-      scf.for %arg3 = %c0 to %c64 step %c32 {
-        scf.for %arg4 = %c0 to %c128 step %c32 {
-          %subview = memref.subview %arg2[%arg3, %arg4] [32, 32] [1, 1] : memref<64x128xf32> to memref<32x32xf32, strided<[128, 1], offset: ?>>
-          %2 = vector.transfer_read %subview[%c0, %c0], %0 {in_bounds = [true, true]} : memref<32x32xf32, strided<[128, 1], offset: ?>>, vector<16x16xf32>
-          %3 = vector.transfer_read %subview[%c0, %c16], %0 {in_bounds = [true, true]} : memref<32x32xf32, strided<[128, 1], offset: ?>>, vector<16x16xf32>
-          %4 = vector.transfer_read %subview[%c16, %c0], %0 {in_bounds = [true, true]} : memref<32x32xf32, strided<[128, 1], offset: ?>>, vector<16x16xf32>
-          %5 = vector.transfer_read %subview[%c16, %c16], %0 {in_bounds = [true, true]} : memref<32x32xf32, strided<[128, 1], offset: ?>>, vector<16x16xf32>
-          %6:4 = scf.for %arg5 = %c0 to %c16 step %c1 iter_args(%arg6 = %2, %arg7 = %3, %arg8 = %4, %arg9 = %5) -> (vector<16x16xf32>, vector<16x16xf32>, vector<16x16xf32>, vector<16x16xf32>) {
-            %7:4 = scf.for %arg10 = %c0 to %c64 step %c16 iter_args(%arg11 = %arg6, %arg12 = %arg7, %arg13 = %arg8, %arg14 = %arg9) -> (vector<16x16xf32>, vector<16x16xf32>, vector<16x16xf32>, vector<16x16xf32>) {
-              %subview_0 = memref.subview %arg0[%arg5, %arg3, %arg10, 0] [1, 32, 16, 2] [1, 1, 1, 1] : memref<16x64x64x2xbf16> to memref<1x32x16x2xbf16, strided<[8192, 128, 2, 1], offset: ?>>
-              %subview_1 = memref.subview %arg1[%arg5, %arg10, %arg4, 0] [1, 16, 32, 2] [1, 1, 1, 1] : memref<16x64x128x2xbf16> to memref<1x16x32x2xbf16, strided<[16384, 256, 2, 1], offset: ?>>
-              %8 = vector.transfer_read %subview_0[%c0, %c0, %c0, %c0], %1 {in_bounds = [true, true, true, true]} : memref<1x32x16x2xbf16, strided<[8192, 128, 2, 1], offset: ?>>, vector<1x16x16x2xbf16>
-              %9 = vector.transfer_read %subview_1[%c0, %c0, %c0, %c0], %1 {in_bounds = [true, true, true, true]} : memref<1x16x32x2xbf16, strided<[16384, 256, 2, 1], offset: ?>>, vector<1x16x16x2xbf16>
-              %10 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["reduction", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %8, %9, %arg11 {unroll_shape = array<i64: 1, 2, 16, 16, 16>} : vector<1x16x16x2xbf16>, vector<1x16x16x2xbf16> into vector<16x16xf32>
-              %11 = vector.transfer_read %subview_1[%c0, %c0, %c16, %c0], %1 {in_bounds = [true, true, true, true]} : memref<1x16x32x2xbf16, strided<[16384, 256, 2, 1], offset: ?>>, vector<1x16x16x2xbf16>
-              %12 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["reduction", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %8, %11, %arg12 {unroll_shape = array<i64: 1, 2, 16, 16, 16>} : vector<1x16x16x2xbf16>, vector<1x16x16x2xbf16> into vector<16x16xf32>
-              %13 = vector.transfer_read %subview_0[%c0, %c16, %c0, %c0], %1 {in_bounds = [true, true, true, true]} : memref<1x32x16x2xbf16, strided<[8192, 128, 2, 1], offset: ?>>, vector<1x16x16x2xbf16>
-              %14 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["reduction", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %13, %9, %arg13 {unroll_shape = array<i64: 1, 2, 16, 16, 16>} : vector<1x16x16x2xbf16>, vector<1x16x16x2xbf16> into vector<16x16xf32>
-              %15 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["reduction", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %13, %11, %arg14 {unroll_shape = array<i64: 1, 2, 16, 16, 16>} : vector<1x16x16x2xbf16>, vector<1x16x16x2xbf16> into vector<16x16xf32>
-              scf.yield %10, %12, %14, %15 : vector<16x16xf32>, vector<16x16xf32>, vector<16x16xf32>, vector<16x16xf32>
-            }
-            scf.yield %7#0, %7#1, %7#2, %7#3 : vector<16x16xf32>, vector<16x16xf32>, vector<16x16xf32>, vector<16x16xf32>
-          }
-          vector.transfer_write %6#3, %subview[%c16, %c16] {in_bounds = [true, true]} : vector<16x16xf32>, memref<32x32xf32, strided<[128, 1], offset: ?>>
-          vector.transfer_write %6#2, %subview[%c16, %c0] {in_bounds = [true, true]} : vector<16x16xf32>, memref<32x32xf32, strided<[128, 1], offset: ?>>
-          vector.transfer_write %6#1, %subview[%c0, %c16] {in_bounds = [true, true]} : vector<16x16xf32>, memref<32x32xf32, strided<[128, 1], offset: ?>>
-          vector.transfer_write %6#0, %subview[%c0, %c0] {in_bounds = [true, true]} : vector<16x16xf32>, memref<32x32xf32, strided<[128, 1], offset: ?>>
+func.func @brgemm_bf16_loop(%arg0: memref<16x64x64x2xbf16>, %arg1: memref<16x64x128x2xbf16>, %arg2: memref<64x128xf32>) -> memref<64x128xf32>  {
+  %0 = ub.poison : f32
+  %1 = ub.poison : bf16
+  %c0 = arith.constant 0 : index
+  %c64 = arith.constant 64 : index
+  %c128 = arith.constant 128 : index
+  %c16 = arith.constant 16 : index
+  %c32 = arith.constant 32 : index
+  %c1 = arith.constant 1 : index
+
+  scf.for %arg3 = %c0 to %c64 step %c32 {
+    scf.for %arg4 = %c0 to %c128 step %c32 {
+
+      %subview = memref.subview %arg2[%arg3, %arg4] [32, 32] [1, 1] :
+                memref<64x128xf32> to !memrefC
+      %2 = vector.transfer_read %subview[%c0, %c0], %0 {in_bounds = [true, true]} :
+                !memrefC, !vecC
+      %3 = vector.transfer_read %subview[%c0, %c16], %0 {in_bounds = [true, true]} :
+                !memrefC, !vecC
+      %4 = vector.transfer_read %subview[%c16, %c0], %0 {in_bounds = [true, true]} :
+                !memrefC, !vecC
+      %5 = vector.transfer_read %subview[%c16, %c16], %0 {in_bounds = [true, true]} :
+                !memrefC, !vecC
+
+      %6:4 = scf.for %arg5 = %c0 to %c16 step %c1 iter_args(%arg6 = %2, %arg7 = %3, %arg8 = %4, %arg9 = %5) -> (!vecC, !vecC, !vecC, !vecC) {
+        %7:4 = scf.for %arg10 = %c0 to %c64 step %c16 iter_args(%arg11 = %arg6, %arg12 = %arg7, %arg13 = %arg8, %arg14 = %arg9) -> (!vecC, !vecC, !vecC, !vecC) {
+
+          %subview_0 = memref.subview %arg0[%arg5, %arg3, %arg10, 0] [1, 32, 16, 2] [1, 1, 1, 1] :
+                memref<16x64x64x2xbf16> to !memrefA
+          %subview_1 = memref.subview %arg1[%arg5, %arg10, %arg4, 0] [1, 16, 32, 2] [1, 1, 1, 1] :
+                memref<16x64x128x2xbf16> to !memrefB
+          %8 = vector.transfer_read %subview_0[%c0, %c0, %c0, %c0], %1 {in_bounds = [true, true, true, true]} :
+                !memrefA, !vecAB
+          %9 = vector.transfer_read %subview_1[%c0, %c0, %c0, %c0], %1 {in_bounds = [true, true, true, true]} :
+                !memrefB, !vecAB
+
+          %10 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types =
+                ["reduction", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+                %8, %9, %arg11 {unroll_shape = array<i64: 1, 2, 16, 16, 16>} : !vecAB, !vecAB into !vecC
+
+          %11 = vector.transfer_read %subview_1[%c0, %c0, %c16, %c0], %1 {in_bounds = [true, true, true, true]} :
+                !memrefB, !vecAB
+          %12 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types =
+                ["reduction", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+                %8, %11, %arg12 {unroll_shape = array<i64: 1, 2, 16, 16, 16>} : !vecAB, !vecAB into !vecC
+          %13 = vector.transfer_read %subview_0[%c0, %c16, %c0, %c0], %1 {in_bounds = [true, true, true, true]} :
+                !memrefA, !vecAB
+          %14 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types =
+                ["reduction", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+                %13, %9, %arg13 {unroll_shape = array<i64: 1, 2, 16, 16, 16>} : !vecAB, !vecAB into !vecC
+          %15 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types =
+                ["reduction", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+                %13, %11, %arg14 {unroll_shape = array<i64: 1, 2, 16, 16, 16>} : !vecAB, !vecAB into !vecC
+
+          scf.yield %10, %12, %14, %15 : !vecC, !vecC, !vecC, !vecC
         }
+        scf.yield %7#0, %7#1, %7#2, %7#3 : !vecC, !vecC, !vecC, !vecC
       }
-      %alloc = memref.alloc() : memref<64x128xf32>
-      memref.copy %arg2, %alloc : memref<64x128xf32> to memref<64x128xf32>
-      return %alloc : memref<64x128xf32>
+
+      vector.transfer_write %6#3, %subview[%c16, %c16] {in_bounds = [true, true]} :
+        !vecC, !memrefC
+      vector.transfer_write %6#2, %subview[%c16, %c0] {in_bounds = [true, true]} :
+        !vecC, !memrefC
+      vector.transfer_write %6#1, %subview[%c0, %c16] {in_bounds = [true, true]} :
+        !vecC, !memrefC
+      vector.transfer_write %6#0, %subview[%c0, %c0] {in_bounds = [true, true]} :
+        !vecC, !memrefC
     }
   }
 
-// CHECK-LABEL: @brgemm_amx
-// CHECK: amx.tile_mulf
+  return %arg2 : memref<64x128xf32>
+}
+
+// CHECK-LABEL: @brgemm_bf16_loop
+// CHECK-2: scf.for {{.*}} -> (!amx.tile<16x16xf32>, !amx.tile<16x16xf32>, !amx.tile<16x16xf32>, !amx.tile<16x16xf32>) { 
+// CHECK-4: amx.tile_zero : !amx.tile<16x16xf32>
+// CHECK-4: amx.tile_load
+// CHECK-4: amx.tile_mulf
+// CHECK: scf.yield {{.*}} : !amx.tile<16x16xf32>, !amx.tile<16x16xf32>, !amx.tile<16x16xf32>, !amx.tile<16x16xf32>
+// CHECK-NOT: scf.for {{.*}} vector<16x16xf32>, vector<16x16xf32>, vector<16x16xf32>, vector<16x16xf32>
 // CHECK-NOT: vector.contract
 
 module attributes {transform.with_named_sequence} {
@@ -66,55 +320,75 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
+!vecAB = vector<16x16x4xi8>
+!vecC = vector<16x16xi32>
+!memrefA = memref<16x16x4xi8, strided<[256, 4, 1], offset: ?>>
+!memrefB = memref<16x32x4xi8, strided<[512, 4, 1], offset: ?>>
+!memrefC = memref<16x32xi32, strided<[128, 1], offset: ?>>
+
 #map = affine_map<(d0, d1, d2, d3) -> (d1, d3, d0)>
 #map1 = affine_map<(d0, d1, d2, d3) -> (d3, d2, d0)>
 #map2 = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
 
-  module {
-    func.func @batch_amx(%arg0: memref<64x64x2xbf16>, %arg1: memref<64x128x2xbf16>, %arg2: memref<64x128xf32>) -> memref<64x128xf32> attributes {dlti.target_system_spec = #dlti.target_system_spec<"CPU" = #dlti.target_device_spec<"reg_gemm_unroll" = [16, 16, 16]>>} {
-      %0 = ub.poison : f32
-      %1 = ub.poison : bf16
-      %c0 = arith.constant 0 : index
-      %c64 = arith.constant 64 : index
-      %c128 = arith.constant 128 : index
-      %c32 = arith.constant 32 : index
-      %c16 = arith.constant 16 : index
-      scf.for %arg3 = %c0 to %c64 step %c32 {
-        scf.for %arg4 = %c0 to %c128 step %c32 {
-          %subview = memref.subview %arg2[%arg3, %arg4] [32, 32] [1, 1] : memref<64x128xf32> to memref<32x32xf32, strided<[128, 1], offset: ?>>
-          %2 = vector.transfer_read %subview[%c0, %c0], %0 {in_bounds = [true, true]} : memref<32x32xf32, strided<[128, 1], offset: ?>>, vector<16x16xf32>
-          %3 = vector.transfer_read %subview[%c0, %c16], %0 {in_bounds = [true, true]} : memref<32x32xf32, strided<[128, 1], offset: ?>>, vector<16x16xf32>
-          %4 = vector.transfer_read %subview[%c16, %c0], %0 {in_bounds = [true, true]} : memref<32x32xf32, strided<[128, 1], offset: ?>>, vector<16x16xf32>
-          %5 = vector.transfer_read %subview[%c16, %c16], %0 {in_bounds = [true, true]} : memref<32x32xf32, strided<[128, 1], offset: ?>>, vector<16x16xf32>
-          %6:4 = scf.for %arg5 = %c0 to %c64 step %c16 iter_args(%arg6 = %2, %arg7 = %3, %arg8 = %4, %arg9 = %5) -> (vector<16x16xf32>, vector<16x16xf32>, vector<16x16xf32>, vector<16x16xf32>) {
-            %subview_0 = memref.subview %arg0[%arg3, %arg5, 0] [32, 16, 2] [1, 1, 1] : memref<64x64x2xbf16> to memref<32x16x2xbf16, strided<[128, 2, 1], offset: ?>>
-            %subview_1 = memref.subview %arg1[%arg5, %arg4, 0] [16, 32, 2] [1, 1, 1] : memref<64x128x2xbf16> to memref<16x32x2xbf16, strided<[256, 2, 1], offset: ?>>
-            %7 = vector.transfer_read %subview_0[%c0, %c0, %c0], %1 {in_bounds = [true, true, true]} : memref<32x16x2xbf16, strided<[128, 2, 1], offset: ?>>, vector<16x16x2xbf16>
-            %8 = vector.transfer_read %subview_0[%c16, %c0, %c0], %1 {in_bounds = [true, true, true]} : memref<32x16x2xbf16, strided<[128, 2, 1], offset: ?>>, vector<16x16x2xbf16>
-            %9 = vector.transfer_read %subview_1[%c0, %c0, %c0], %1 {in_bounds = [true, true, true]} : memref<16x32x2xbf16, strided<[256, 2, 1], offset: ?>>, vector<16x16x2xbf16>
-            %10 = vector.transfer_read %subview_1[%c0, %c16, %c0], %1 {in_bounds = [true, true, true]} : memref<16x32x2xbf16, strided<[256, 2, 1], offset: ?>>, vector<16x16x2xbf16>
-            %11 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %7, %9, %arg6 {unroll_shape = array<i64: 2, 16, 16, 16>} : vector<16x16x2xbf16>, vector<16x16x2xbf16> into vector<16x16xf32>
-            %12 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %7, %10, %arg7 {unroll_shape = array<i64: 2, 16, 16, 16>} : vector<16x16x2xbf16>, vector<16x16x2xbf16> into vector<16x16xf32>
-            %13 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %8, %9, %arg8 {unroll_shape = array<i64: 2, 16, 16, 16>} : vector<16x16x2xbf16>, vector<16x16x2xbf16> into vector<16x16xf32>
-            %14 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %8, %10, %arg9 {unroll_shape = array<i64: 2, 16, 16, 16>} : vector<16x16x2xbf16>, vector<16x16x2xbf16> into vector<16x16xf32>
-            scf.yield %11, %12, %13, %14 : vector<16x16xf32>, vector<16x16xf32>, vector<16x16xf32>, vector<16x16xf32>
-          }
-          vector.transfer_write %6#3, %subview[%c16, %c16] {in_bounds = [true, true]} : vector<16x16xf32>, memref<32x32xf32, strided<[128, 1], offset: ?>>
-          vector.transfer_write %6#2, %subview[%c16, %c0] {in_bounds = [true, true]} : vector<16x16xf32>, memref<32x32xf32, strided<[128, 1], offset: ?>>
-          vector.transfer_write %6#1, %subview[%c0, %c16] {in_bounds = [true, true]} : vector<16x16xf32>, memref<32x32xf32, strided<[128, 1], offset: ?>>
-          vector.transfer_write %6#0, %subview[%c0, %c0] {in_bounds = [true, true]} : vector<16x16xf32>, memref<32x32xf32, strided<[128, 1], offset: ?>>
-        }
+func.func @matmul_int8_loop(%arg0: memref<64x64x4xi8>, %arg1: memref<64x128x4xi8>, %arg2: memref<64x128xi32>) {
+  %0 = ub.poison : i32
+  %1 = ub.poison : i8
+  %c0 = arith.constant 0 : index
+  %c64 = arith.constant 64 : index
+  %c128 = arith.constant 128 : index
+  %c16 = arith.constant 16 : index
+  %c32 = arith.constant 32 : index
+  scf.for %arg3 = %c0 to %c64 step %c16 {
+    scf.for %arg4 = %c0 to %c128 step %c32 {
+
+      %subview = memref.subview %arg2[%arg3, %arg4] [16, 32] [1, 1] :
+                memref<64x128xi32> to !memrefC
+      %2 = vector.transfer_read %subview[%c0, %c0], %0 {in_bounds = [true, true]} :
+                !memrefC, !vecC
+      %3 = vector.transfer_read %subview[%c0, %c16], %0 {in_bounds = [true, true]} :
+                !memrefC, !vecC
+
+      %4:2 = scf.for %arg5 = %c0 to %c64 step %c16 iter_args(%arg6 = %2, %arg7 = %3) -> (!vecC, !vecC) {
+
+        %subview_0 = memref.subview %arg0[%arg3, %arg5, 0] [16, 16, 4] [1, 1, 1] :
+                memref<64x64x4xi8> to !memrefA
+        %subview_1 = memref.subview %arg1[%arg5, %arg4, 0] [16, 32, 4] [1, 1, 1] :
+                memref<64x128x4xi8> to !memrefB
+        %5 = vector.transfer_read %subview_0[%c0, %c0, %c0], %1 {in_bounds = [true, true, true]} :
+                !memrefA, !vecAB
+        %6 = vector.transfer_read %subview_1[%c0, %c0, %c0], %1 {in_bounds = [true, true, true]} :
+                !memrefB, !vecAB
+        %7 = vector.transfer_read %subview_1[%c0, %c16, %c0], %1 {in_bounds = [true, true, true]} :
+                !memrefB, !vecAB
+
+        %8 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types =
+                ["reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+                %5, %6, %arg6 {unroll_shape = array<i64: 4, 16, 16, 16>} : !vecAB, !vecAB into !vecC
+        %9 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types =
+                ["reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+                %5, %7, %arg7 {unroll_shape = array<i64: 4, 16, 16, 16>} : !vecAB, !vecAB into !vecC
+
+        scf.yield %8, %9 : !vecC, !vecC
       }
-      %alloc = memref.alloc() : memref<64x128xf32>
-      memref.copy %arg2, %alloc : memref<64x128xf32> to memref<64x128xf32>
-      return %alloc : memref<64x128xf32>
+
+      vector.transfer_write %4#1, %subview[%c0, %c16] {in_bounds = [true, true]} :
+                !vecC, !memrefC
+      vector.transfer_write %4#0, %subview[%c0, %c0] {in_bounds = [true, true]} :
+                !vecC, !memrefC
     }
   }
 
-// CHECK-LABEL: @batch_amx
-// CHECK: amx.tile_mulf
-// CHECK-NOT: vector.contract
+  return
+}
 
+// CHECK-LABEL: @matmul_int8_loop
+// CHECK-2: amx.tile_zero : !amx.tile<16x16xi32>
+// CHECK: scf.for {{.*}} -> (!amx.tile<16x16xi32>, !amx.tile<16x16xi32>) {
+// CHECK-3: amx.tile_load
+// CHECK-2: amx.tile_muli
+// CHECK: scf.yield {{.*}} !amx.tile<16x16xi32>, !amx.tile<16x16xi32>
+// CHECK-NOT: scf.for {{.*}} vector<16x16xi32>, vector<16x16xi32>
+// CHECK-NOT: vector.contract
 
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
@@ -128,56 +402,73 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
+!vecAB = vector<1x16x16x4xi8>
+!vecC = vector<1x16x16xi32>
+!memrefA = memref<1x16x16x4xi8, strided<[16384, 256, 4, 1], offset: ?>>
+!memrefB = memref<1x16x32x4xi8, strided<[32768, 512, 4, 1], offset: ?>>
+!memrefC = memref<1x16x32xi32, strided<[8192, 128, 1], offset: ?>>
+
 #map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d2, d4, d1)>
 #map1 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d4, d3, d1)>
 #map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d2, d3)>
 
-  module {
-    func.func @matmul_amx(%arg0: memref<16x64x64x2xbf16>, %arg1: memref<16x64x128x2xbf16>, %arg2: memref<16x64x128xf32>) -> memref<16x64x128xf32> attributes {dlti.target_system_spec = #dlti.target_system_spec<"CPU" = #dlti.target_device_spec<"reg_gemm_unroll" = [16, 16, 16]>>} {
-      %0 = ub.poison : f32
-      %1 = ub.poison : bf16
-      %c0 = arith.constant 0 : index
-      %c64 = arith.constant 64 : index
-      %c128 = arith.constant 128 : index
-      %c16 = arith.constant 16 : index
-      %c32 = arith.constant 32 : index
-      %c1 = arith.constant 1 : index
-      scf.for %arg3 = %c0 to %c64 step %c32 {
-        scf.for %arg4 = %c0 to %c128 step %c32 {
-          scf.for %arg5 = %c0 to %c16 step %c1 {
-            %subview = memref.subview %arg2[%arg5, %arg3, %arg4] [1, 32, 32] [1, 1, 1] : memref<16x64x128xf32> to memref<1x32x32xf32, strided<[8192, 128, 1], offset: ?>>
-            %2 = vector.transfer_read %subview[%c0, %c0, %c0], %0 {in_bounds = [true, true, true]} : memref<1x32x32xf32, strided<[8192, 128, 1], offset: ?>>, vector<1x16x16xf32>
-            %3 = vector.transfer_read %subview[%c0, %c0, %c16], %0 {in_bounds = [true, true, true]} : memref<1x32x32xf32, strided<[8192, 128, 1], offset: ?>>, vector<1x16x16xf32>
-            %4 = vector.transfer_read %subview[%c0, %c16, %c0], %0 {in_bounds = [true, true, true]} : memref<1x32x32xf32, strided<[8192, 128, 1], offset: ?>>, vector<1x16x16xf32>
-            %5 = vector.transfer_read %subview[%c0, %c16, %c16], %0 {in_bounds = [true, true, true]} : memref<1x32x32xf32, strided<[8192, 128, 1], offset: ?>>, vector<1x16x16xf32>
-            %6:4 = scf.for %arg6 = %c0 to %c64 step %c16 iter_args(%arg7 = %2, %arg8 = %3, %arg9 = %4, %arg10 = %5) -> (vector<1x16x16xf32>, vector<1x16x16xf32>, vector<1x16x16xf32>, vector<1x16x16xf32>) {
-              %subview_0 = memref.subview %arg0[%arg5, %arg3, %arg6, 0] [1, 32, 16, 2] [1, 1, 1, 1] : memref<16x64x64x2xbf16> to memref<1x32x16x2xbf16, strided<[8192, 128, 2, 1], offset: ?>>
-              %subview_1 = memref.subview %arg1[%arg5, %arg6, %arg4, 0] [1, 16, 32, 2] [1, 1, 1, 1] : memref<16x64x128x2xbf16> to memref<1x16x32x2xbf16, strided<[16384, 256, 2, 1], offset: ?>>
-              %7 = vector.transfer_read %subview_0[%c0, %c0, %c0, %c0], %1 {in_bounds = [true, true, true, true]} : memref<1x32x16x2xbf16, strided<[8192, 128, 2, 1], offset: ?>>, vector<1x16x16x2xbf16>
-              %8 = vector.transfer_read %subview_0[%c0, %c16, %c0, %c0], %1 {in_bounds = [true, true, true, true]} : memref<1x32x16x2xbf16, strided<[8192, 128, 2, 1], offset: ?>>, vector<1x16x16x2xbf16>
-              %9 = vector.transfer_read %subview_1[%c0, %c0, %c0, %c0], %1 {in_bounds = [true, true, true, true]} : memref<1x16x32x2xbf16, strided<[16384, 256, 2, 1], offset: ?>>, vector<1x16x16x2xbf16>
-              %10 = vector.transfer_read %subview_1[%c0, %c0, %c16, %c0], %1 {in_bounds = [true, true, true, true]} : memref<1x16x32x2xbf16, strided<[16384, 256, 2, 1], offset: ?>>, vector<1x16x16x2xbf16>
-              %11 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %7, %9, %arg7 {unroll_shape = array<i64: 1, 2, 16, 16, 16>} : vector<1x16x16x2xbf16>, vector<1x16x16x2xbf16> into vector<1x16x16xf32>
-              %12 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %7, %10, %arg8 {unroll_shape = array<i64: 1, 2, 16, 16, 16>} : vector<1x16x16x2xbf16>, vector<1x16x16x2xbf16> into vector<1x16x16xf32>
-              %13 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %8, %9, %arg9 {unroll_shape = array<i64: 1, 2, 16, 16, 16>} : vector<1x16x16x2xbf16>, vector<1x16x16x2xbf16> into vector<1x16x16xf32>
-              %14 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %8, %10, %arg10 {unroll_shape = array<i64: 1, 2, 16, 16, 16>} : vector<1x16x16x2xbf16>, vector<1x16x16x2xbf16> into vector<1x16x16xf32>
-              scf.yield %11, %12, %13, %14 : vector<1x16x16xf32>, vector<1x16x16xf32>, vector<1x16x16xf32>, vector<1x16x16xf32>
-            }
-            vector.transfer_write %6#3, %subview[%c0, %c16, %c16] {in_bounds = [true, true, true]} : vector<1x16x16xf32>, memref<1x32x32xf32, strided<[8192, 128, 1], offset: ?>>
-            vector.transfer_write %6#2, %subview[%c0, %c16, %c0] {in_bounds = [true, true, true]} : vector<1x16x16xf32>, memref<1x32x32xf32, strided<[8192, 128, 1], offset: ?>>
-            vector.transfer_write %6#1, %subview[%c0, %c0, %c16] {in_bounds = [true, true, true]} : vector<1x16x16xf32>, memref<1x32x32xf32, strided<[8192, 128, 1], offset: ?>>
-            vector.transfer_write %6#0, %subview[%c0, %c0, %c0] {in_bounds = [true, true, true]} : vector<1x16x16xf32>, memref<1x32x32xf32, strided<[8192, 128, 1], offset: ?>>
-          }
+func.func @batch_matmul_int8_loop(%arg0: memref<16x64x64x4xi8>, %arg1: memref<16x64x128x4xi8>, %arg2: memref<16x64x128xi32>) {
+  %0 = ub.poison : i32
+  %1 = ub.poison : i8
+  %c0 = arith.constant 0 : index
+  %c64 = arith.constant 64 : index
+  %c128 = arith.constant 128 : index
+  %c16 = arith.constant 16 : index
+  %c32 = arith.constant 32 : index
+  %c1 = arith.constant 1 : index
+  scf.for %arg3 = %c0 to %c64 step %c16 {
+    scf.for %arg4 = %c0 to %c128 step %c32 {
+      scf.for %arg5 = %c0 to %c16 step %c1 {
+
+        %subview = memref.subview %arg2[%arg5, %arg3, %arg4] [1, 16, 32] [1, 1, 1] :
+                memref<16x64x128xi32> to !memrefC
+        %2 = vector.transfer_read %subview[%c0, %c0, %c0], %0 {in_bounds = [true, true, true]} :
+                !memrefC, !vecC
+        %3 = vector.transfer_read %subview[%c0, %c0, %c16], %0 {in_bounds = [true, true, true]} :
+                !memrefC, !vecC
+        %4:2 = scf.for %arg6 = %c0 to %c64 step %c16 iter_args(%arg7 = %2, %arg8 = %3) -> (!vecC, !vecC) {
+
+          %subview_0 = memref.subview %arg0[%arg5, %arg3, %arg6, 0] [1, 16, 16, 4] [1, 1, 1, 1] :
+                memref<16x64x64x4xi8> to !memrefA
+          %subview_1 = memref.subview %arg1[%arg5, %arg6, %arg4, 0] [1, 16, 32, 4] [1, 1, 1, 1] :
+                memref<16x64x128x4xi8> to !memrefB
+          %5 = vector.transfer_read %subview_0[%c0, %c0, %c0, %c0], %1 {in_bounds = [true, true, true, true]} :
+                !memrefA, !vecAB
+          %6 = vector.transfer_read %subview_1[%c0, %c0, %c0, %c0], %1 {in_bounds = [true, true, true, true]} :
+                !memrefB, !vecAB
+          %7 = vector.transfer_read %subview_1[%c0, %c0, %c16, %c0], %1 {in_bounds = [true, true, true, true]} :
+                !memrefB, !vecAB
+          %8 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types =
+                ["parallel", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+                %5, %6, %arg7 {unroll_shape = array<i64: 1, 4, 16, 16, 16>} : !vecAB, !vecAB into !vecC
+          %9 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types =
+                ["parallel", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+                %5, %7, %arg8 {unroll_shape = array<i64: 1, 4, 16, 16, 16>} : !vecAB, !vecAB into !vecC
+          scf.yield %8, %9 : !vecC, !vecC
         }
+
+        vector.transfer_write %4#1, %subview[%c0, %c0, %c16] {in_bounds = [true, true, true]} :
+                !vecC, !memrefC
+        vector.transfer_write %4#0, %subview[%c0, %c0, %c0] {in_bounds = [true, true, true]} :
+                !vecC, !memrefC
       }
-      %alloc = memref.alloc() : memref<16x64x128xf32>
-      memref.copy %arg2, %alloc : memref<16x64x128xf32> to memref<16x64x128xf32>
-      return %alloc : memref<16x64x128xf32>
     }
   }
+  return
+}
 
-// CHECK-LABEL: @matmul_amx
-// CHECK: amx.tile_mulf
+// CHECK-LABEL: @batch_matmul_int8_loop
+// CHECK-2: amx.tile_zero : !amx.tile<16x16xi32>
+// CHECK: scf.for {{.*}} -> (!amx.tile<16x16xi32>, !amx.tile<16x16xi32>) {
+// CHECK-3: amx.tile_load
+// CHECK-2: amx.tile_muli
+// CHECK: scf.yield {{.*}} !amx.tile<16x16xi32>, !amx.tile<16x16xi32>
+// CHECK-NOT: scf.for {{.*}} vector<16x16xi32>, vector<16x16xi32>
 // CHECK-NOT: vector.contract
 
 module attributes {transform.with_named_sequence} {
@@ -192,16 +483,68 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
-!vecA = vector<1x16x16x2xbf16>
-!vecB = vector<1x16x16x2xbf16>
-!vecC = vector<16x16xf32>
+!vecA = vector<1x16x16x4xi8>
+!vecB = vector<1x16x16x4xi8>
+!vecC = vector<16x16xi32>
+!memrefA = memref<1x32x16x4xi8>
+!memrefB = memref<1x16x32x4xi8>
+!memrefC = memref<32x32xi32>
+#map = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d3, d4)>
+#map1 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d3, d2, d4)>
+#map2 = affine_map<(d0, d4, d1, d2, d3) -> (d1, d2)>
+func.func @negative_invalid_vc_kind(
+  %arg0: !memrefA, %arg1: !memrefB, %arg2: !memrefC) -> !memrefC
+{
+  %c0 = arith.constant 0 : index
+  %0 = ub.poison : i8
+  %32 = ub.poison : i32
+
+  %1 = vector.transfer_read %arg0[%c0, %c0, %c0, %c0], %0 {in_bounds = [true, true, true, true]} :
+        !memrefA, !vecA
+  %2 = vector.transfer_read %arg1[%c0, %c0, %c0, %c0], %0 {in_bounds = [true, true, true, true]} :
+        !memrefB, !vecB
+
+  %3 = vector.transfer_read %arg2[%c0, %c0], %32 {in_bounds = [true, true]} : !memrefC, !vecC
+
+  %4 = vector.contract {
+    indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["reduction", "reduction", "parallel", "parallel", "reduction"],
+    kind = #vector.kind<mul>}
+    %1, %2, %3 : !vecA, !vecB into !vecC
+
+  vector.transfer_write %4, %arg2[%c0, %c0] {in_bounds = [true, true]} : !vecC, !memrefC
+
+  return %arg2 : !memrefC
+}
+
+// CHECK-LABEL: @negative_invalid_vc_kind
+// CHECK-NOT: amx.tile_load {{.*}} !amx.tile<16x64xi8>
+// CHECK-NOT: amx.tile_muli
+// CHECK-NOT: amx.tile_store {{.*}} !amx.tile<16x16xi32>
+// CHECK: vector.contract
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.amx.vector_contract_to_packed_type_tiled_dot_product
+    } : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+!vecA = vector<1x32x16x2xbf16>
+!vecB = vector<1x16x32x2xbf16>
+!vecC = vector<1x32x32xf32>
 !memrefA = memref<1x32x16x2xbf16>
 !memrefB = memref<1x16x32x2xbf16>
-!memrefC = memref<32x32xf32>
+!memrefC = memref<1x32x32xf32>
 #map = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d3, d4)>
 #map1 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d3, d2, d4)>
-#map2 = affine_map<(d0, d4, d1, d2, d3) -> (d1, d2)>
-func.func @amx(
+#map2 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d2)>
+func.func @negative_wrong_dimensions(
   %arg0: !memrefA, %arg1: !memrefB, %arg2: !memrefC) -> !memrefC
 {
   %c0 = arith.constant 0 : index
@@ -213,11 +556,113 @@ func.func @amx(
   %2 = vector.transfer_read %arg1[%c0, %c0, %c0, %c0], %0 {in_bounds = [true, true, true, true]} :
         !memrefB, !vecB
 
+  %3 = vector.transfer_read %arg2[%c0, %c0, %c0], %32 {in_bounds = [true, true, true]} : !memrefC, !vecC
+
+  %4 = vector.contract {
+    indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["parallel", "reduction", "parallel", "parallel", "reduction"],
+    kind = #vector.kind<add>}
+    %1, %2, %3 : !vecA, !vecB into !vecC
+
+  vector.transfer_write %4, %arg2[%c0, %c0, %c0] {in_bounds = [true, true, true]} : !vecC, !memrefC
+
+  return %arg2 : !memrefC
+}
+
+// CHECK-LABEL: @negative_wrong_dimensions
+// CHECK-NOT: amx.tile_load
+// CHECK-NOT: amx.tile_mulf
+// CHECK-NOT: amx.tile_store 
+// CHECK: vector.contract
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.amx.vector_contract_to_packed_type_tiled_dot_product
+    } : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+!vecA = vector<16x16x4xi8>
+!vecB = vector<16x16x4xi8>
+!vecC = vector<16x16xi32>
+!memrefB = memref<16x32x4xi8>
+!memrefC = memref<32x32xi32>
+#map = affine_map<(d4, d1, d2, d3) -> (d1, d3, d4)>
+#map1 = affine_map<(d4, d1, d2, d3) -> (d3, d2, d4)>
+#map2 = affine_map<(d4, d1, d2, d3) -> (d1, d2)>
+func.func @negative_no_memref_source_LHS(
+  %arg0: !vecA, %arg1: !memrefB, %arg2: !memrefC) -> !memrefC
+{
+  %c0 = arith.constant 0 : index
+  %0 = ub.poison : i8
+  %32 = ub.poison : i32
+
+  %2 = vector.transfer_read %arg1[%c0, %c0, %c0], %0 {in_bounds = [true, true, true]} :
+        !memrefB, !vecB
+
   %3 = vector.transfer_read %arg2[%c0, %c0], %32 {in_bounds = [true, true]} : !memrefC, !vecC
 
   %4 = vector.contract {
     indexing_maps = [#map, #map1, #map2],
-    iterator_types = ["reduction", "reduction", "parallel", "parallel", "reduction"],
+    iterator_types = ["reduction", "parallel", "parallel", "reduction"],
+    kind = #vector.kind<add>}
+    %arg0, %2, %3 : !vecA, !vecB into !vecC
+
+  vector.transfer_write %4, %arg2[%c0, %c0] {in_bounds = [true, true]} : !vecC, !memrefC
+
+  return %arg2 : !memrefC
+}
+
+// CHECK-LABEL: @negative_no_memref_source_LHS
+// CHECK-NOT: amx.tile_load {{.*}} !amx.tile<16x64xi8>
+// CHECK-NOT: amx.tile_muli
+// CHECK-NOT: amx.tile_store {{.*}} !amx.tile<16x16xi32>
+// CHECK: vector.contract
+
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.amx.vector_contract_to_packed_type_tiled_dot_product
+    } : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+!vecA = vector<16x64xi8>
+!vecB = vector<64x16xi8>
+!vecC = vector<16x16xi32>
+!memrefA = memref<32x64xi8>
+!memrefB = memref<64x32xi8>
+!memrefC = memref<32x32xi32>
+#map = affine_map<(d1, d2, d3) -> (d1, d3)>
+#map1 = affine_map<(d1, d2, d3) -> (d3, d2)>
+#map2 = affine_map<(d1, d2, d3) -> (d1, d2)>
+func.func @negative_no_vnni_packed(
+  %arg0: !memrefA, %arg1: !memrefB, %arg2: !memrefC) -> !memrefC
+{
+  %c0 = arith.constant 0 : index
+  %0 = ub.poison : i8
+  %32 = ub.poison : i32
+
+  %1 = vector.transfer_read %arg0[%c0, %c0], %0 {in_bounds = [true, true]} :
+        !memrefA, !vecA
+  %2 = vector.transfer_read %arg1[%c0, %c0], %0 {in_bounds = [true, true]} :
+        !memrefB, !vecB
+
+  %3 = vector.transfer_read %arg2[%c0, %c0], %32 {in_bounds = [true, true]} : !memrefC, !vecC
+
+  %4 = vector.contract {
+    indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["parallel", "parallel", "reduction"],
     kind = #vector.kind<add>}
     %1, %2, %3 : !vecA, !vecB into !vecC
 
@@ -226,9 +671,12 @@ func.func @amx(
   return %arg2 : !memrefC
 }
 
-// CHECK-LABEL: @amx
-// CHECK: amx.tile_mulf
-// CHECK-NOT: vector.contract
+// CHECK-LABEL: @negative_no_vnni_packed
+// CHECK-NOT: amx.tile_load {{.*}} !amx.tile<16x64xi8>
+// CHECK-NOT: amx.tile_muli
+// CHECK-NOT: amx.tile_store {{.*}} !amx.tile<16x16xi32>
+// CHECK: vector.contract
+
 
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
@@ -240,3 +688,156 @@ module attributes {transform.with_named_sequence} {
   }
 }
 
+// -----
+
+!vecAB = vector<16x16x4xi8>
+!vecC = vector<16x16xi32>
+!memrefA = memref<16x16x4xi8, strided<[256, 4, 1], offset: ?>>
+!memrefB = memref<16x32x4xi8, strided<[512, 4, 1], offset: ?>>
+!memrefC = memref<16x32xi32, strided<[128, 1], offset: ?>>
+
+#map = affine_map<(d0, d1, d2, d3) -> (d1, d3, d0)>
+#map1 = affine_map<(d0, d1, d2, d3) -> (d3, d2, d0)>
+#map2 = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+
+func.func @negative_VCs_LHS_src_differ(%arg0: memref<64x64x4xi8>, %arg1: memref<64x128x4xi8>, %arg2: memref<64x128xi32>, %arg13: memref<64x64x4xi8>) {
+  %0 = ub.poison : i32
+  %1 = ub.poison : i8
+  %c0 = arith.constant 0 : index
+  %c64 = arith.constant 64 : index
+  %c128 = arith.constant 128 : index
+  %c16 = arith.constant 16 : index
+  %c32 = arith.constant 32 : index
+  scf.for %arg3 = %c0 to %c64 step %c16 {
+    scf.for %arg4 = %c0 to %c128 step %c32 {
+
+      %subview = memref.subview %arg2[%arg3, %arg4] [16, 32] [1, 1] :
+                memref<64x128xi32> to !memrefC
+      %2 = vector.transfer_read %subview[%c0, %c0], %0 {in_bounds = [true, true]} :
+                !memrefC, !vecC
+      %3 = vector.transfer_read %subview[%c0, %c16], %0 {in_bounds = [true, true]} :
+                !memrefC, !vecC
+
+      %4:2 = scf.for %arg5 = %c0 to %c64 step %c16 iter_args(%arg6 = %2, %arg7 = %3) -> (!vecC, !vecC) {
+
+        %subview_0 = memref.subview %arg0[%arg3, %arg5, 0] [16, 16, 4] [1, 1, 1] :
+                memref<64x64x4xi8> to !memrefA
+        %subview_negative = memref.subview %arg13[%arg3, %arg5, 0] [16, 16, 4] [1, 1, 1] :
+                memref<64x64x4xi8> to !memrefA
+        %subview_1 = memref.subview %arg1[%arg5, %arg4, 0] [16, 32, 4] [1, 1, 1] :
+                memref<64x128x4xi8> to !memrefB
+        %5 = vector.transfer_read %subview_0[%c0, %c0, %c0], %1 {in_bounds = [true, true, true]} :
+                !memrefA, !vecAB
+        %odd_load = vector.transfer_read %subview_negative[%c0, %c0, %c0], %1 {in_bounds = [true, true, true]} :
+                !memrefA, !vecAB
+        %6 = vector.transfer_read %subview_1[%c0, %c0, %c0], %1 {in_bounds = [true, true, true]} :
+                !memrefB, !vecAB
+        %7 = vector.transfer_read %subview_1[%c0, %c16, %c0], %1 {in_bounds = [true, true, true]} :
+                !memrefB, !vecAB
+
+        %8 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types =
+                ["reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+                %5, %6, %arg6 {unroll_shape = array<i64: 4, 16, 16, 16>} : !vecAB, !vecAB into !vecC
+        %9 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types =
+                ["reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+                %odd_load, %7, %arg7 {unroll_shape = array<i64: 4, 16, 16, 16>} : !vecAB, !vecAB into !vecC
+
+        scf.yield %8, %9 : !vecC, !vecC
+      }
+
+      vector.transfer_write %4#1, %subview[%c0, %c16] {in_bounds = [true, true]} :
+                !vecC, !memrefC
+      vector.transfer_write %4#0, %subview[%c0, %c0] {in_bounds = [true, true]} :
+                !vecC, !memrefC
+    }
+  }
+
+  return
+}
+
+// CHECK-LABEL: @negative_VCs_LHS_src_differ
+// CHECK-NOT: amx.tile_zero : !amx.tile<16x16xi32>
+// CHECK-NOT: scf.for {{.*}} -> (!amx.tile<16x16xi32>, !amx.tile<16x16xi32>) {
+// CHECK-NOT: amx.tile_load
+// CHECK-NOT: amx.tile_muli
+// CHECK-NOT: scf.yield {{.*}} !amx.tile<16x16xi32>, !amx.tile<16x16xi32>
+// CHECK: scf.for {{.*}} -> (vector<16x16xi32>, vector<16x16xi32>) {
+// CHECK: vector.contract
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.amx.vector_contract_to_packed_type_tiled_dot_product
+    } : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+!vecA = vector<1x16x16x4xi8>
+!vecB = vector<1x16x32x4xi8>
+!vecC = vector<1x16x32xi32>
+!memrefA = memref<1x16x16x4xi8, strided<[16384, 256, 4, 1], offset: ?>>
+!memrefB = memref<1x16x32x4xi8, strided<[32768, 512, 4, 1], offset: ?>>
+!memrefC = memref<1x16x32xi32, strided<[8192, 128, 1], offset: ?>>
+
+#map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d2, d4, d1)>
+#map1 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d4, d3, d1)>
+#map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d2, d3)>
+
+func.func @negative_wrong_N_dim(%arg0: memref<16x64x64x4xi8>, %arg1: memref<16x64x128x4xi8>, %arg2: memref<16x64x128xi32>) {
+  %0 = ub.poison : i32
+  %1 = ub.poison : i8
+  %c0 = arith.constant 0 : index
+  %c64 = arith.constant 64 : index
+  %c128 = arith.constant 128 : index
+  %c16 = arith.constant 16 : index
+  %c32 = arith.constant 32 : index
+  %c1 = arith.constant 1 : index
+  scf.for %arg3 = %c0 to %c64 step %c16 {
+    scf.for %arg4 = %c0 to %c128 step %c32 {
+      scf.for %arg5 = %c0 to %c16 step %c1 {
+
+        %subview = memref.subview %arg2[%arg5, %arg3, %arg4] [1, 16, 32] [1, 1, 1] :
+                memref<16x64x128xi32> to !memrefC
+        %2 = vector.transfer_read %subview[%c0, %c0, %c0], %0 {in_bounds = [true, true, true]} :
+                !memrefC, !vecC
+        %3 = scf.for %arg6 = %c0 to %c64 step %c16 iter_args(%arg7 = %2) -> (!vecC) {
+          %subview_0 = memref.subview %arg0[%arg5, %arg3, %arg6, 0] [1, 16, 16, 4] [1, 1, 1, 1] :
+                memref<16x64x64x4xi8> to !memrefA
+          %subview_1 = memref.subview %arg1[%arg5, %arg6, %arg4, 0] [1, 16, 32, 4] [1, 1, 1, 1] :
+                memref<16x64x128x4xi8> to !memrefB
+          %4 = vector.transfer_read %subview_0[%c0, %c0, %c0, %c0], %1 {in_bounds = [true, true, true, true]} :
+                !memrefA, !vecA
+          %5 = vector.transfer_read %subview_1[%c0, %c0, %c0, %c0], %1 {in_bounds = [true, true, true, true]} :
+                !memrefB, !vecB
+          %6 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types =
+                ["parallel", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+                %4, %5, %arg7 {unroll_shape = array<i64: 1, 4, 16, 32, 16>} : !vecA, !vecB into !vecC
+          scf.yield %6 : !vecC
+        }
+        vector.transfer_write %3, %subview[%c0, %c0, %c0] {in_bounds = [true, true, true]} :
+                !vecC, !memrefC
+      }
+    }
+  }
+  return
+}
+
+// CHECK-LABEL: @negative_wrong_N_dim
+// CHECK-NOT: amx.tile_zero : !amx.tile<16x16xi32>
+// CHECK-NOT: amx.tile_load
+// CHECK-NOT: amx.tile_muli
+// CHECK: vector.contract
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.amx.vector_contract_to_packed_type_tiled_dot_product
+    } : !transform.any_op
+    transform.yield
+  }
+}

>From 58f0eda53f2fe0e6d705798711ae20e2ffed5625 Mon Sep 17 00:00:00 2001
From: Arun Thangamani <arun.thangamani at intel.com>
Date: Tue, 3 Mar 2026 00:51:26 -0800
Subject: [PATCH 5/7] added a -ve test-cases

---
 .../AMX/vector-contract-to-tiled-dp.mlir      | 84 +++++++++++++++++++
 1 file changed, 84 insertions(+)

diff --git a/mlir/test/Dialect/AMX/vector-contract-to-tiled-dp.mlir b/mlir/test/Dialect/AMX/vector-contract-to-tiled-dp.mlir
index ccc1a0f0c9bdc..e95f739ea0197 100644
--- a/mlir/test/Dialect/AMX/vector-contract-to-tiled-dp.mlir
+++ b/mlir/test/Dialect/AMX/vector-contract-to-tiled-dp.mlir
@@ -841,3 +841,87 @@ module attributes {transform.with_named_sequence} {
     transform.yield
   }
 }
+
+// -----
+
+!vecAB = vector<1x1x16x16x4xi8>
+!vecC = vector<16x16xi32>
+!memrefA = memref<1x1x16x16x4xi8, strided<[262144, 16384, 256, 4, 1], offset: ?>>
+!memrefB = memref<1x1x16x32x4xi8, strided<[524288, 32768, 512, 4, 1], offset: ?>>
+!memrefC = memref<16x32xi32, strided<[128, 1], offset: ?>>
+
+#map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d3, d5, d2)>
+#map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d5, d4, d2)>
+#map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d3, d4)>
+
+func.func @negative_reduction_loop_depth_3(%arg0: memref<2x16x64x64x4xi8>, %arg1: memref<2x16x64x128x4xi8>, %arg2: memref<64x128xi32>) attributes {dlti.target_system_spec = #dlti.target_system_spec<"CPU" = #dlti.target_device_spec<"reg_gemm_unroll" = [16, 16, 16]>>} {
+  %0 = ub.poison : i32
+  %1 = ub.poison : i8
+  %c0 = arith.constant 0 : index
+  %c64 = arith.constant 64 : index
+  %c128 = arith.constant 128 : index
+  %c2 = arith.constant 2 : index
+  %c16 = arith.constant 16 : index
+  %c32 = arith.constant 32 : index
+  %c1 = arith.constant 1 : index
+  scf.for %arg3 = %c0 to %c64 step %c16 {
+    scf.for %arg4 = %c0 to %c128 step %c32 {
+      %subview = memref.subview %arg2[%arg3, %arg4] [16, 32] [1, 1] :
+                memref<64x128xi32> to !memrefC
+      %2 = vector.transfer_read %subview[%c0, %c0], %0 {in_bounds = [true, true]} :
+                !memrefC, !vecC
+      %3 = vector.transfer_read %subview[%c0, %c16], %0 {in_bounds = [true, true]} :
+                !memrefC, !vecC
+
+      %4:2 = scf.for %arg5 = %c0 to %c2 step %c1 iter_args(%arg6 = %2, %arg7 = %3) -> (!vecC, !vecC) {
+        %5:2 = scf.for %arg8 = %c0 to %c16 step %c1 iter_args(%arg9 = %arg6, %arg10 = %arg7) -> (!vecC, !vecC) {
+          %6:2 = scf.for %arg11 = %c0 to %c64 step %c16 iter_args(%arg12 = %arg9, %arg13 = %arg10) -> (!vecC, !vecC) {
+            %subview_0 = memref.subview %arg0[%arg5, %arg8, %arg3, %arg11, 0] [1, 1, 16, 16, 4] [1, 1, 1, 1, 1] :
+                memref<2x16x64x64x4xi8> to !memrefA
+            %subview_1 = memref.subview %arg1[%arg5, %arg8, %arg11, %arg4, 0] [1, 1, 16, 32, 4] [1, 1, 1, 1, 1] :
+                memref<2x16x64x128x4xi8> to !memrefB
+            %7 = vector.transfer_read %subview_0[%c0, %c0, %c0, %c0, %c0], %1 {in_bounds = [true, true, true, true, true]} :
+                !memrefA, !vecAB
+            %8 = vector.transfer_read %subview_1[%c0, %c0, %c0, %c0, %c0], %1 {in_bounds = [true, true, true, true, true]} :
+                !memrefB, !vecAB
+            %9 = vector.transfer_read %subview_1[%c0, %c0, %c0, %c16, %c0], %1 {in_bounds = [true, true, true, true, true]} :
+                !memrefB, !vecAB
+
+            %10 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types =
+                ["reduction", "reduction", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+                %7, %8, %arg12 {unroll_shape = array<i64: 1, 1, 4, 16, 16, 16>} : !vecAB, !vecAB into !vecC
+            %11 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types =
+                ["reduction", "reduction", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+                %7, %9, %arg13 {unroll_shape = array<i64: 1, 1, 4, 16, 16, 16>} : !vecAB, !vecAB into !vecC
+            scf.yield %10, %11 : !vecC, !vecC
+          }
+          scf.yield %6#0, %6#1 : !vecC, !vecC
+        }
+        scf.yield %5#0, %5#1 : !vecC, !vecC
+      }
+
+      vector.transfer_write %4#1, %subview[%c0, %c16] {in_bounds = [true, true]} :
+                !vecC, !memrefC
+      vector.transfer_write %4#0, %subview[%c0, %c0] {in_bounds = [true, true]} :
+                !vecC, !memrefC
+    }
+  }
+  return
+}
+
+
+// CHECK-LABEL: @negative_reduction_loop_depth_3
+// CHECK-NOT: amx.tile_zero : !amx.tile<16x16xi32>
+// CHECK-NOT: amx.tile_load
+// CHECK-NOT: amx.tile_muli
+// CHECK: vector.contract
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.amx.vector_contract_to_packed_type_tiled_dot_product
+    } : !transform.any_op
+    transform.yield
+  }
+}

>From 0b8d1298999184b2b403d57e4317dca17b40a7c5 Mon Sep 17 00:00:00 2001
From: Arun Thangamani <arun.thangamani at intel.com>
Date: Tue, 3 Mar 2026 01:46:34 -0800
Subject: [PATCH 6/7] remove the new AMX related files

---
 .../AMX/TransformOps/AMXTransformOps.h        | 31 ----------
 .../AMX/TransformOps/AMXTransformOps.td       | 32 -----------
 .../Dialect/AMX/TransformOps/CMakeLists.txt   |  4 --
 .../AMX/TransformOps/AMXTransformOps.cpp      | 57 -------------------
 .../Dialect/AMX/TransformOps/CMakeLists.txt   | 17 ------
 5 files changed, 141 deletions(-)
 delete mode 100644 mlir/include/mlir/Dialect/AMX/TransformOps/AMXTransformOps.h
 delete mode 100644 mlir/include/mlir/Dialect/AMX/TransformOps/AMXTransformOps.td
 delete mode 100644 mlir/include/mlir/Dialect/AMX/TransformOps/CMakeLists.txt
 delete mode 100644 mlir/lib/Dialect/AMX/TransformOps/AMXTransformOps.cpp
 delete mode 100644 mlir/lib/Dialect/AMX/TransformOps/CMakeLists.txt

diff --git a/mlir/include/mlir/Dialect/AMX/TransformOps/AMXTransformOps.h b/mlir/include/mlir/Dialect/AMX/TransformOps/AMXTransformOps.h
deleted file mode 100644
index 8806635df8eb5..0000000000000
--- a/mlir/include/mlir/Dialect/AMX/TransformOps/AMXTransformOps.h
+++ /dev/null
@@ -1,31 +0,0 @@
-//===- AMXTransformOps.h - AMX transform ops --------------------*- C++ -*-===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-
-#ifndef MLIR_DIALECT_AMX_TRANSFORMOPS_AMXTRANSFORMOPS_H
-#define MLIR_DIALECT_AMX_TRANSFORMOPS_AMXTRANSFORMOPS_H
-
-#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
-#include "mlir/IR/OpImplementation.h"
-
-//===----------------------------------------------------------------------===//
-// AMX Transform Operations
-//===----------------------------------------------------------------------===//
-
-#define GET_OP_CLASSES
-#include "mlir/Dialect/AMX/TransformOps/AMXTransformOps.h.inc"
-
-namespace mlir {
-class DialectRegistry;
-
-namespace amx {
-void registerTransformDialectExtension(DialectRegistry &registry);
-
-} // namespace amx
-} // namespace mlir
-
-#endif // MLIR_DIALECT_AMX_TRANSFORMOPS_AMXTRANSFORMOPS_H
diff --git a/mlir/include/mlir/Dialect/AMX/TransformOps/AMXTransformOps.td b/mlir/include/mlir/Dialect/AMX/TransformOps/AMXTransformOps.td
deleted file mode 100644
index 74bcba3c37e57..0000000000000
--- a/mlir/include/mlir/Dialect/AMX/TransformOps/AMXTransformOps.td
+++ /dev/null
@@ -1,32 +0,0 @@
-//===- AMXTransformOps.td - AMX transform ops --*- tablegen -*-------------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-
-#ifndef AMX_TRANSFORM_OPS
-#define AMX_TRANSFORM_OPS
-
-include "mlir/Dialect/Transform/IR/TransformDialect.td"
-include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.td"
-include "mlir/Interfaces/SideEffectInterfaces.td"
-include "mlir/IR/OpBase.td"
-include "mlir/Dialect/Transform/IR/TransformAttrs.td"
-include "mlir/Dialect/Transform/IR/TransformTypes.td"
-include "mlir/IR/RegionKindInterface.td"
-
-def ApplyVectorContractToPackedTypeTiledDotProductPatternsOp : Op<Transform_Dialect,
-    "apply_patterns.amx.vector_contract_to_packed_type_tiled_dot_product",
-    [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
-  let description = [{
-    Collect patterns to lower a BF16/Int8 type vector.contract operation 
-	to a BF16/Int8 tiled dot-product.
-  }];
-
-  let assemblyFormat = "attr-dict";
-}
-
-
-#endif // AMX_TRANSFORM_OPS
diff --git a/mlir/include/mlir/Dialect/AMX/TransformOps/CMakeLists.txt b/mlir/include/mlir/Dialect/AMX/TransformOps/CMakeLists.txt
deleted file mode 100644
index 41255b936be71..0000000000000
--- a/mlir/include/mlir/Dialect/AMX/TransformOps/CMakeLists.txt
+++ /dev/null
@@ -1,4 +0,0 @@
-set(LLVM_TARGET_DEFINITIONS AMXTransformOps.td)
-mlir_tablegen(AMXTransformOps.h.inc -gen-op-decls)
-mlir_tablegen(AMXTransformOps.cpp.inc -gen-op-defs)
-add_mlir_dialect_tablegen_target(MLIRAMXTransformOpsIncGen)
diff --git a/mlir/lib/Dialect/AMX/TransformOps/AMXTransformOps.cpp b/mlir/lib/Dialect/AMX/TransformOps/AMXTransformOps.cpp
deleted file mode 100644
index 6ff573407b42b..0000000000000
--- a/mlir/lib/Dialect/AMX/TransformOps/AMXTransformOps.cpp
+++ /dev/null
@@ -1,57 +0,0 @@
-//===- AMXTransformOps.cpp ------------------------------------------------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir/Dialect/AMX/TransformOps/AMXTransformOps.h"
-#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
-#include "mlir/Dialect/AMX/AMXDialect.h"
-#include "mlir/Dialect/AMX/Transforms.h"
-#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
-#include "mlir/Dialect/Transform/IR/TransformDialect.h"
-#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
-#include "mlir/Dialect/Vector/IR/VectorOps.h"
-
-#include "mlir/IR/OpImplementation.h"
-#include "mlir/IR/RegionKindInterface.h"
-
-using namespace mlir;
-using namespace mlir::amx;
-using namespace mlir::transform;
-
-void mlir::transform::ApplyVectorContractToPackedTypeTiledDotProductPatternsOp::
-    populatePatterns(RewritePatternSet &patterns) {
-  amx::populateVectorContractToPackedTypeTiledDotProductPatterns(patterns);
-}
-
-//===----------------------------------------------------------------------===//
-// Transform op registration
-//===----------------------------------------------------------------------===//
-
-namespace {
-class AMXTransformDialectExtension
-    : public transform::TransformDialectExtension<
-          AMXTransformDialectExtension> {
-public:
-  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(AMXTransformDialectExtension)
-
-  AMXTransformDialectExtension() {
-    declareGeneratedDialect<amx::AMXDialect>();
-    declareGeneratedDialect<LLVM::LLVMDialect>();
-    registerTransformOps<
-#define GET_OP_LIST
-#include "mlir/Dialect/AMX/TransformOps/AMXTransformOps.cpp.inc"
-        >();
-  }
-};
-} // namespace
-
-#define GET_OP_CLASSES
-#include "mlir/Dialect/AMX/TransformOps/AMXTransformOps.cpp.inc"
-
-void mlir::amx::registerTransformDialectExtension(DialectRegistry &registry) {
-  registry.addExtensions<AMXTransformDialectExtension>();
-}
diff --git a/mlir/lib/Dialect/AMX/TransformOps/CMakeLists.txt b/mlir/lib/Dialect/AMX/TransformOps/CMakeLists.txt
deleted file mode 100644
index 30b4304586ab7..0000000000000
--- a/mlir/lib/Dialect/AMX/TransformOps/CMakeLists.txt
+++ /dev/null
@@ -1,17 +0,0 @@
-add_mlir_dialect_library(MLIRAMXTransformOps
-  AMXTransformOps.cpp
-
-  DEPENDS
-  MLIRAMXTransformOpsIncGen
-
-  LINK_LIBS PUBLIC
-  MLIRIR
-  MLIRLLVMCommonConversion
-  MLIRLLVMDialect
-  MLIRVectorDialect
-  MLIRSideEffectInterfaces
-  MLIRTransformDialect
-  MLIRTransformDialectUtils
-  MLIRAMXDialect
-  MLIRAMXTransforms
-  )

>From 06fc39f126a511fe620bfc61c7decb6d34f17033 Mon Sep 17 00:00:00 2001
From: Arun Thangamani <arun.thangamani at intel.com>
Date: Wed, 4 Mar 2026 22:21:35 -0800
Subject: [PATCH 7/7] removing couple of debugger statements.

---
 .../Transforms/VectorContractToPackedTypeTiledDotProduct.cpp | 5 +----
 1 file changed, 1 insertion(+), 4 deletions(-)

diff --git a/mlir/lib/Dialect/X86/Transforms/VectorContractToPackedTypeTiledDotProduct.cpp b/mlir/lib/Dialect/X86/Transforms/VectorContractToPackedTypeTiledDotProduct.cpp
index 4bdb5ff83bb74..d5c17528f1564 100644
--- a/mlir/lib/Dialect/X86/Transforms/VectorContractToPackedTypeTiledDotProduct.cpp
+++ b/mlir/lib/Dialect/X86/Transforms/VectorContractToPackedTypeTiledDotProduct.cpp
@@ -382,15 +382,12 @@ struct VectorContractToPackedTypeTiledDotProduct
                         "The rest dims should be 1.");
 
       Location loc = contractOp.getLoc();
-      llvm::outs() << "Reaching-here1" << "\n";
+
       auto srcIndxLhs = getSrcIndxValue(rewriter, contractOp.getLoc(),
                                         contractOp.getLhs(), true);
-      llvm::outs() << "Reaching-here2" << "\n";
       if (failed(srcIndxLhs))
         return rewriter.notifyMatchFailure(contractOp,
                                            "The LHS src is not a MemRef type.");
-      llvm::outs() << "Reaching-here3" << "\n";
-
       auto [srcBuffLhs, indicesLhs] = *srcIndxLhs;
 
       auto srcIndxRhs = getSrcIndxValue(rewriter, contractOp.getLoc(),



More information about the Mlir-commits mailing list