[Mlir-commits] [mlir] [mlir][x86] Lower packed type vector.contract to AMX dot-product (online-packing) (PR #188192)

Arun Thangamani llvmlistbot at llvm.org
Wed Apr 8 11:09:30 PDT 2026


https://github.com/arun-thmn updated https://github.com/llvm/llvm-project/pull/188192

>From 7b888d75131d67a4e3f2cb2ffdec7e1f82abdf01 Mon Sep 17 00:00:00 2001
From: Arun Thangamani <arun.thangamani at intel.com>
Date: Tue, 24 Mar 2026 01:05:51 -0700
Subject: [PATCH 1/5] support for amx online-packing

---
 .../VectorContractToAMXDotProduct.cpp         | 936 +++++++++++++++---
 .../X86/AMX/vector-contract-to-tiled-dp.mlir  | 349 ++++++-
 2 files changed, 1118 insertions(+), 167 deletions(-)

diff --git a/mlir/lib/Dialect/X86/Transforms/VectorContractToAMXDotProduct.cpp b/mlir/lib/Dialect/X86/Transforms/VectorContractToAMXDotProduct.cpp
index 85966a85af40e..2b159e6f59cb9 100644
--- a/mlir/lib/Dialect/X86/Transforms/VectorContractToAMXDotProduct.cpp
+++ b/mlir/lib/Dialect/X86/Transforms/VectorContractToAMXDotProduct.cpp
@@ -1,4 +1,4 @@
-//===- VectorContractToAMXDotProduct.cpp ----------------------------------===//
+//===- VectorContractToAMXDotProduct.cpp ----------------------===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -70,8 +70,9 @@ getSrcIndxValue(OpBuilder &rewriter, Location loc, Value operand,
   if (!srcBuff)
     return failure();
 
-  if (isNotAcc)
+  if (isNotAcc) {
     indexVals.pop_back();
+  }
 
   SmallVector<Value> indices;
   indices.reserve(indexVals.size());
@@ -189,20 +190,155 @@ static unsigned getIndexPosition(Value operand, scf::ForOp loop) {
 // Creates amx.tile_loads.
 static amx::TileLoadOp createTileLoads(OpBuilder &rewriter, Location loc,
                                        Value operand, Value mat, Type ipType,
-                                       bool rhs, unsigned int offset) {
+                                       bool rhs, unsigned int offset,
+                                       bool isVnni) {
 
   auto srcIndx = getSrcIndxValue(rewriter, loc, operand, false);
   auto [srcBuff, indices] = *srcIndx;
-  indices.pop_back();
+  if (isVnni) {
+    indices.pop_back();
+  }
 
-  if (rhs) {
+  if (rhs && isVnni) {
     auto cOffset = arith::ConstantIndexOp::create(rewriter, loc, offset);
     indices[indices.size() - 1] = arith::MulIOp::create(
         rewriter, loc, indices[indices.size() - 1], cOffset);
   }
 
   amx::TileType tileType = amx::TileType::get({16, (16 * offset)}, ipType);
-  return amx::TileLoadOp::create(rewriter, loc, tileType, mat, indices);
+  auto load = amx::TileLoadOp::create(rewriter, loc, tileType, mat, indices);
+  return load;
+}
+
+static void performShuffle(OpBuilder &rewriter, Location loc, Value matB,
+                           Type ipType, unsigned int offset, Value bBuffer,
+                           Value allocStore) {
+
+  auto subview = matB.getDefiningOp<mlir::memref::SubViewOp>();
+  Value c0 = arith::ConstantIndexOp::create(rewriter, loc, 0);
+  SmallVector<Value> vals(subview.getOffsets().size(), c0);
+
+  // Value c1 = arith::ConstantIndexOp::create(rewriter, loc, 1);
+  Value c16 = arith::ConstantIndexOp::create(rewriter, loc, 16);
+
+  Value cStep = arith::ConstantIndexOp::create(rewriter, loc, offset);
+  Value cBound = arith::ConstantIndexOp::create(rewriter, loc, (16 * offset));
+  Value nLoadIndx = arith::ConstantIndexOp::create(rewriter, loc, (offset / 2));
+
+  scf::ForOp::create(
+      rewriter, loc, c0, cBound, cStep, ValueRange{},
+      [&](OpBuilder &nestedBuilder, Location loc, Value iv,
+          ValueRange iterArgs) {
+        Value i1_load = arith::AddIOp::create(rewriter, loc, nLoadIndx, iv);
+
+        vals[vals.size() - 2] = iv;
+        ValueRange range1(vals);
+        auto vec1 = vector::LoadOp::create(
+            rewriter, loc, VectorType::get((16 * offset), ipType), matB,
+            range1);
+
+        vals[vals.size() - 2] = i1_load;
+        ValueRange range2(vals);
+        auto vec2 = vector::LoadOp::create(
+            rewriter, loc, VectorType::get((16 * offset), ipType), matB,
+            range2);
+
+        vector::ShuffleOp shuffle1;
+        vector::ShuffleOp shuffle2;
+
+        if (offset == 2) {
+
+          shuffle1 = vector::ShuffleOp::create(
+              rewriter, loc, VectorType::get({(16 * offset)}, ipType), vec1,
+              vec2,
+              ArrayRef<int64_t>{0,  32, 1,  33, 2,  34, 3,  35, 8,  40, 9,
+                                41, 10, 42, 11, 43, 16, 48, 17, 49, 18, 50,
+                                19, 51, 24, 56, 25, 57, 26, 58, 27, 59});
+
+          shuffle2 = vector::ShuffleOp::create(
+              rewriter, loc, VectorType::get({(16 * offset)}, ipType), vec1,
+              vec2,
+              ArrayRef<int64_t>{4,  36, 5,  37, 6,  38, 7,  39, 12, 44, 13,
+                                45, 14, 46, 15, 47, 20, 52, 21, 53, 22, 54,
+                                23, 55, 28, 60, 29, 61, 30, 62, 31, 63});
+        }
+
+        if (offset == 4) {
+
+          shuffle1 = vector::ShuffleOp::create(
+              rewriter, loc, VectorType::get({(16 * offset)}, ipType), vec1,
+              vec2,
+              ArrayRef<int64_t>{
+                  0,   32,  64, 96,  1,   33,  65,  97,  2,   34,  66,  98, 3,
+                  35,  67,  99, 4,   36,  68,  100, 5,   37,  69,  101, 6,  38,
+                  70,  102, 7,  39,  71,  103, 8,   40,  72,  104, 9,   41, 73,
+                  105, 10,  42, 74,  106, 11,  43,  75,  107, 12,  44,  76, 108,
+                  13,  45,  77, 109, 14,  46,  78,  110, 15,  47,  79,  111});
+
+          shuffle2 = vector::ShuffleOp::create(
+              rewriter, loc, VectorType::get({(16 * offset)}, ipType), vec1,
+              vec2,
+              ArrayRef<int64_t>{
+                  16, 48,  80, 112, 17, 49,  81, 113, 18, 50,  82, 114, 19, 51,
+                  83, 115, 20, 52,  84, 116, 21, 53,  85, 117, 22, 54,  86, 118,
+                  23, 55,  87, 119, 24, 56,  88, 120, 25, 57,  89, 121, 26, 58,
+                  90, 122, 27, 59,  91, 123, 28, 60,  92, 124, 29, 61,  93, 125,
+                  30, 62,  94, 126, 31, 63,  95, 127});
+        }
+        Value j_pos = arith::DivUIOp::create(rewriter, loc, iv, cStep);
+        Value j16_pos = arith::AddIOp::create(rewriter, loc, c16, j_pos);
+
+        vector::StoreOp::create(rewriter, loc, shuffle1, bBuffer,
+                                ValueRange{allocStore, j_pos, c0});
+        vector::StoreOp::create(rewriter, loc, shuffle2, bBuffer,
+                                ValueRange{allocStore, j16_pos, c0});
+
+        scf::YieldOp::create(nestedBuilder, loc);
+      });
+}
+
+static llvm::DenseMap<Operation *, amx::TileLoadOp>
+packInputs(OpBuilder &rewriter, Location loc,
+           SmallVector<vector::ContractionOp> ops, Value matB, Type ipType,
+           unsigned int offset, Value bBuffer, bool pack, Value allocStore,
+           Value addIdx) {
+  llvm::DenseMap<Operation *, amx::TileLoadOp> readsToTileLoads;
+
+  for (size_t j = 0; j < ops.size(); j++) {
+    for (size_t i = 0; i < ops.size(); i++) {
+
+      if (i != j && validatePairVectorContract(ops[j], ops[i], true, 16)) {
+
+        Operation *readOpRhs = ops[j].getRhs().getDefiningOp();
+        auto itRhs = readsToTileLoads.find(readOpRhs);
+        if (itRhs != readsToTileLoads.end()) {
+          continue;
+        }
+
+        Value c0 = arith::ConstantIndexOp::create(rewriter, loc, 0);
+        Value c16 = arith::ConstantIndexOp::create(rewriter, loc, 16);
+
+        if (pack) {
+          performShuffle(rewriter, loc, matB, ipType, offset, bBuffer,
+                         allocStore);
+        }
+
+        amx::TileType tileType =
+            amx::TileType::get({16, (16 * offset)}, ipType);
+        auto load = amx::TileLoadOp::create(rewriter, loc, tileType, bBuffer,
+                                            ValueRange{addIdx, c0, c0});
+
+        auto load1 = amx::TileLoadOp::create(rewriter, loc, tileType, bBuffer,
+                                             ValueRange{addIdx, c16, c0});
+
+        readsToTileLoads.try_emplace(readOpRhs, load);
+
+        readsToTileLoads.try_emplace(ops[i].getRhs().getDefiningOp(), load1);
+      }
+    }
+  }
+
+  return readsToTileLoads;
 }
 
 // Creates tiled amx dot-products.
@@ -210,16 +346,26 @@ static SmallVector<Value> createTiledDp(OpBuilder &rewriter, Location loc,
                                         SmallVector<vector::ContractionOp> ops,
                                         Value matA, Value matB, Type ipType,
                                         Type opType, ValueRange accIterArgs,
-                                        unsigned int offset) {
+                                        unsigned int offset, bool isVnni,
+                                        Value bBuffer, bool pack,
+                                        Value allocStore, Value addIdx) {
 
-  auto subviewCollapseLhs = collapseInnerDims(rewriter, loc, matA);
-  auto subviewCollapseRhs = collapseInnerDims(rewriter, loc, matB);
+  if (isVnni) {
+    matA = collapseInnerDims(rewriter, loc, matA);
+    matB = collapseInnerDims(rewriter, loc, matB);
+  }
 
   SmallVector<Value> accumulators;
   // Stores the amx.tile_load operation vs it's equivalent vector tranfer_read
   // or load operations.
   llvm::DenseMap<Operation *, amx::TileLoadOp> readsToTileLoads;
 
+  // function call to make the flat.
+  if (!isVnni) {
+    readsToTileLoads = packInputs(rewriter, loc, ops, matB, ipType, offset,
+                                  bBuffer, pack, allocStore, addIdx);
+  }
+
   // Iterate over the contraction operations and compute the tiled dot-product.
   for (size_t i = 0; i < ops.size(); i++) {
 
@@ -229,8 +375,8 @@ static SmallVector<Value> createTiledDp(OpBuilder &rewriter, Location loc,
     if (itLhs != readsToTileLoads.end()) {
       tilesLhs = itLhs->second;
     } else {
-      tilesLhs = createTileLoads(rewriter, loc, ops[i].getLhs(),
-                                 subviewCollapseLhs, ipType, false, offset);
+      tilesLhs = createTileLoads(rewriter, loc, ops[i].getLhs(), matA, ipType,
+                                 false, offset, isVnni);
       readsToTileLoads.try_emplace(readOpLhs, tilesLhs);
     }
 
@@ -240,8 +386,8 @@ static SmallVector<Value> createTiledDp(OpBuilder &rewriter, Location loc,
     if (itRhs != readsToTileLoads.end()) {
       tilesRhs = itRhs->second;
     } else {
-      tilesRhs = createTileLoads(rewriter, loc, ops[i].getRhs(),
-                                 subviewCollapseRhs, ipType, true, offset);
+      tilesRhs = createTileLoads(rewriter, loc, ops[i].getRhs(), matB, ipType,
+                                 true, offset, isVnni);
       readsToTileLoads.try_emplace(readOpRhs, tilesRhs);
     }
 
@@ -276,6 +422,172 @@ static SmallVector<Value> createTileZeros(OpBuilder &rewriter, Location loc,
   return loopItrArgs;
 }
 
+static Value bufferIndxToStore(OpBuilder &rewriter, Location loc, Value iv_K,
+                               Value iv_red, bool oddDimK, bool nDimK,
+                               bool pack, unsigned int blockingFactor) {
+
+  Value c2 = arith::ConstantIndexOp::create(rewriter, loc, 2);
+
+  Value nLoadIndx =
+      arith::ConstantIndexOp::create(rewriter, loc, (16 * blockingFactor));
+
+  Value quotient_K = arith::DivUIOp::create(rewriter, loc, iv_K, nLoadIndx);
+  Value rem_K = arith::RemUIOp::create(rewriter, loc, rewriter.getIndexType(),
+                                       quotient_K, c2);
+
+  if (!nDimK && !pack) {
+    rem_K = arith::RemUIOp::create(rewriter, loc, rewriter.getIndexType(),
+                                   iv_red, c2);
+  }
+
+  // if K quotient is odd. Then, BR loop iv is taken
+  // into consideration
+  if (oddDimK) {
+    auto rem_BR = arith::RemUIOp::create(rewriter, loc, rewriter.getIndexType(),
+                                         iv_red, c2);
+    auto remAdd = arith::AddIOp::create(rewriter, loc, rewriter.getIndexType(),
+                                        rem_K, rem_BR);
+    rem_K = arith::RemUIOp::create(rewriter, loc, rewriter.getIndexType(),
+                                   remAdd, c2);
+  }
+
+  return rem_K;
+}
+
+static scf::ForOp
+createLoops(OpBuilder &rewriter, Location loc, Value lowerBound,
+            Value upperBound, Value step, SmallVector<Value> loopItrArgs,
+            Type ipType, Type opType, unsigned int blockingFactor, bool isVnni,
+            Operation *vectorOpLhs, Operation *vectorOpRhs,
+            vector::ContractionOp contractOp, scf::ForOp outerLoop,
+            scf::ForOp innerLoop, SmallVector<vector::ContractionOp> ops,
+            Value ivOuterLoop, Value bBuffer, bool pack,
+            arith::ConstantIndexOp innerLoopIndex, bool nDimK, bool oddDimK) {
+
+  auto newLoop1 = scf::ForOp::create(
+      rewriter, loc, lowerBound, upperBound, step, loopItrArgs,
+      [&](OpBuilder &rewriterNewInnerLoop, Location locNewInnerLoop,
+          Value ivNewInnerLoop, ValueRange iterArgsNewInnerLoop) {
+        IRMapping mapping;
+        if (outerLoop) {
+          mapping.map(vectorOpLhs->getOperand(
+                          getIndexPosition(contractOp.getLhs(), outerLoop) + 1),
+                      ivOuterLoop);
+        }
+        mapping.map(vectorOpLhs->getOperand(
+                        getIndexPosition(contractOp.getLhs(), innerLoop) + 1),
+                    ivNewInnerLoop);
+        auto lhsClone = rewriterNewInnerLoop.clone(*vectorOpLhs, mapping);
+
+        Value c0 = arith::ConstantIndexOp::create(rewriter, locNewInnerLoop, 0);
+        Value c1 = arith::ConstantIndexOp::create(rewriter, locNewInnerLoop, 1);
+        Value c2 = arith::ConstantIndexOp::create(rewriter, locNewInnerLoop, 2);
+        Value allocStore = c0;
+        Value allocGet = c0;
+
+        if (!isVnni) {
+          if (outerLoop) {
+            if (innerLoopIndex.value() == 0) {
+              if (pack) {
+                ivNewInnerLoop = c0;
+                ivOuterLoop = arith::AddIOp::create(rewriter, locNewInnerLoop,
+                                                    c1, ivOuterLoop);
+
+                if (!nDimK || oddDimK) {
+                  allocStore = arith::RemUIOp::create(rewriter, locNewInnerLoop,
+                                                      rewriter.getIndexType(),
+                                                      ivOuterLoop, c2);
+                }
+
+                Value addIdx =
+                    arith::AddIOp::create(rewriter, loc, allocStore, c1);
+                allocGet = arith::RemUIOp::create(
+                    rewriter, loc, rewriter.getIndexType(), addIdx, c2);
+              }
+
+            } else {
+              Value nLoadIndx = arith::ConstantIndexOp::create(
+                  rewriter, locNewInnerLoop, (16 * blockingFactor));
+              ivNewInnerLoop = arith::AddIOp::create(rewriter, locNewInnerLoop,
+                                                     nLoadIndx, ivNewInnerLoop);
+              allocStore =
+                  bufferIndxToStore(rewriter, loc, ivNewInnerLoop, ivOuterLoop,
+                                    oddDimK, nDimK, pack, blockingFactor);
+              Value addIdx =
+                  arith::AddIOp::create(rewriter, loc, allocStore, c1);
+              allocGet = arith::RemUIOp::create(
+                  rewriter, loc, rewriter.getIndexType(), addIdx, c2);
+            }
+          } else {
+            if (pack) {
+              Value nLoadIndx = arith::ConstantIndexOp::create(
+                  rewriter, locNewInnerLoop, (16 * blockingFactor));
+              ivNewInnerLoop = arith::AddIOp::create(rewriter, locNewInnerLoop,
+                                                     nLoadIndx, ivNewInnerLoop);
+              Value quotient_K = arith::DivUIOp::create(
+                  rewriter, loc, ivNewInnerLoop, nLoadIndx);
+              allocStore = arith::RemUIOp::create(
+                  rewriter, loc, rewriter.getIndexType(), quotient_K, c2);
+
+              Value addIdx =
+                  arith::AddIOp::create(rewriter, loc, allocStore, c1);
+              allocGet = arith::RemUIOp::create(
+                  rewriter, loc, rewriter.getIndexType(), addIdx, c2);
+            }
+          }
+        }
+
+        IRMapping rhsMapping;
+        if (outerLoop) {
+          rhsMapping.map(
+              vectorOpRhs->getOperand(
+                  getIndexPosition(contractOp.getRhs(), outerLoop) + 1),
+              ivOuterLoop);
+        }
+        rhsMapping.map(
+            vectorOpRhs->getOperand(
+                getIndexPosition(contractOp.getRhs(), innerLoop) + 1),
+            ivNewInnerLoop);
+        auto rhsClone = rewriterNewInnerLoop.clone(*vectorOpRhs, rhsMapping);
+
+        Value matB = rhsClone->getResult(0);
+
+        if (!isVnni) {
+          if (outerLoop) {
+            if (!pack) {
+              Value nLoadIndx = arith::ConstantIndexOp::create(
+                  rewriter, locNewInnerLoop, (16 * blockingFactor));
+              matB = Value();
+              allocGet = c0;
+              allocGet =
+                  bufferIndxToStore(rewriter, loc, nLoadIndx, ivOuterLoop,
+                                    oddDimK, nDimK, pack, blockingFactor);
+            }
+          } else {
+            if (!pack) {
+              Value nLoadIndx = arith::ConstantIndexOp::create(
+                  rewriter, locNewInnerLoop, (16 * blockingFactor));
+              matB = Value();
+              Value quotient_K = arith::DivUIOp::create(
+                  rewriter, loc, ivNewInnerLoop, nLoadIndx);
+              allocGet = arith::RemUIOp::create(
+                  rewriter, loc, rewriter.getIndexType(), quotient_K, c2);
+            }
+          }
+        }
+
+        SmallVector<Value> accumulators = createTiledDp(
+            rewriter, locNewInnerLoop, ops, lhsClone->getResult(0), matB,
+            ipType, opType, iterArgsNewInnerLoop, blockingFactor, isVnni,
+            bBuffer, pack, allocStore, allocGet);
+
+        scf::YieldOp::create(rewriterNewInnerLoop, locNewInnerLoop,
+                             accumulators);
+      });
+
+  return newLoop1;
+}
+
 // Implements tiled dot-product operation for a vector.contract operation or a
 // sequence of vector.contracts inside the reduction loops.
 //
@@ -326,9 +638,6 @@ struct VectorContractToAMXDotProduct
       return rewriter.notifyMatchFailure(contractOp,
                                          "Only F32 for BF16 or Int32 for Int8 "
                                          "accumulation type is supported.");
-    if (!isVnni)
-      return rewriter.notifyMatchFailure(
-          contractOp, "Only VNNI-packed inputs are supported.");
 
     Operation *accReadOp =
         traceToVectorReadLikeParentOperation(contractOp.getAcc());
@@ -342,8 +651,13 @@ struct VectorContractToAMXDotProduct
                       "transfer_read or a load. And, the result should be "
                       "stored using transfer_write or store.");
 
-    Type ipType = rewriter.getBF16Type();
-    Type opType = rewriter.getF32Type();
+    Type ipType;
+    Type opType;
+
+    if (lhsTy.getElementType().isBF16()) {
+      ipType = rewriter.getBF16Type();
+      opType = rewriter.getF32Type();
+    }
 
     if (lhsTy.getElementType().isSignlessInteger(8)) {
       ipType = rewriter.getIntegerType(8);
@@ -360,13 +674,21 @@ struct VectorContractToAMXDotProduct
       return rewriter.notifyMatchFailure(
           contractOp, "The accumulator read is in different block.");
 
+    unsigned int dimValue = blockingFactor;
+    if (!isVnni)
+      dimValue = 16 * blockingFactor;
+
     // Case 1: For just one VC rewrite. Where all accumulator read/write
     // within the same block.
     if (accReadOp->getBlock() == contractOp->getBlock() &&
         resultWriteOp->getBlock() == contractOp->getBlock()) {
 
+      bool collapse = false;
+      if (isVnni)
+        collapse = true;
+
       LogicalResult validate = validateContractOps(
-          rewriter, contractOp, blockingFactor, Value(), Value(), false);
+          rewriter, contractOp, dimValue, Value(), Value(), false);
 
       if (failed(validate))
         return rewriter.notifyMatchFailure(
@@ -377,18 +699,20 @@ struct VectorContractToAMXDotProduct
       Location loc = contractOp.getLoc();
 
       auto srcIndxLhs = getSrcIndxValue(rewriter, contractOp.getLoc(),
-                                        contractOp.getLhs(), true);
+                                        contractOp.getLhs(), collapse);
       if (failed(srcIndxLhs))
         return rewriter.notifyMatchFailure(contractOp,
                                            "The LHS src is not a MemRef type.");
       auto [srcBuffLhs, indicesLhs] = *srcIndxLhs;
 
       auto srcIndxRhs = getSrcIndxValue(rewriter, contractOp.getLoc(),
-                                        contractOp.getRhs(), true);
+                                        contractOp.getRhs(), collapse);
       if (failed(srcIndxRhs))
         return rewriter.notifyMatchFailure(contractOp,
                                            "The RHS src is not a MemRef type.");
-      auto [srcBuffRhs, indicesRhs] = *srcIndxRhs;
+      auto rhsSrc = *srcIndxRhs;
+      auto srcBuffRhs = rhsSrc.first;
+      auto indicesRhs = rhsSrc.second;
 
       auto srcIndxAcc = getSrcIndxValue(rewriter, contractOp.getLoc(),
                                         contractOp.getAcc(), false);
@@ -401,8 +725,112 @@ struct VectorContractToAMXDotProduct
       auto tileType = amx::TileType::get({16, (16 * blockingFactor)}, ipType);
       auto loadLhs = amx::TileLoadOp::create(rewriter, loc, tileType,
                                              srcBuffLhs, indicesLhs);
-      auto loadRhs = amx::TileLoadOp::create(rewriter, loc, tileType,
-                                             srcBuffRhs, indicesRhs);
+
+      // Create the subview and then load.
+      //
+      amx::TileLoadOp loadRhs;
+      if (!isVnni) {
+        VectorType vecTy;
+        SmallVector<OpFoldResult> indexVals;
+        llvm::TypeSwitch<Operation *>(contractOp.getRhs().getDefiningOp())
+            .Case<TransferReadOp, LoadOp>([&](auto readOp) {
+              indexVals = SmallVector<OpFoldResult>(readOp.getIndices().begin(),
+                                                    readOp.getIndices().end());
+              vecTy = readOp.getType();
+            });
+        auto one = rewriter.getIndexAttr(1);
+        SmallVector<OpFoldResult> strides(indexVals.size(), one);
+        SmallVector<OpFoldResult> sizes = getAsIndexOpFoldResult(
+            contractOp.getRhs().getDefiningOp()->getContext(),
+            vecTy.getShape());
+        auto subview = memref::SubViewOp::create(rewriter, loc, srcBuffRhs,
+                                                 indexVals, sizes, strides);
+        auto bufferType = MemRefType::get({16, (16 * blockingFactor)}, ipType);
+        auto bBuffer = memref::AllocaOp::create(rewriter, loc, bufferType);
+
+        // create a loop that swaps them.
+
+        Value c0 = arith::ConstantIndexOp::create(rewriter, loc, 0);
+
+        Value step =
+            arith::ConstantIndexOp::create(rewriter, loc, blockingFactor);
+        Value uBound = arith::ConstantIndexOp::create(rewriter, loc,
+                                                      (blockingFactor * 16));
+        Value nextLoadIndx =
+            arith::ConstantIndexOp::create(rewriter, loc, (blockingFactor / 2));
+        Value nextStoreIndx = arith::ConstantIndexOp::create(
+            rewriter, loc, 16 * (blockingFactor / 2));
+
+        scf::ForOp::create(
+            rewriter, loc, c0, uBound, step, ValueRange{},
+            [&](OpBuilder &nestedBuilder, Location loc, Value iv,
+                ValueRange iterArgs) {
+              Value i1_load =
+                  arith::AddIOp::create(rewriter, loc, nextLoadIndx, iv);
+
+              indicesRhs[indicesRhs.size() - 2] = iv;
+              ValueRange range1(indicesRhs);
+              auto vec1 = vector::LoadOp::create(
+                  rewriter, loc,
+                  VectorType::get(16 * (blockingFactor / 2), ipType), subview,
+                  range1);
+
+              indicesRhs[indicesRhs.size() - 2] = i1_load;
+              ValueRange range2(indicesRhs);
+              auto vec2 = vector::LoadOp::create(
+                  rewriter, loc,
+                  VectorType::get(16 * (blockingFactor / 2), ipType), subview,
+                  range2);
+
+              vector::ShuffleOp shuffle1;
+              vector::ShuffleOp shuffle2;
+
+              if (blockingFactor == 2) {
+
+                shuffle1 = vector::ShuffleOp::create(
+                    rewriter, loc, VectorType::get({16}, ipType), vec1, vec2,
+                    ArrayRef<int64_t>{0, 16, 1, 17, 2, 18, 3, 19, 4, 20, 5, 21,
+                                      6, 22, 7, 23});
+
+                shuffle2 = vector::ShuffleOp::create(
+                    rewriter, loc, VectorType::get({16}, ipType), vec1, vec2,
+                    ArrayRef<int64_t>{8, 24, 9, 25, 10, 26, 11, 27, 12, 28, 13,
+                                      29, 14, 30, 15, 31});
+              }
+
+              if (blockingFactor == 4) {
+                shuffle1 = vector::ShuffleOp::create(
+                    rewriter, loc, VectorType::get({32}, ipType), vec1, vec2,
+                    ArrayRef<int64_t>{0, 16, 32, 48, 1, 17, 33, 49,
+                                      2, 18, 34, 50, 3, 19, 35, 51,
+                                      4, 20, 36, 52, 5, 21, 37, 53,
+                                      6, 22, 38, 54, 7, 23, 39, 55});
+
+                shuffle2 = vector::ShuffleOp::create(
+                    rewriter, loc, VectorType::get({32}, ipType), vec1, vec2,
+                    ArrayRef<int64_t>{8,  24, 40, 56, 9,  25, 41, 57,
+                                      10, 26, 42, 58, 11, 27, 43, 59,
+                                      12, 28, 44, 60, 13, 29, 45, 61,
+                                      14, 30, 46, 62, 15, 31, 47, 63});
+              }
+
+              auto rem = arith::RemUIOp::create(
+                  rewriter, loc, rewriter.getIndexType(), iv, step);
+
+              vector::StoreOp::create(rewriter, loc, shuffle1, bBuffer,
+                                      ValueRange{rem, c0});
+              vector::StoreOp::create(rewriter, loc, shuffle2, bBuffer,
+                                      ValueRange{rem, nextStoreIndx});
+
+              scf::YieldOp::create(nestedBuilder, loc);
+            });
+        loadRhs = amx::TileLoadOp::create(rewriter, loc, tileType, bBuffer,
+                                          ValueRange{c0, c0});
+      } else {
+
+        loadRhs = amx::TileLoadOp::create(rewriter, loc, tileType, srcBuffRhs,
+                                          indicesRhs);
+      }
 
       auto tileTypeAcc = amx::TileType::get({16, 16}, opType);
       auto loadAcc = amx::TileLoadOp::create(rewriter, loc, tileTypeAcc,
@@ -429,7 +857,6 @@ struct VectorContractToAMXDotProduct
     // reduction loop.
     SmallVector<scf::ForOp> loopLists;
     Operation *current = contractOp;
-
     while (true) {
       Operation *parent = current->getParentOfType<scf::ForOp>();
       loopLists.push_back(dyn_cast<scf::ForOp>(parent));
@@ -440,7 +867,6 @@ struct VectorContractToAMXDotProduct
 
       current = parent;
     }
-
     if (loopLists.size() > 2 || loopLists.size() == 0)
       return rewriter.notifyMatchFailure(
           contractOp, "Rewrite is supported until reduction loop depth of 2.");
@@ -458,7 +884,6 @@ struct VectorContractToAMXDotProduct
       return rewriter.notifyMatchFailure(contractOp,
                                          "The RHS src is not a MemRef type.");
     auto [srcBuffRhs, indicesRhs] = *srcIndxRhs;
-
     Operation *vectorOpLhs;
     llvm::TypeSwitch<Operation *>(contractOp.getLhs().getDefiningOp())
         .Case<TransferReadOp, LoadOp>([&](auto readOp) {
@@ -478,7 +903,7 @@ struct VectorContractToAMXDotProduct
       if (auto contract = llvm::dyn_cast<mlir::vector::ContractionOp>(op)) {
 
         LogicalResult validate = validateContractOps(
-            rewriter, contract, blockingFactor, srcBuffLhs, srcBuffRhs, true);
+            rewriter, contract, dimValue, srcBuffLhs, srcBuffRhs, true);
 
         if (failed(validate))
           return rewriter.notifyMatchFailure(
@@ -490,8 +915,8 @@ struct VectorContractToAMXDotProduct
       }
     }
 
-    scf::ForOp outerLoop;
     scf::ForOp innerLoop;
+    scf::ForOp outerLoop;
 
     scf::ForOp newLoop;
     // Case 2a: Reduction loop depth is 2.
@@ -502,126 +927,248 @@ struct VectorContractToAMXDotProduct
       SmallVector<Value> loopItrArgs = createTileZeros(
           rewriter, outerLoop.getLoc(), opType, outerLoop, ops.size());
 
-      newLoop = scf::ForOp::create(
-          rewriter, outerLoop.getLoc(), outerLoop.getLowerBound(),
-          outerLoop.getUpperBound(), outerLoop.getStep(), loopItrArgs,
-          [&](OpBuilder &rewriterOuterLoop, Location locOuterLoop,
-              Value ivOuterLoop, ValueRange iterArgsOuterLoop) {
-            auto newInnerLoop = scf::ForOp::create(
-                rewriter, innerLoop.getLoc(), innerLoop.getLowerBound(),
-                innerLoop.getUpperBound(), innerLoop.getStep(),
-                iterArgsOuterLoop,
-                [&](OpBuilder &rewriterNewInnerLoop, Location locNewInnerLoop,
-                    Value ivNewInnerLoop, ValueRange iterArgsNewInnerLoop) {
-                  IRMapping mapping;
-                  mapping.map(
-                      vectorOpLhs->getOperand(
-                          getIndexPosition(contractOp.getLhs(), outerLoop) + 1),
-                      ivOuterLoop);
-                  mapping.map(
-                      vectorOpLhs->getOperand(
-                          getIndexPosition(contractOp.getLhs(), innerLoop) + 1),
-                      ivNewInnerLoop);
-                  auto lhsClone =
-                      rewriterNewInnerLoop.clone(*vectorOpLhs, mapping);
-
-                  IRMapping rhsMapping;
-                  rhsMapping.map(
-                      vectorOpRhs->getOperand(
-                          getIndexPosition(contractOp.getRhs(), outerLoop) + 1),
-                      ivOuterLoop);
-                  rhsMapping.map(
-                      vectorOpRhs->getOperand(
-                          getIndexPosition(contractOp.getRhs(), innerLoop) + 1),
-                      ivNewInnerLoop);
-                  auto rhsClone =
-                      rewriterNewInnerLoop.clone(*vectorOpRhs, rhsMapping);
-
-                  SmallVector<Value> accumulators = createTiledDp(
-                      rewriter, locNewInnerLoop, ops, lhsClone->getResult(0),
-                      rhsClone->getResult(0), ipType, opType,
-                      iterArgsNewInnerLoop, blockingFactor);
-
-                  scf::YieldOp::create(rewriterNewInnerLoop, locNewInnerLoop,
-                                       accumulators);
-                });
-
-            scf::YieldOp::create(rewriterOuterLoop, locOuterLoop,
-                                 newInnerLoop.getResults());
-          });
+      if (isVnni) {
+        newLoop = scf::ForOp::create(
+            rewriter, outerLoop.getLoc(), outerLoop.getLowerBound(),
+            outerLoop.getUpperBound(), outerLoop.getStep(), loopItrArgs,
+            [&](OpBuilder &rewriterOuterLoop, Location locOuterLoop,
+                Value ivOuterLoop, ValueRange iterArgsOuterLoop) {
+              auto newInnerLoop = createLoops(
+                  rewriter, innerLoop.getLoc(), innerLoop.getLowerBound(),
+                  innerLoop.getUpperBound(), innerLoop.getStep(),
+                  iterArgsOuterLoop, ipType, opType, blockingFactor, isVnni,
+                  vectorOpLhs, vectorOpRhs, contractOp, outerLoop, innerLoop,
+                  ops, ivOuterLoop, nullptr, true, nullptr, false, false);
+
+              scf::YieldOp::create(rewriterOuterLoop, locOuterLoop,
+                                   newInnerLoop.getResults());
+            });
+
+      } else {
+
+        bool nDimK = false;
+        bool oddDimK = false;
+
+        int64_t ubVal = 16 * blockingFactor;
+        mlir::Value ub = innerLoop.getUpperBound();
+        if (auto constOp = ub.getDefiningOp<mlir::arith::ConstantOp>()) {
+          if (auto intAttr =
+                  llvm::dyn_cast<mlir::IntegerAttr>(constOp.getValue())) {
+            ubVal = intAttr.getInt();
+          }
+        }
+
+        nDimK = ubVal > 16 * blockingFactor;
+        oddDimK = (((ubVal / (16 * blockingFactor)) % 2) == 1) && nDimK;
+
+        rewriter.setInsertionPoint(outerLoop);
+
+        auto c0 =
+            arith::ConstantIndexOp::create(rewriter, outerLoop.getLoc(), 0);
+
+        auto c1 =
+            arith::ConstantIndexOp::create(rewriter, outerLoop.getLoc(), 1);
+
+        auto spillLoopBound = arith::ConstantIndexOp::create(
+            rewriter, outerLoop.getLoc(), 16 * blockingFactor);
+        Value subBRLoop = arith::SubIOp::create(rewriter, outerLoop.getLoc(),
+                                                outerLoop.getUpperBound(), c1);
+        Value subKloop =
+            arith::SubIOp::create(rewriter, innerLoop.getLoc(),
+                                  innerLoop.getUpperBound(), spillLoopBound);
+        auto bufferType =
+            MemRefType::get({2, 32, (blockingFactor * 16)}, ipType);
+        auto bBuffer =
+            memref::AllocaOp::create(rewriter, outerLoop.getLoc(), bufferType);
+
+        // First Shuffling outside the reduction loops
+        IRMapping rhsMapping;
+        rhsMapping.map(
+            vectorOpRhs->getOperand(
+                getIndexPosition(contractOp.getRhs(), outerLoop) + 1),
+            c0);
+        rhsMapping.map(
+            vectorOpRhs->getOperand(
+                getIndexPosition(contractOp.getRhs(), innerLoop) + 1),
+            c0);
+        auto rhsClone = rewriter.clone(*vectorOpRhs, rhsMapping);
+
+        performShuffle(rewriter, outerLoop.getLoc(), rhsClone->getResult(0),
+                       ipType, blockingFactor, bBuffer, c0);
+
+        // First Set of Loops
+        auto newLoop1 = scf::ForOp::create(
+            rewriter, outerLoop.getLoc(), outerLoop.getLowerBound(), subBRLoop,
+            outerLoop.getStep(), loopItrArgs,
+            [&](OpBuilder &rewriterOuterLoop, Location locOuterLoop,
+                Value ivOuterLoop, ValueRange iterArgsOuterLoop) {
+              auto newInnerLoop1 = createLoops(
+                  rewriter, innerLoop.getLoc(), innerLoop.getLowerBound(),
+                  subKloop, innerLoop.getStep(), iterArgsOuterLoop, ipType,
+                  opType, blockingFactor, isVnni, vectorOpLhs, vectorOpRhs,
+                  contractOp, outerLoop, innerLoop, ops, ivOuterLoop, bBuffer,
+                  true, spillLoopBound, nDimK, oddDimK);
+
+              auto newInnerLoop =
+                  createLoops(rewriter, innerLoop.getLoc(), subKloop,
+                              innerLoop.getUpperBound(), innerLoop.getStep(),
+                              newInnerLoop1.getResults(), ipType, opType,
+                              blockingFactor, isVnni, vectorOpLhs, vectorOpRhs,
+                              contractOp, outerLoop, innerLoop, ops,
+                              ivOuterLoop, bBuffer, true, c0, nDimK, oddDimK);
+
+              scf::YieldOp::create(rewriterOuterLoop, locOuterLoop,
+                                   newInnerLoop.getResults());
+            });
+
+        // Last set of Loops
+        newLoop = scf::ForOp::create(
+            rewriter, outerLoop.getLoc(), subBRLoop, outerLoop.getUpperBound(),
+            outerLoop.getStep(), newLoop1.getResults(),
+            [&](OpBuilder &rewriterOuterLoop, Location locOuterLoop,
+                Value ivOuterLoop, ValueRange iterArgsOuterLoop) {
+              auto newInnerLoop1 = createLoops(
+                  rewriter, innerLoop.getLoc(), innerLoop.getLowerBound(),
+                  subKloop, innerLoop.getStep(), iterArgsOuterLoop, ipType,
+                  opType, blockingFactor, isVnni, vectorOpLhs, vectorOpRhs,
+                  contractOp, outerLoop, innerLoop, ops, ivOuterLoop, bBuffer,
+                  true, spillLoopBound, nDimK, oddDimK);
+
+              auto newInnerLoop =
+                  createLoops(rewriter, innerLoop.getLoc(), subKloop,
+                              innerLoop.getUpperBound(), innerLoop.getStep(),
+                              newInnerLoop1.getResults(), ipType, opType,
+                              blockingFactor, isVnni, vectorOpLhs, vectorOpRhs,
+                              contractOp, outerLoop, innerLoop, ops,
+                              ivOuterLoop, bBuffer, false, c0, nDimK, oddDimK);
+
+              scf::YieldOp::create(rewriterOuterLoop, locOuterLoop,
+                                   newInnerLoop.getResults());
+            });
+      }
     }
 
-    // Case 2b: Reduction loop depth is 1.
     if (loopLists.size() == 1) {
       outerLoop = loopLists[0];
+      innerLoop = loopLists[0];
 
       SmallVector<Value> loopItrArgs = createTileZeros(
-          rewriter, outerLoop.getLoc(), opType, outerLoop, ops.size());
-      newLoop = scf::ForOp::create(
-          rewriter, outerLoop.getLoc(), outerLoop.getLowerBound(),
-          outerLoop.getUpperBound(), outerLoop.getStep(), loopItrArgs,
-          [&](OpBuilder &rewriterOuterLoop, Location locOuterLoop,
-              Value ivOuterLoop, ValueRange iterArgsOuterLoop) {
-            IRMapping mapping;
-            mapping.map(
-                vectorOpLhs->getOperand(
-                    getIndexPosition(contractOp.getLhs(), outerLoop) + 1),
-                ivOuterLoop);
-
-            auto lhsClone = rewriterOuterLoop.clone(*vectorOpLhs, mapping);
-
-            IRMapping rhsMapping;
-            rhsMapping.map(
-                vectorOpRhs->getOperand(
-                    getIndexPosition(contractOp.getRhs(), outerLoop) + 1),
-                ivOuterLoop);
-
-            auto rhsClone = rewriterOuterLoop.clone(*vectorOpRhs, rhsMapping);
-
-            SmallVector<Value> accumulators = createTiledDp(
-                rewriter, locOuterLoop, ops, lhsClone->getResult(0),
-                rhsClone->getResult(0), ipType, opType, iterArgsOuterLoop,
-                blockingFactor);
-
-            scf::YieldOp::create(rewriterOuterLoop, locOuterLoop, accumulators);
-          });
+          rewriter, innerLoop.getLoc(), opType, innerLoop, ops.size());
+
+      if (isVnni) {
+
+        newLoop = createLoops(
+            rewriter, innerLoop.getLoc(), innerLoop.getLowerBound(),
+            innerLoop.getUpperBound(), innerLoop.getStep(), loopItrArgs, ipType,
+            opType, blockingFactor, isVnni, vectorOpLhs, vectorOpRhs,
+            contractOp, nullptr, innerLoop, ops, nullptr, nullptr, true,
+            nullptr, false, false);
+
+      } else {
+        bool nDimK = false;
+        bool oddDimK = false;
+
+        int64_t ubVal = 16 * blockingFactor;
+        mlir::Value ub = innerLoop.getUpperBound();
+        if (auto constOp = ub.getDefiningOp<mlir::arith::ConstantOp>()) {
+          if (auto intAttr =
+                  llvm::dyn_cast<mlir::IntegerAttr>(constOp.getValue())) {
+            ubVal = intAttr.getInt();
+          }
+        }
+
+        nDimK = ubVal > 16 * blockingFactor;
+        oddDimK = (((ubVal / (16 * blockingFactor)) % 2) == 1) && nDimK;
+        rewriter.setInsertionPoint(innerLoop);
+        auto c0 =
+            arith::ConstantIndexOp::create(rewriter, innerLoop.getLoc(), 0);
+        auto spillLoopBound = arith::ConstantIndexOp::create(
+            rewriter, innerLoop.getLoc(), 16 * blockingFactor);
+
+        Value subKloop =
+            arith::SubIOp::create(rewriter, innerLoop.getLoc(),
+                                  innerLoop.getUpperBound(), spillLoopBound);
+
+        auto bufferType =
+            MemRefType::get({2, 32, (blockingFactor * 16)}, ipType);
+        auto bBuffer =
+            memref::AllocaOp::create(rewriter, innerLoop.getLoc(), bufferType);
+
+        // First Shuffling outside the reduction loops
+        IRMapping rhsMapping;
+        rhsMapping.map(
+            vectorOpRhs->getOperand(
+                getIndexPosition(contractOp.getRhs(), innerLoop) + 1),
+            c0);
+        auto rhsClone = rewriter.clone(*vectorOpRhs, rhsMapping);
+
+        performShuffle(rewriter, innerLoop.getLoc(), rhsClone->getResult(0),
+                       ipType, blockingFactor, bBuffer, c0);
+
+        auto newLoop1 = createLoops(
+            rewriter, innerLoop.getLoc(), innerLoop.getLowerBound(), subKloop,
+            innerLoop.getStep(), loopItrArgs, ipType, opType, blockingFactor,
+            isVnni, vectorOpLhs, vectorOpRhs, contractOp, nullptr, innerLoop,
+            ops, nullptr, bBuffer, true, spillLoopBound, nDimK, oddDimK);
+
+        newLoop = createLoops(rewriter, innerLoop.getLoc(), subKloop,
+                              innerLoop.getUpperBound(), innerLoop.getStep(),
+                              newLoop1.getResults(), ipType, opType,
+                              blockingFactor, isVnni, vectorOpLhs, vectorOpRhs,
+                              contractOp, nullptr, innerLoop, ops, nullptr,
+                              bBuffer, false, c0, nDimK, oddDimK);
+      }
     }
 
-    // post processing after the loop creation.
     // Copy the amx tile accumulation results to a MemRef buffer, add the
     // initial accumulation value, and store back to the C-Matrix
-    auto bufferType = MemRefType::get({16, 16}, opType);
-    auto bBuffer =
-        memref::AllocaOp::create(rewriter, outerLoop.getLoc(), bufferType);
 
-    SmallVector<Value> dps = newLoop.getResults();
-    for (size_t i = 0; i < ops.size(); i++) {
-      vector::ContractionOp contOp = ops[i];
-      Operation *resultWriteOp =
-          traceToVectorWriteLikeUserOperation(contOp.getResult());
-      rewriter.setInsertionPoint(resultWriteOp);
+    if (!isVnni) {
+      Location loc = outerLoop.getLoc();
+      SmallVector<Value> dps = newLoop.getResults();
+      auto bufferType = MemRefType::get({32, 32}, opType);
+      auto bBuffer =
+          memref::AllocaOp::create(rewriter, outerLoop.getLoc(), bufferType);
+      for (int i = 0, k = 0; i < 32; i = i + 16) {
+        for (int j = 0; j < 32; j = j + 16) {
+          Value indexOp_i =
+              arith::ConstantIndexOp::create(rewriter, outerLoop.getLoc(), i);
+          Value indexOp_j =
+              arith::ConstantIndexOp::create(rewriter, outerLoop.getLoc(), j);
+          amx::TileStoreOp::create(rewriter, outerLoop.getLoc(), bBuffer,
+                                   ValueRange{indexOp_i, indexOp_j}, dps[k]);
+          k++;
+        }
+      }
+      auto c0 = arith::ConstantIndexOp::create(rewriter, loc, 0);
+      auto c16 = arith::ConstantIndexOp::create(rewriter, loc, 16);
+      auto one = arith::ConstantIndexOp::create(rewriter, loc, 1);
+      auto mBound = arith::ConstantIndexOp::create(rewriter, loc, 32);
 
-      Value indexOp_0 =
-          arith::ConstantIndexOp::create(rewriter, outerLoop.getLoc(), 0);
+      scf::ForOp::create(
+          rewriter, loc, c0, mBound, one, ValueRange{},
+          [&](OpBuilder &nestedBuilder, Location loc, Value iv,
+              ValueRange iterArgs) {
+            auto row = vector::LoadOp::create(rewriter, loc,
+                                              VectorType::get(16, opType),
+                                              bBuffer, ValueRange{iv, c0});
 
-      amx::TileStoreOp::create(rewriter, outerLoop.getLoc(), bBuffer,
-                               ValueRange{indexOp_0, indexOp_0}, dps[i]);
+            auto row2 = vector::LoadOp::create(rewriter, loc,
+                                               VectorType::get(16, opType),
+                                               bBuffer, ValueRange{iv, c16});
 
-      auto c0 = arith::ConstantIndexOp::create(rewriter, outerLoop.getLoc(), 0);
-      auto one =
-          arith::ConstantIndexOp::create(rewriter, outerLoop.getLoc(), 1);
-      auto mBound =
-          arith::ConstantIndexOp::create(rewriter, outerLoop.getLoc(), 16);
+            auto shuffle1 = vector::ShuffleOp::create(
+                rewriter, loc, VectorType::get(16, opType), row, row2,
+                ArrayRef<int64_t>{0, 1, 2, 3, 16, 17, 18, 19, 4, 5, 6, 7, 20,
+                                  21, 22, 23});
 
-      scf::ForOp::create(
-          rewriter, outerLoop.getLoc(), c0, mBound, one, ValueRange{},
-          [&](OpBuilder &builder, Location loc, Value iv, ValueRange iterArgs) {
-            auto resultAcc = vector::LoadOp::create(
-                rewriter, loc, VectorType::get(16, opType), bBuffer,
-                ValueRange{iv, c0});
+            auto shuffle2 = vector::ShuffleOp::create(
+                rewriter, loc, VectorType::get(16, opType), row, row2,
+                ArrayRef<int64_t>{8, 9, 10, 11, 24, 25, 26, 27, 12, 13, 14, 15,
+                                  28, 29, 30, 31});
 
             Operation *accReadOp =
-                traceToVectorReadLikeParentOperation(ops[i].getAcc());
+                traceToVectorReadLikeParentOperation(contractOp.getAcc());
 
             Value srcBuffAcc;
             SmallVector<Value> indicesAcc;
@@ -641,24 +1188,119 @@ struct VectorContractToAMXDotProduct
                       });
                 });
 
-            Value sum = arith::AddIOp::create(builder, loc, iv, indicesAcc[0]);
-            indicesAcc[indicesAcc.size() - 2] = sum;
-
-            auto acc = vector::LoadOp::create(rewriter, loc,
-                                              VectorType::get(16, opType),
-                                              srcBuffAcc, indicesAcc);
-            Value addition;
-            if (ipType.isBF16())
-              addition = arith::AddFOp::create(rewriter, loc, resultAcc, acc);
-
-            if (ipType.isSignlessInteger(8))
-              addition = arith::AddIOp::create(rewriter, loc, resultAcc, acc);
-
-            vector::StoreOp::create(builder, loc, addition, srcBuffAcc,
+            indicesAcc[indicesAcc.size() - 2] = iv;
+            indicesAcc[indicesAcc.size() - 1] = c0;
+
+            Value valueCRow1 = vector::LoadOp::create(
+                rewriter, loc, VectorType::get(16, opType), srcBuffAcc,
+                indicesAcc);
+            indicesAcc[indicesAcc.size() - 1] = c16;
+
+            Value valueCRow2 = vector::LoadOp::create(
+                rewriter, loc, VectorType::get(16, opType), srcBuffAcc,
+                indicesAcc);
+            Value addOp;
+            Value addOp2;
+
+            if (ipType.isBF16()) {
+              addOp =
+                  arith::AddFOp::create(rewriter, loc, shuffle1, valueCRow1);
+
+              addOp2 =
+                  arith::AddFOp::create(rewriter, loc, shuffle2, valueCRow2);
+            }
+
+            if (ipType.isSignlessInteger(8)) {
+              addOp =
+                  arith::AddIOp::create(rewriter, loc, shuffle1, valueCRow1);
+
+              addOp2 =
+                  arith::AddIOp::create(rewriter, loc, shuffle2, valueCRow2);
+            }
+            indicesAcc[indicesAcc.size() - 1] = c0;
+            vector::StoreOp::create(rewriter, loc, addOp, srcBuffAcc,
+                                    indicesAcc);
+            indicesAcc[indicesAcc.size() - 1] = c16;
+            vector::StoreOp::create(rewriter, loc, addOp2, srcBuffAcc,
                                     indicesAcc);
 
-            scf::YieldOp::create(builder, outerLoop.getLoc());
+            scf::YieldOp::create(nestedBuilder, loc);
           });
+    }
+    auto bufferType = MemRefType::get({16, 16}, opType);
+    auto bBuffer =
+        memref::AllocaOp::create(rewriter, outerLoop.getLoc(), bufferType);
+
+    SmallVector<Value> dps = newLoop.getResults();
+    for (size_t i = 0; i < ops.size(); i++) {
+      vector::ContractionOp contOp = ops[i];
+      Operation *resultWriteOp =
+          traceToVectorWriteLikeUserOperation(contOp.getResult());
+      if (isVnni) {
+        rewriter.setInsertionPoint(resultWriteOp);
+
+        Value indexOp_0 =
+            arith::ConstantIndexOp::create(rewriter, outerLoop.getLoc(), 0);
+
+        amx::TileStoreOp::create(rewriter, outerLoop.getLoc(), bBuffer,
+                                 ValueRange{indexOp_0, indexOp_0}, dps[i]);
+
+        auto c0 =
+            arith::ConstantIndexOp::create(rewriter, outerLoop.getLoc(), 0);
+        auto one =
+            arith::ConstantIndexOp::create(rewriter, outerLoop.getLoc(), 1);
+        auto mBound =
+            arith::ConstantIndexOp::create(rewriter, outerLoop.getLoc(), 16);
+
+        scf::ForOp::create(
+            rewriter, outerLoop.getLoc(), c0, mBound, one, ValueRange{},
+            [&](OpBuilder &builder, Location loc, Value iv,
+                ValueRange iterArgs) {
+              auto resultAcc = vector::LoadOp::create(
+                  rewriter, loc, VectorType::get(16, opType), bBuffer,
+                  ValueRange{iv, c0});
+
+              Operation *accReadOp =
+                  traceToVectorReadLikeParentOperation(ops[i].getAcc());
+
+              Value srcBuffAcc;
+              SmallVector<Value> indicesAcc;
+
+              llvm::TypeSwitch<Operation *>(accReadOp)
+                  .Case<TransferReadOp, LoadOp>([&](auto readOp) {
+                    srcBuffAcc = readOp.getOperand(0);
+
+                    auto indices = readOp.getIndices();
+                    indicesAcc.reserve(indices.size());
+
+                    llvm::transform(
+                        indices, std::back_inserter(indicesAcc),
+                        [&](OpFoldResult ofr) {
+                          return mlir::getValueOrCreateConstantIndexOp(
+                              rewriter, loc, ofr);
+                        });
+                  });
+
+              Value sum =
+                  arith::AddIOp::create(builder, loc, iv, indicesAcc[0]);
+              indicesAcc[0] = sum;
+
+              auto acc = vector::LoadOp::create(rewriter, loc,
+                                                VectorType::get(16, opType),
+                                                srcBuffAcc, indicesAcc);
+              Value addition;
+              if (ipType.isBF16())
+                addition = arith::AddFOp::create(rewriter, loc, resultAcc, acc);
+
+              if (ipType.isSignlessInteger(8))
+                addition = arith::AddIOp::create(rewriter, loc, resultAcc, acc);
+
+              vector::StoreOp::create(builder, loc, addition, srcBuffAcc,
+                                      indicesAcc);
+
+              scf::YieldOp::create(builder, outerLoop.getLoc());
+            });
+      }
 
       rewriter.eraseOp(resultWriteOp);
     }
diff --git a/mlir/test/Dialect/X86/AMX/vector-contract-to-tiled-dp.mlir b/mlir/test/Dialect/X86/AMX/vector-contract-to-tiled-dp.mlir
index cde15b680a037..151946453df81 100644
--- a/mlir/test/Dialect/X86/AMX/vector-contract-to-tiled-dp.mlir
+++ b/mlir/test/Dialect/X86/AMX/vector-contract-to-tiled-dp.mlir
@@ -216,6 +216,122 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
+!vecA = vector<16x64xi8>
+!vecB = vector<64x16xi8>
+!vecC = vector<16x16xi32>
+!memrefA = memref<32x64xi8>
+!memrefB = memref<64x32xi8>
+!memrefC = memref<32x32xi32>
+#map = affine_map<(d1, d2, d3) -> (d1, d3)>
+#map1 = affine_map<(d1, d2, d3) -> (d3, d2)>
+#map2 = affine_map<(d1, d2, d3) -> (d1, d2)>
+func.func @online_packing_int8(
+  %arg0: !memrefA, %arg1: !memrefB, %arg2: !memrefC) -> !memrefC
+{
+  %c0 = arith.constant 0 : index
+  %0 = ub.poison : i8
+  %32 = ub.poison : i32
+
+  %1 = vector.transfer_read %arg0[%c0, %c0], %0 {in_bounds = [true, true]} :
+        !memrefA, !vecA
+  %2 = vector.transfer_read %arg1[%c0, %c0], %0 {in_bounds = [true, true]} :
+        !memrefB, !vecB
+
+  %3 = vector.transfer_read %arg2[%c0, %c0], %32 {in_bounds = [true, true]} : !memrefC, !vecC
+
+  %4 = vector.contract {
+    indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["parallel", "parallel", "reduction"],
+    kind = #vector.kind<add>}
+    %1, %2, %3 : !vecA, !vecB into !vecC
+
+  vector.transfer_write %4, %arg2[%c0, %c0] {in_bounds = [true, true]} : !vecC, !memrefC
+
+  return %arg2 : !memrefC
+}
+
+// CHECK-LABEL: @online_packing_int8
+// CHECK: x86.amx.tile_load {{.*}} !x86.amx.tile<16x64xi8>
+// CHECK: scf.for
+// CHECK: vector.shuffle{{.*}}[0, 16, 32, 48, 1, 17, 33, 49, 2, 18, 34, 50, 3, 19, 35, 51, 4, 20, 36, 52, 5, 21, 37, 53, 6, 22, 38, 54, 7, 23, 39, 55] : vector<32xi8>, vector<32xi8>
+// CHECK-NEXT: vector.shuffle{{.*}}[8, 24, 40, 56, 9, 25, 41, 57, 10, 26, 42, 58, 11, 27, 43, 59, 12, 28, 44, 60, 13, 29, 45, 61, 14, 30, 46, 62, 15, 31, 47, 63] : vector<32xi8>, vector<32xi8>
+// CHECK: x86.amx.tile_load {{.*}} !x86.amx.tile<16x64xi8>
+// CHECK: x86.amx.tile_load {{.*}} !x86.amx.tile<16x16xi32>
+// CHECK: x86.amx.tile_muli
+// CHECK: x86.amx.tile_store {{.*}} !x86.amx.tile<16x16xi32>
+// CHECK-NOT: vector.contract
+
+
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.x86.vector_contract_to_amx_dot_product
+    } : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+!vecA = vector<1x16x32xbf16>
+!vecB = vector<1x32x16xbf16>
+!vecC = vector<16x16xf32>
+!memrefA = memref<1x32x32xbf16>
+!memrefB = memref<1x32x32xbf16>
+!memrefC = memref<32x32xf32>
+#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
+#map2 = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+func.func @online_packing_bf16(
+  %arg0: !memrefA, %arg1: !memrefB, %arg2: !memrefC) -> !memrefC
+{
+  %c0 = arith.constant 0 : index
+  %0 = ub.poison : bf16
+  %32 = ub.poison : f32
+
+  %1 = vector.transfer_read %arg0[%c0, %c0, %c0], %0 {in_bounds = [true, true, true]} :
+        !memrefA, !vecA
+  %2 = vector.transfer_read %arg1[%c0, %c0, %c0], %0 {in_bounds = [true, true, true]} :
+        !memrefB, !vecB
+
+  %3 = vector.transfer_read %arg2[%c0, %c0], %32 {in_bounds = [true, true]} : !memrefC, !vecC
+
+  %4 = vector.contract {
+    indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["reduction", "parallel", "parallel", "reduction"],
+    kind = #vector.kind<add>}
+    %1, %2, %3 : !vecA, !vecB into !vecC
+
+  vector.transfer_write %4, %arg2[%c0, %c0] {in_bounds = [true, true]} : !vecC, !memrefC
+
+  return %arg2 : !memrefC
+}
+
+// CHECK-LABEL: @online_packing_bf16
+// CHECK: x86.amx.tile_load {{.*}} !x86.amx.tile<16x32xbf16>
+// CHECK: scf.for
+// CHECK: vector.shuffle{{.*}}[0, 16, 1, 17, 2, 18, 3, 19, 4, 20, 5, 21, 6, 22, 7, 23] : vector<16xbf16>, vector<16xbf16>
+// CHECK-NEXT: vector.shuffle{{.*}}[8, 24, 9, 25, 10, 26, 11, 27, 12, 28, 13, 29, 14, 30, 15, 31] : vector<16xbf16>, vector<16xbf16>
+// CHECK: x86.amx.tile_load {{.*}} !x86.amx.tile<16x32xbf16>
+// CHECK: x86.amx.tile_load {{.*}} !x86.amx.tile<16x16xf32>
+// CHECK: x86.amx.tile_mulf
+// CHECK: x86.amx.tile_store {{.*}} !x86.amx.tile<16x16xf32>
+// CHECK-NOT: vector.contract
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.x86.vector_contract_to_amx_dot_product
+    } : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
 !vecAB = vector<1x16x16x2xbf16>
 !vecC = vector<16x16xf32>
 !memrefA = memref<1x32x16x2xbf16, strided<[8192, 128, 2, 1], offset: ?>>
@@ -483,6 +599,199 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
+!vecA = vector<1x16x32xbf16>
+!vecB = vector<1x32x16xbf16>
+!vecC = vector<16x16xf32>
+!memrefA = memref<1x32x32xbf16, strided<[6144, 96, 1], offset: ?>>
+!memrefB = memref<1x32x32xbf16, strided<[12288, 128, 1], offset: ?>>
+!memrefC = memref<32x32xf32, strided<[128, 1], offset: ?>>
+
+#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
+#map2 = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+
+func.func @online_packing_bf16_loop(%arg0: memref<16x64x96xbf16>, %arg1: memref<16x96x128xbf16>, %arg2: memref<64x128xf32>) -> memref<64x128xf32> {
+  %0 = ub.poison : f32
+  %1 = ub.poison : bf16
+  %c0 = arith.constant 0 : index
+  %c64 = arith.constant 64 : index
+  %c128 = arith.constant 128 : index
+  %c16 = arith.constant 16 : index
+  %c96 = arith.constant 96 : index
+  %c32 = arith.constant 32 : index
+  %c1 = arith.constant 1 : index
+  scf.for %arg3 = %c0 to %c64 step %c32 {
+    scf.for %arg4 = %c0 to %c128 step %c32 {
+
+      %subview = memref.subview %arg2[%arg3, %arg4] [32, 32] [1, 1] :
+                memref<64x128xf32> to !memrefC
+      %2 = vector.transfer_read %subview[%c0, %c0], %0 {in_bounds = [true, true]} :
+                !memrefC, !vecC
+      %3 = vector.transfer_read %subview[%c0, %c16], %0 {in_bounds = [true, true]} :
+                !memrefC, !vecC
+      %4 = vector.transfer_read %subview[%c16, %c0], %0 {in_bounds = [true, true]} :
+                !memrefC, !vecC
+      %5 = vector.transfer_read %subview[%c16, %c16], %0 {in_bounds = [true, true]} :
+                !memrefC, !vecC
+
+      %6:4 = scf.for %arg5 = %c0 to %c16 step %c1 iter_args(%arg6 = %2, %arg7 = %3, %arg8 = %4, %arg9 = %5) -> (!vecC, !vecC, !vecC, !vecC) {
+        %7:4 = scf.for %arg10 = %c0 to %c96 step %c32 iter_args(%arg11 = %arg6, %arg12 = %arg7, %arg13 = %arg8, %arg14 = %arg9) -> (!vecC, !vecC, !vecC, !vecC) {
+
+          %subview_0 = memref.subview %arg0[%arg5, %arg3, %arg10] [1, 32, 32] [1, 1, 1] :
+                memref<16x64x96xbf16> to !memrefA
+          %subview_1 = memref.subview %arg1[%arg5, %arg10, %arg4] [1, 32, 32] [1, 1, 1] :
+                memref<16x96x128xbf16> to !memrefB
+          %8 = vector.transfer_read %subview_0[%c0, %c0, %c0], %1 {in_bounds = [true, true, true]} :
+                !memrefA, !vecA
+          %9 = vector.transfer_read %subview_0[%c0, %c16, %c0], %1 {in_bounds = [true, true, true]} :
+                !memrefA, !vecA
+          %10 = vector.transfer_read %subview_1[%c0, %c0, %c0], %1 {in_bounds = [true, true, true]} :
+                !memrefB, !vecB
+          %11 = vector.transfer_read %subview_1[%c0, %c0, %c16], %1 {in_bounds = [true, true, true]} :
+                !memrefB, !vecB
+
+          %12 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types =
+                ["reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+                %8, %10, %arg11 {unroll_shape = array<i64: 1, 16, 16, 32>} : !vecA, !vecB into !vecC
+          %13 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types =
+                ["reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+                %8, %11, %arg12 {unroll_shape = array<i64: 1, 16, 16, 32>} : !vecA, !vecB into !vecC
+          %14 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types =
+                ["reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+                %9, %10, %arg13 {unroll_shape = array<i64: 1, 16, 16, 32>} : !vecA, !vecB into !vecC
+          %15 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types =
+                ["reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+                %9, %11, %arg14 {unroll_shape = array<i64: 1, 16, 16, 32>} : !vecA, !vecB into !vecC
+
+          scf.yield %12, %13, %14, %15 : !vecC, !vecC, !vecC, !vecC
+        }
+        scf.yield %7#0, %7#1, %7#2, %7#3 : !vecC, !vecC, !vecC, !vecC
+      }
+      vector.transfer_write %6#3, %subview[%c16, %c16] {in_bounds = [true, true]} :
+                !vecC, !memrefC
+      vector.transfer_write %6#2, %subview[%c16, %c0] {in_bounds = [true, true]} :
+                !vecC, !memrefC
+      vector.transfer_write %6#1, %subview[%c0, %c16] {in_bounds = [true, true]} :
+                !vecC, !memrefC
+      vector.transfer_write %6#0, %subview[%c0, %c0] {in_bounds = [true, true]} :
+                !vecC, !memrefC
+    }
+  }
+  %alloc = memref.alloc() : memref<64x128xf32>
+  memref.copy %arg2, %alloc : memref<64x128xf32> to memref<64x128xf32>
+  return %alloc : memref<64x128xf32>
+}
+
+// CHECK-LABEL: @online_packing_bf16_loop
+// CHECK-COUNT-4: x86.amx.tile_zero : !x86.amx.tile<16x16xf32>
+// CHECK-COUNT-4: scf.for {{.*}} -> (!x86.amx.tile<16x16xf32>, !x86.amx.tile<16x16xf32>, !x86.amx.tile<16x16xf32>, !x86.amx.tile<16x16xf32>) {
+// CHECK: vector.shuffle{{.*}}[0, 32, 1, 33, 2, 34, 3, 35, 8, 40, 9, 41, 10, 42, 11, 43, 16, 48, 17, 49, 18, 50, 19, 51, 24, 56, 25, 57, 26, 58, 27, 59] : vector<32xbf16>, vector<32xbf16>
+// CHECK-NEXT: vector.shuffle{{.*}}[4, 36, 5, 37, 6, 38, 7, 39, 12, 44, 13, 45, 14, 46, 15, 47, 20, 52, 21, 53, 22, 54, 23, 55, 28, 60, 29, 61, 30, 62, 31, 63] : vector<32xbf16>, vector<32xbf16>
+// CHECK: x86.amx.tile_load
+// CHECK: x86.amx.tile_mulf
+// CHECK: scf.yield {{.*}} : !x86.amx.tile<16x16xf32>, !x86.amx.tile<16x16xf32>, !x86.amx.tile<16x16xf32>, !x86.amx.tile<16x16xf32>
+// CHECK: vector.shuffle{{.*}}[0, 1, 2, 3, 16, 17, 18, 19, 4, 5, 6, 7, 20, 21, 22, 23] : vector<16xf32>, vector<16xf32>
+// CHECK-NEXT: vector.shuffle{{.*}}[8, 9, 10, 11, 24, 25, 26, 27, 12, 13, 14, 15, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32>
+// CHECK-NOT: scf.for {{.*}} vector<16x16xf32>, vector<16x16xf32>, vector<16x16xf32>, vector<16x16xf32>
+// CHECK-NOT: vector.contract
+
+
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.x86.vector_contract_to_amx_dot_product
+    } : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+!vecA = vector<16x64xi8>
+!vecB = vector<64x16xi8>
+!vecC = vector<16x16xi32>
+!memrefA = memref<32x64xi8, strided<[256, 1], offset: ?>>
+!memrefB = memref<64x32xi8, strided<[128, 1], offset: ?>>
+!memrefC = memref<32x32xi32, strided<[128, 1], offset: ?>>
+
+#map = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
+func.func @online_packing_int8_matmul_loop(%arg0: memref<64x256xi8>, %arg1: memref<256x128xi8>, %arg2: memref<64x128xi32>) -> memref<64x128xi32> {
+  %c16 = arith.constant 16 : index
+  %0 = ub.poison : i32
+  %1 = ub.poison : i8
+  %c0 = arith.constant 0 : index
+  %c64 = arith.constant 64 : index
+  %c128 = arith.constant 128 : index
+  %c256 = arith.constant 256 : index
+  %c32 = arith.constant 32 : index
+  scf.for %arg3 = %c0 to %c64 step %c32 {
+    scf.for %arg4 = %c0 to %c128 step %c32 {
+      %subview = memref.subview %arg2[%arg3, %arg4] [32, 32] [1, 1] : memref<64x128xi32> to !memrefC
+      %2 = vector.transfer_read %subview[%c0, %c0], %0 {in_bounds = [true, true]} : !memrefC, !vecC
+      %3 = vector.transfer_read %subview[%c0, %c16], %0 {in_bounds = [true, true]} : !memrefC, !vecC
+      %4 = vector.transfer_read %subview[%c16, %c0], %0 {in_bounds = [true, true]} : !memrefC, !vecC
+      %5 = vector.transfer_read %subview[%c16, %c16], %0 {in_bounds = [true, true]} : !memrefC, !vecC
+      %6:4 = scf.for %arg5 = %c0 to %c256 step %c64 iter_args(%arg6 = %2, %arg7 = %3, %arg8 = %4, %arg9 = %5) -> (!vecC, !vecC, !vecC, !vecC) {
+        %subview_0 = memref.subview %arg0[%arg3, %arg5] [32, 64] [1, 1] : memref<64x256xi8> to !memrefA
+        %subview_1 = memref.subview %arg1[%arg5, %arg4] [64, 32] [1, 1] : memref<256x128xi8> to !memrefB
+        %7 = vector.transfer_read %subview_0[%c0, %c0], %1 {in_bounds = [true, true]} : !memrefA, !vecA
+        %8 = vector.transfer_read %subview_0[%c16, %c0], %1 {in_bounds = [true, true]} : !memrefA, !vecA
+        %9 = vector.transfer_read %subview_1[%c0, %c0], %1 {in_bounds = [true, true]} : !memrefB, !vecB
+        %10 = vector.transfer_read %subview_1[%c0, %c16], %1 {in_bounds = [true, true]} : !memrefB, !vecB
+        %11 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types =
+                ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+                %7, %9, %arg6 {unroll_shape = array<i64: 16, 16, 64>} : !vecA, !vecB into !vecC
+        %12 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types =
+                ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+                %7, %10, %arg7 {unroll_shape = array<i64: 16, 16, 64>} : !vecA, !vecB into !vecC
+        %13 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types =
+                ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+                %8, %9, %arg8 {unroll_shape = array<i64: 16, 16, 64>} : !vecA, !vecB into !vecC
+        %14 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types =
+                ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+                %8, %10, %arg9 {unroll_shape = array<i64: 16, 16, 64>} : !vecA, !vecB into !vecC
+        scf.yield %11, %12, %13, %14 : !vecC, !vecC, !vecC, !vecC
+      }
+      vector.transfer_write %6#3, %subview[%c16, %c16] {in_bounds = [true, true]} : !vecC, !memrefC
+      vector.transfer_write %6#2, %subview[%c16, %c0] {in_bounds = [true, true]} : !vecC, !memrefC
+      vector.transfer_write %6#1, %subview[%c0, %c16] {in_bounds = [true, true]} : !vecC, !memrefC
+      vector.transfer_write %6#0, %subview[%c0, %c0] {in_bounds = [true, true]} : !vecC, !memrefC
+    }
+  }
+  %alloc = memref.alloc() : memref<64x128xi32>
+  memref.copy %arg2, %alloc : memref<64x128xi32> to memref<64x128xi32>
+  return %alloc : memref<64x128xi32>
+}
+
+// CHECK-LABEL: @online_packing_int8_matmul_loop
+// CHECK-COUNT-4: x86.amx.tile_zero : !x86.amx.tile<16x16xi32>
+// CHECK: scf.for {{.*}} -> (!x86.amx.tile<16x16xi32>, !x86.amx.tile<16x16xi32>, !x86.amx.tile<16x16xi32>, !x86.amx.tile<16x16xi32>) {
+// CHECK: vector.shuffle{{.*}}[0, 32, 64, 96, 1, 33, 65, 97, 2, 34, 66, 98, 3, 35, 67, 99, 4, 36, 68, 100, 5, 37, 69, 101, 6, 38, 70, 102, 7, 39, 71, 103, 8, 40, 72, 104, 9, 41, 73, 105, 10, 42, 74, 106, 11, 43, 75, 107, 12, 44, 76, 108, 13, 45, 77, 109, 14, 46, 78, 110, 15, 47, 79, 111] : vector<64xi8>, vector<64xi8>
+// CHECK-NEXT: vector.shuffle{{.*}}[16, 48, 80, 112, 17, 49, 81, 113, 18, 50, 82, 114, 19, 51, 83, 115, 20, 52, 84, 116, 21, 53, 85, 117, 22, 54, 86, 118, 23, 55, 87, 119, 24, 56, 88, 120, 25, 57, 89, 121, 26, 58, 90, 122, 27, 59, 91, 123, 28, 60, 92, 124, 29, 61, 93, 125, 30, 62, 94, 126, 31, 63, 95, 127] : vector<64xi8>, vector<64xi8>
+// CHECK: x86.amx.tile_load
+// CHECK: x86.amx.tile_muli
+// CHECK: scf.yield {{.*}} !x86.amx.tile<16x16xi32>, !x86.amx.tile<16x16xi32>, !x86.amx.tile<16x16xi32>, !x86.amx.tile<16x16xi32>
+// CHECK: vector.shuffle{{.*}}[0, 1, 2, 3, 16, 17, 18, 19, 4, 5, 6, 7, 20, 21, 22, 23] : vector<16xi32>, vector<16xi32>
+// CHECK-NEXT: vector.shuffle{{.*}}[8, 9, 10, 11, 24, 25, 26, 27, 12, 13, 14, 15, 28, 29, 30, 31] : vector<16xi32>, vector<16xi32>
+// CHECK-NOT: scf.for {{.*}} vector<16x16xi32>, vector<16x16xi32>, vector<16x16xi32>, vector<16x16xi32>
+// CHECK-NOT: vector.contract
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.x86.vector_contract_to_amx_dot_product
+    } : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
 !vecA = vector<1x16x16x4xi8>
 !vecB = vector<1x16x16x4xi8>
 !vecC = vector<16x16xi32>
@@ -637,32 +946,32 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
-!vecA = vector<16x64xi8>
-!vecB = vector<64x16xi8>
-!vecC = vector<16x16xi32>
-!memrefA = memref<32x64xi8>
-!memrefB = memref<64x32xi8>
-!memrefC = memref<32x32xi32>
-#map = affine_map<(d1, d2, d3) -> (d1, d3)>
-#map1 = affine_map<(d1, d2, d3) -> (d3, d2)>
-#map2 = affine_map<(d1, d2, d3) -> (d1, d2)>
-func.func @negative_no_vnni_packed(
+!vecA = vector<1x16x32xbf16>
+!vecB = vector<1x32x32xbf16>
+!vecC = vector<16x32xf32>
+!memrefA = memref<1x32x32xbf16>
+!memrefB = memref<1x32x32xbf16>
+!memrefC = memref<32x32xf32>
+#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
+#map2 = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+func.func @negative_wrong_dimensions_online_packing(
   %arg0: !memrefA, %arg1: !memrefB, %arg2: !memrefC) -> !memrefC
 {
   %c0 = arith.constant 0 : index
-  %0 = ub.poison : i8
-  %32 = ub.poison : i32
+  %0 = ub.poison : bf16
+  %32 = ub.poison : f32
 
-  %1 = vector.transfer_read %arg0[%c0, %c0], %0 {in_bounds = [true, true]} :
+  %1 = vector.transfer_read %arg0[%c0, %c0, %c0], %0 {in_bounds = [true, true, true]} :
         !memrefA, !vecA
-  %2 = vector.transfer_read %arg1[%c0, %c0], %0 {in_bounds = [true, true]} :
+  %2 = vector.transfer_read %arg1[%c0, %c0, %c0], %0 {in_bounds = [true, true, true]} :
         !memrefB, !vecB
 
   %3 = vector.transfer_read %arg2[%c0, %c0], %32 {in_bounds = [true, true]} : !memrefC, !vecC
 
   %4 = vector.contract {
     indexing_maps = [#map, #map1, #map2],
-    iterator_types = ["parallel", "parallel", "reduction"],
+    iterator_types = ["reduction", "parallel", "parallel", "reduction"],
     kind = #vector.kind<add>}
     %1, %2, %3 : !vecA, !vecB into !vecC
 
@@ -671,13 +980,13 @@ func.func @negative_no_vnni_packed(
   return %arg2 : !memrefC
 }
 
-// CHECK-LABEL: @negative_no_vnni_packed
-// CHECK-NOT: x86.amx.tile_load {{.*}} !x86.amx.tile<16x64xi8>
-// CHECK-NOT: x86.amx.tile_muli
-// CHECK-NOT: x86.amx.tile_store {{.*}} !x86.amx.tile<16x16xi32>
+// CHECK-LABEL: @negative_wrong_dimensions_online_packing
+// CHECK-NOT: x86.amx.tile_load
+// CHECK-NOT: vector.shuffle
+// CHECK-NOT: x86.amx.tile_mulf
+// CHECK-NOT: x86.amx.tile_store
 // CHECK: vector.contract
 
-
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
     %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op

>From caabcfde4a773121a31a4442e8dcab28c469c47c Mon Sep 17 00:00:00 2001
From: Arun Thangamani <arun.thangamani at intel.com>
Date: Thu, 26 Mar 2026 21:21:44 -0700
Subject: [PATCH 2/5] code refactoring + addition of two -ve test-cases

---
 .../VectorContractToAMXDotProduct.cpp         | 491 ++++++++++--------
 .../X86/AMX/vector-contract-to-tiled-dp.mlir  | 151 ++++++
 2 files changed, 439 insertions(+), 203 deletions(-)

diff --git a/mlir/lib/Dialect/X86/Transforms/VectorContractToAMXDotProduct.cpp b/mlir/lib/Dialect/X86/Transforms/VectorContractToAMXDotProduct.cpp
index 2b159e6f59cb9..744c065b4e05e 100644
--- a/mlir/lib/Dialect/X86/Transforms/VectorContractToAMXDotProduct.cpp
+++ b/mlir/lib/Dialect/X86/Transforms/VectorContractToAMXDotProduct.cpp
@@ -1,4 +1,4 @@
-//===- VectorContractToAMXDotProduct.cpp ----------------------===//
+//===- VectorContractToAMXDotProduct.cpp ----------------------------------===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -211,42 +211,41 @@ static amx::TileLoadOp createTileLoads(OpBuilder &rewriter, Location loc,
 }
 
 static void performShuffle(OpBuilder &rewriter, Location loc, Value matB,
-                           Type ipType, unsigned int offset, Value bBuffer,
-                           Value allocStore) {
+                           Type ipType, unsigned int offset, Value packedBuffer,
+                           Value indxToStoreInBuffer) {
 
-  auto subview = matB.getDefiningOp<mlir::memref::SubViewOp>();
   Value c0 = arith::ConstantIndexOp::create(rewriter, loc, 0);
-  SmallVector<Value> vals(subview.getOffsets().size(), c0);
-
-  // Value c1 = arith::ConstantIndexOp::create(rewriter, loc, 1);
   Value c16 = arith::ConstantIndexOp::create(rewriter, loc, 16);
 
+  auto subview = matB.getDefiningOp<mlir::memref::SubViewOp>();
+  SmallVector<Value> subviewOffset(subview.getOffsets().size(), c0);
+
   Value cStep = arith::ConstantIndexOp::create(rewriter, loc, offset);
   Value cBound = arith::ConstantIndexOp::create(rewriter, loc, (16 * offset));
-  Value nLoadIndx = arith::ConstantIndexOp::create(rewriter, loc, (offset / 2));
+  Value offsetIndx =
+      arith::ConstantIndexOp::create(rewriter, loc, (offset / 2));
 
   scf::ForOp::create(
       rewriter, loc, c0, cBound, cStep, ValueRange{},
       [&](OpBuilder &nestedBuilder, Location loc, Value iv,
           ValueRange iterArgs) {
-        Value i1_load = arith::AddIOp::create(rewriter, loc, nLoadIndx, iv);
-
-        vals[vals.size() - 2] = iv;
-        ValueRange range1(vals);
+        subviewOffset[subviewOffset.size() - 2] = iv;
         auto vec1 = vector::LoadOp::create(
             rewriter, loc, VectorType::get((16 * offset), ipType), matB,
-            range1);
+            ValueRange(subviewOffset));
 
-        vals[vals.size() - 2] = i1_load;
-        ValueRange range2(vals);
+        // Increment the iv by 1 or 2 based on the type to load the next 32/64
+        // elements
+        Value incIV = arith::AddIOp::create(rewriter, loc, offsetIndx, iv);
+        subviewOffset[subviewOffset.size() - 2] = incIV;
         auto vec2 = vector::LoadOp::create(
             rewriter, loc, VectorType::get((16 * offset), ipType), matB,
-            range2);
+            ValueRange(subviewOffset));
 
         vector::ShuffleOp shuffle1;
         vector::ShuffleOp shuffle2;
 
-        if (offset == 2) {
+        if (ipType.isBF16()) {
 
           shuffle1 = vector::ShuffleOp::create(
               rewriter, loc, VectorType::get({(16 * offset)}, ipType), vec1,
@@ -263,7 +262,7 @@ static void performShuffle(OpBuilder &rewriter, Location loc, Value matB,
                                 23, 55, 28, 60, 29, 61, 30, 62, 31, 63});
         }
 
-        if (offset == 4) {
+        if (ipType.isSignlessInteger(8)) {
 
           shuffle1 = vector::ShuffleOp::create(
               rewriter, loc, VectorType::get({(16 * offset)}, ipType), vec1,
@@ -285,13 +284,15 @@ static void performShuffle(OpBuilder &rewriter, Location loc, Value matB,
                   90, 122, 27, 59,  91, 123, 28, 60,  92, 124, 29, 61,  93, 125,
                   30, 62,  94, 126, 31, 63,  95, 127});
         }
-        Value j_pos = arith::DivUIOp::create(rewriter, loc, iv, cStep);
-        Value j16_pos = arith::AddIOp::create(rewriter, loc, c16, j_pos);
 
-        vector::StoreOp::create(rewriter, loc, shuffle1, bBuffer,
-                                ValueRange{allocStore, j_pos, c0});
-        vector::StoreOp::create(rewriter, loc, shuffle2, bBuffer,
-                                ValueRange{allocStore, j16_pos, c0});
+        // iv to store the shuffled elements
+        Value ivShuff1 = arith::DivUIOp::create(rewriter, loc, iv, cStep);
+        Value ivShuff2 = arith::AddIOp::create(rewriter, loc, ivShuff1, c16);
+
+        vector::StoreOp::create(rewriter, loc, shuffle1, packedBuffer,
+                                ValueRange{indxToStoreInBuffer, ivShuff1, c0});
+        vector::StoreOp::create(rewriter, loc, shuffle2, packedBuffer,
+                                ValueRange{indxToStoreInBuffer, ivShuff2, c0});
 
         scf::YieldOp::create(nestedBuilder, loc);
       });
@@ -300,9 +301,12 @@ static void performShuffle(OpBuilder &rewriter, Location loc, Value matB,
 static llvm::DenseMap<Operation *, amx::TileLoadOp>
 packInputs(OpBuilder &rewriter, Location loc,
            SmallVector<vector::ContractionOp> ops, Value matB, Type ipType,
-           unsigned int offset, Value bBuffer, bool pack, Value allocStore,
-           Value addIdx) {
+           unsigned int offset, Value packedBuffer, bool pack,
+           Value indxToStoreInBuffer, Value indxToLoadFromMatB) {
+
   llvm::DenseMap<Operation *, amx::TileLoadOp> readsToTileLoads;
+  Value c0 = arith::ConstantIndexOp::create(rewriter, loc, 0);
+  Value c16 = arith::ConstantIndexOp::create(rewriter, loc, 16);
 
   for (size_t j = 0; j < ops.size(); j++) {
     for (size_t i = 0; i < ops.size(); i++) {
@@ -315,25 +319,23 @@ packInputs(OpBuilder &rewriter, Location loc,
           continue;
         }
 
-        Value c0 = arith::ConstantIndexOp::create(rewriter, loc, 0);
-        Value c16 = arith::ConstantIndexOp::create(rewriter, loc, 16);
-
         if (pack) {
-          performShuffle(rewriter, loc, matB, ipType, offset, bBuffer,
-                         allocStore);
+          performShuffle(rewriter, loc, matB, ipType, offset, packedBuffer,
+                         indxToStoreInBuffer);
         }
 
         amx::TileType tileType =
             amx::TileType::get({16, (16 * offset)}, ipType);
-        auto load = amx::TileLoadOp::create(rewriter, loc, tileType, bBuffer,
-                                            ValueRange{addIdx, c0, c0});
-
-        auto load1 = amx::TileLoadOp::create(rewriter, loc, tileType, bBuffer,
-                                             ValueRange{addIdx, c16, c0});
+        auto loadRow1 =
+            amx::TileLoadOp::create(rewriter, loc, tileType, packedBuffer,
+                                    ValueRange{indxToLoadFromMatB, c0, c0});
 
-        readsToTileLoads.try_emplace(readOpRhs, load);
+        auto loadRow2 =
+            amx::TileLoadOp::create(rewriter, loc, tileType, packedBuffer,
+                                    ValueRange{indxToLoadFromMatB, c16, c0});
 
-        readsToTileLoads.try_emplace(ops[i].getRhs().getDefiningOp(), load1);
+        readsToTileLoads.try_emplace(readOpRhs, loadRow1);
+        readsToTileLoads.try_emplace(ops[i].getRhs().getDefiningOp(), loadRow2);
       }
     }
   }
@@ -342,13 +344,12 @@ packInputs(OpBuilder &rewriter, Location loc,
 }
 
 // Creates tiled amx dot-products.
-static SmallVector<Value> createTiledDp(OpBuilder &rewriter, Location loc,
-                                        SmallVector<vector::ContractionOp> ops,
-                                        Value matA, Value matB, Type ipType,
-                                        Type opType, ValueRange accIterArgs,
-                                        unsigned int offset, bool isVnni,
-                                        Value bBuffer, bool pack,
-                                        Value allocStore, Value addIdx) {
+static SmallVector<Value>
+createTiledDp(OpBuilder &rewriter, Location loc,
+              SmallVector<vector::ContractionOp> ops, Value matA, Value matB,
+              Type ipType, Type opType, ValueRange accIterArgs,
+              unsigned int offset, bool isVnni, Value packedBuffer, bool pack,
+              Value indxToStoreInBuffer, Value indxToLoadFromMatB) {
 
   if (isVnni) {
     matA = collapseInnerDims(rewriter, loc, matA);
@@ -360,10 +361,11 @@ static SmallVector<Value> createTiledDp(OpBuilder &rewriter, Location loc,
   // or load operations.
   llvm::DenseMap<Operation *, amx::TileLoadOp> readsToTileLoads;
 
-  // function call to make the flat.
+  // function call to online pack the input  B matrix
   if (!isVnni) {
-    readsToTileLoads = packInputs(rewriter, loc, ops, matB, ipType, offset,
-                                  bBuffer, pack, allocStore, addIdx);
+    readsToTileLoads =
+        packInputs(rewriter, loc, ops, matB, ipType, offset, packedBuffer, pack,
+                   indxToStoreInBuffer, indxToLoadFromMatB);
   }
 
   // Iterate over the contraction operations and compute the tiled dot-product.
@@ -422,36 +424,37 @@ static SmallVector<Value> createTileZeros(OpBuilder &rewriter, Location loc,
   return loopItrArgs;
 }
 
-static Value bufferIndxToStore(OpBuilder &rewriter, Location loc, Value iv_K,
-                               Value iv_red, bool oddDimK, bool nDimK,
-                               bool pack, unsigned int blockingFactor) {
+static Value bufferIndxToStore(OpBuilder &rewriter, Location loc,
+                               Value ivInnerLoop, Value ivOuterLoop,
+                               bool isInnerLoopUBHasOddQuot,
+                               bool isInnerLoopUBLarger, bool pack,
+                               unsigned int blockingFactor) {
 
   Value c2 = arith::ConstantIndexOp::create(rewriter, loc, 2);
-
-  Value nLoadIndx =
+  Value packOffset =
       arith::ConstantIndexOp::create(rewriter, loc, (16 * blockingFactor));
 
-  Value quotient_K = arith::DivUIOp::create(rewriter, loc, iv_K, nLoadIndx);
-  Value rem_K = arith::RemUIOp::create(rewriter, loc, rewriter.getIndexType(),
-                                       quotient_K, c2);
+  Value quotientInnerLoop =
+      arith::DivUIOp::create(rewriter, loc, ivInnerLoop, packOffset);
+  Value remInnerLoop = arith::RemUIOp::create(
+      rewriter, loc, rewriter.getIndexType(), quotientInnerLoop, c2);
 
-  if (!nDimK && !pack) {
-    rem_K = arith::RemUIOp::create(rewriter, loc, rewriter.getIndexType(),
-                                   iv_red, c2);
+  if (!isInnerLoopUBLarger && !pack) {
+    remInnerLoop = arith::RemUIOp::create(
+        rewriter, loc, rewriter.getIndexType(), ivOuterLoop, c2);
   }
 
   // if K quotient is odd. Then, BR loop iv is taken
   // into consideration
-  if (oddDimK) {
-    auto rem_BR = arith::RemUIOp::create(rewriter, loc, rewriter.getIndexType(),
-                                         iv_red, c2);
+  if (isInnerLoopUBHasOddQuot) {
+    auto remOuterLoop = arith::RemUIOp::create(
+        rewriter, loc, rewriter.getIndexType(), ivOuterLoop, c2);
     auto remAdd = arith::AddIOp::create(rewriter, loc, rewriter.getIndexType(),
-                                        rem_K, rem_BR);
-    rem_K = arith::RemUIOp::create(rewriter, loc, rewriter.getIndexType(),
-                                   remAdd, c2);
+                                        remInnerLoop, remOuterLoop);
+    remInnerLoop = arith::RemUIOp::create(rewriter, loc,
+                                          rewriter.getIndexType(), remAdd, c2);
   }
-
-  return rem_K;
+  return remInnerLoop;
 }
 
 static scf::ForOp
@@ -461,10 +464,15 @@ createLoops(OpBuilder &rewriter, Location loc, Value lowerBound,
             Operation *vectorOpLhs, Operation *vectorOpRhs,
             vector::ContractionOp contractOp, scf::ForOp outerLoop,
             scf::ForOp innerLoop, SmallVector<vector::ContractionOp> ops,
-            Value ivOuterLoop, Value bBuffer, bool pack,
-            arith::ConstantIndexOp innerLoopIndex, bool nDimK, bool oddDimK) {
+            Value ivOuterLoop, Value packedBuffer, bool pack,
+            arith::ConstantIndexOp innerLoopIndex, bool isInnerLoopUBLarger,
+            bool isInnerLoopUBHasOddQuot) {
 
-  auto newLoop1 = scf::ForOp::create(
+  Value c0 = arith::ConstantIndexOp::create(rewriter, loc, 0);
+  Value c1 = arith::ConstantIndexOp::create(rewriter, loc, 1);
+  Value c2 = arith::ConstantIndexOp::create(rewriter, loc, 2);
+
+  auto newLoop = scf::ForOp::create(
       rewriter, loc, lowerBound, upperBound, step, loopItrArgs,
       [&](OpBuilder &rewriterNewInnerLoop, Location locNewInnerLoop,
           Value ivNewInnerLoop, ValueRange iterArgsNewInnerLoop) {
@@ -479,11 +487,8 @@ createLoops(OpBuilder &rewriter, Location loc, Value lowerBound,
                     ivNewInnerLoop);
         auto lhsClone = rewriterNewInnerLoop.clone(*vectorOpLhs, mapping);
 
-        Value c0 = arith::ConstantIndexOp::create(rewriter, locNewInnerLoop, 0);
-        Value c1 = arith::ConstantIndexOp::create(rewriter, locNewInnerLoop, 1);
-        Value c2 = arith::ConstantIndexOp::create(rewriter, locNewInnerLoop, 2);
-        Value allocStore = c0;
-        Value allocGet = c0;
+        Value indxToStoreInBuffer = c0;
+        Value indxToLoadFromBuffer = c0;
 
         if (!isVnni) {
           if (outerLoop) {
@@ -493,16 +498,17 @@ createLoops(OpBuilder &rewriter, Location loc, Value lowerBound,
                 ivOuterLoop = arith::AddIOp::create(rewriter, locNewInnerLoop,
                                                     c1, ivOuterLoop);
 
-                if (!nDimK || oddDimK) {
-                  allocStore = arith::RemUIOp::create(rewriter, locNewInnerLoop,
-                                                      rewriter.getIndexType(),
-                                                      ivOuterLoop, c2);
+                if (!isInnerLoopUBLarger || isInnerLoopUBHasOddQuot) {
+                  indxToStoreInBuffer = arith::RemUIOp::create(
+                      rewriter, locNewInnerLoop, rewriter.getIndexType(),
+                      ivOuterLoop, c2);
                 }
 
-                Value addIdx =
-                    arith::AddIOp::create(rewriter, loc, allocStore, c1);
-                allocGet = arith::RemUIOp::create(
-                    rewriter, loc, rewriter.getIndexType(), addIdx, c2);
+                Value indxToLoadFromMatB = arith::AddIOp::create(
+                    rewriter, loc, indxToStoreInBuffer, c1);
+                indxToLoadFromBuffer = arith::RemUIOp::create(
+                    rewriter, loc, rewriter.getIndexType(), indxToLoadFromMatB,
+                    c2);
               }
 
             } else {
@@ -510,13 +516,15 @@ createLoops(OpBuilder &rewriter, Location loc, Value lowerBound,
                   rewriter, locNewInnerLoop, (16 * blockingFactor));
               ivNewInnerLoop = arith::AddIOp::create(rewriter, locNewInnerLoop,
                                                      nLoadIndx, ivNewInnerLoop);
-              allocStore =
+              indxToStoreInBuffer =
                   bufferIndxToStore(rewriter, loc, ivNewInnerLoop, ivOuterLoop,
-                                    oddDimK, nDimK, pack, blockingFactor);
-              Value addIdx =
-                  arith::AddIOp::create(rewriter, loc, allocStore, c1);
-              allocGet = arith::RemUIOp::create(
-                  rewriter, loc, rewriter.getIndexType(), addIdx, c2);
+                                    isInnerLoopUBHasOddQuot,
+                                    isInnerLoopUBLarger, pack, blockingFactor);
+              Value indxToLoadFromMatB =
+                  arith::AddIOp::create(rewriter, loc, indxToStoreInBuffer, c1);
+              indxToLoadFromBuffer =
+                  arith::RemUIOp::create(rewriter, loc, rewriter.getIndexType(),
+                                         indxToLoadFromMatB, c2);
             }
           } else {
             if (pack) {
@@ -526,13 +534,14 @@ createLoops(OpBuilder &rewriter, Location loc, Value lowerBound,
                                                      nLoadIndx, ivNewInnerLoop);
               Value quotient_K = arith::DivUIOp::create(
                   rewriter, loc, ivNewInnerLoop, nLoadIndx);
-              allocStore = arith::RemUIOp::create(
+              indxToStoreInBuffer = arith::RemUIOp::create(
                   rewriter, loc, rewriter.getIndexType(), quotient_K, c2);
 
-              Value addIdx =
-                  arith::AddIOp::create(rewriter, loc, allocStore, c1);
-              allocGet = arith::RemUIOp::create(
-                  rewriter, loc, rewriter.getIndexType(), addIdx, c2);
+              Value indxToLoadFromMatB =
+                  arith::AddIOp::create(rewriter, loc, indxToStoreInBuffer, c1);
+              indxToLoadFromBuffer =
+                  arith::RemUIOp::create(rewriter, loc, rewriter.getIndexType(),
+                                         indxToLoadFromMatB, c2);
             }
           }
         }
@@ -558,10 +567,11 @@ createLoops(OpBuilder &rewriter, Location loc, Value lowerBound,
               Value nLoadIndx = arith::ConstantIndexOp::create(
                   rewriter, locNewInnerLoop, (16 * blockingFactor));
               matB = Value();
-              allocGet = c0;
-              allocGet =
+              indxToLoadFromBuffer = c0;
+              indxToLoadFromBuffer =
                   bufferIndxToStore(rewriter, loc, nLoadIndx, ivOuterLoop,
-                                    oddDimK, nDimK, pack, blockingFactor);
+                                    isInnerLoopUBHasOddQuot,
+                                    isInnerLoopUBLarger, pack, blockingFactor);
             }
           } else {
             if (!pack) {
@@ -570,28 +580,30 @@ createLoops(OpBuilder &rewriter, Location loc, Value lowerBound,
               matB = Value();
               Value quotient_K = arith::DivUIOp::create(
                   rewriter, loc, ivNewInnerLoop, nLoadIndx);
-              allocGet = arith::RemUIOp::create(
+              indxToLoadFromBuffer = arith::RemUIOp::create(
                   rewriter, loc, rewriter.getIndexType(), quotient_K, c2);
             }
           }
         }
 
+        // compute tiled dot-product
         SmallVector<Value> accumulators = createTiledDp(
             rewriter, locNewInnerLoop, ops, lhsClone->getResult(0), matB,
             ipType, opType, iterArgsNewInnerLoop, blockingFactor, isVnni,
-            bBuffer, pack, allocStore, allocGet);
+            packedBuffer, pack, indxToStoreInBuffer, indxToLoadFromBuffer);
 
         scf::YieldOp::create(rewriterNewInnerLoop, locNewInnerLoop,
                              accumulators);
       });
 
-  return newLoop1;
+  return newLoop;
 }
 
 // Implements tiled dot-product operation for a vector.contract operation or a
 // sequence of vector.contracts inside the reduction loops.
 //
-// For example - for F32 type:
+// For example:
+// Case 1: register blocked vector.contract with prepacked input
 // ```
 //   vector.transfer_read %arg0 {{.}*} : memref<16x32x4xi8>, vector<16x16x4xi8>
 //   vector.transfer_read %arg1 {{.}*} : memref<16x32x4xi8>, vector<16x16x4xi8>
@@ -605,6 +617,52 @@ createLoops(OpBuilder &rewriter, Location loc, Value lowerBound,
 //   amx.tile_muli !amx.tile<16x64xi8> -> !amx.tile<16x16xi32>
 //   amx.tile_store %arg2{{.}*} : memref<32x32xi32>, !amx.tile<16x16xi32>
 // ```
+//
+//
+// Case2: vector.contract with register blocked
+//
+// Output IR with online packing (with s/w pipeline advantage):
+// s/w pipeline: load, pack to VNNI, and store the B sub matrix
+// of the 0th batch-reduce and K iteration.
+// scf.for (0 to 31) {
+// 	- load 0th and 1st  vector<32xbf16>, pack into VNNI, store the
+// 	first shuffle in 0th and 2nd shuffle in 16th index of the
+// 	buffer.
+// }
+// scf.for (0 to br-2) { batch-reduce loop
+//   scf.for (0 to k-2) { K loop
+// 	- load A matrix
+//	- scf.loop for s/w pipeline: load, pack to VNNI, and store the B sub
+// matrix 	for the next K loop iteration 	(c) load VNNI pack B matrix of K
+// iteration from the buffer 	(d) compute the tiled dot-product
+//   }
+//   Last iteration of the the K Loop (k-1) {
+//      - load A matrix
+//      - scf.loop for s/w pipeline: load, pack to VNNI, and store the B sub
+//      matrix for the next batch-reduce + K loop iteration (c) load VNNI pack B
+//      matrix of K iteration from the buffer (d) compute the tiled dot-product
+//   }
+// }
+// Last iteration of the batch-reduce loop (br-1) {
+//   scf.for (0 to k-2) { K loop
+//      - load A matrix
+//      - scf.loop for s/w pipeline: load, pack to VNNI, and store the B sub
+//      matrix for the next K loop iteration (c) load VNNI pack B matrix of K
+//      iteration from the buffer (d) compute the tiled dot-product
+//   }
+//   Last iteration of the the K Loop (k-1) {
+//      - load A matrix
+//      - load VNNI pack B matrix of K iteration from the buffer
+//      - compute the tiled dot-product
+//   }
+// }
+//
+// scf.for (0 to M)
+//   scf.for (0 to N)
+//     - Load the ith and i+1th acc
+//     - Shuffle them as we packed using vpunpack
+//     - Load C matrix and do arith.add with the shuffle
+//     - Store back into C matrix
 struct VectorContractToAMXDotProduct
     : public OpRewritePattern<vector::ContractionOp> {
   using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
@@ -727,7 +785,6 @@ struct VectorContractToAMXDotProduct
                                              srcBuffLhs, indicesLhs);
 
       // Create the subview and then load.
-      //
       amx::TileLoadOp loadRhs;
       if (!isVnni) {
         VectorType vecTy;
@@ -746,12 +803,10 @@ struct VectorContractToAMXDotProduct
         auto subview = memref::SubViewOp::create(rewriter, loc, srcBuffRhs,
                                                  indexVals, sizes, strides);
         auto bufferType = MemRefType::get({16, (16 * blockingFactor)}, ipType);
-        auto bBuffer = memref::AllocaOp::create(rewriter, loc, bufferType);
-
-        // create a loop that swaps them.
+        auto packedBuffer = memref::AllocaOp::create(rewriter, loc, bufferType);
 
+        // create a loop that does online packing.
         Value c0 = arith::ConstantIndexOp::create(rewriter, loc, 0);
-
         Value step =
             arith::ConstantIndexOp::create(rewriter, loc, blockingFactor);
         Value uBound = arith::ConstantIndexOp::create(rewriter, loc,
@@ -817,14 +872,14 @@ struct VectorContractToAMXDotProduct
               auto rem = arith::RemUIOp::create(
                   rewriter, loc, rewriter.getIndexType(), iv, step);
 
-              vector::StoreOp::create(rewriter, loc, shuffle1, bBuffer,
+              vector::StoreOp::create(rewriter, loc, shuffle1, packedBuffer,
                                       ValueRange{rem, c0});
-              vector::StoreOp::create(rewriter, loc, shuffle2, bBuffer,
+              vector::StoreOp::create(rewriter, loc, shuffle2, packedBuffer,
                                       ValueRange{rem, nextStoreIndx});
 
               scf::YieldOp::create(nestedBuilder, loc);
             });
-        loadRhs = amx::TileLoadOp::create(rewriter, loc, tileType, bBuffer,
+        loadRhs = amx::TileLoadOp::create(rewriter, loc, tileType, packedBuffer,
                                           ValueRange{c0, c0});
       } else {
 
@@ -915,6 +970,22 @@ struct VectorContractToAMXDotProduct
       }
     }
 
+    if (!isVnni) {
+      unsigned int pairCount = 0;
+      for (size_t j = 0; j < ops.size(); j++) {
+        for (size_t i = j; i < ops.size(); i++) {
+
+          if (i != j && validatePairVectorContract(ops[j], ops[i], true, 16)) {
+            pairCount = pairCount + 2;
+          }
+        }
+      }
+
+      if (pairCount != ops.size())
+        return rewriter.notifyMatchFailure(
+            contractOp, "Coudn't find the pair vector contract ");
+    }
+
     scf::ForOp innerLoop;
     scf::ForOp outerLoop;
 
@@ -946,8 +1017,8 @@ struct VectorContractToAMXDotProduct
 
       } else {
 
-        bool nDimK = false;
-        bool oddDimK = false;
+        bool isInnerLoopUBLarger = false;
+        bool isInnerLoopUBHasOddQuot = false;
 
         int64_t ubVal = 16 * blockingFactor;
         mlir::Value ub = innerLoop.getUpperBound();
@@ -958,27 +1029,27 @@ struct VectorContractToAMXDotProduct
           }
         }
 
-        nDimK = ubVal > 16 * blockingFactor;
-        oddDimK = (((ubVal / (16 * blockingFactor)) % 2) == 1) && nDimK;
+        isInnerLoopUBLarger = ubVal > 16 * blockingFactor;
+        isInnerLoopUBHasOddQuot =
+            (((ubVal / (16 * blockingFactor)) % 2) == 1) && isInnerLoopUBLarger;
 
         rewriter.setInsertionPoint(outerLoop);
 
         auto c0 =
             arith::ConstantIndexOp::create(rewriter, outerLoop.getLoc(), 0);
-
         auto c1 =
             arith::ConstantIndexOp::create(rewriter, outerLoop.getLoc(), 1);
-
         auto spillLoopBound = arith::ConstantIndexOp::create(
             rewriter, outerLoop.getLoc(), 16 * blockingFactor);
-        Value subBRLoop = arith::SubIOp::create(rewriter, outerLoop.getLoc(),
-                                                outerLoop.getUpperBound(), c1);
-        Value subKloop =
+
+        Value spillOuterLoop = arith::SubIOp::create(
+            rewriter, outerLoop.getLoc(), outerLoop.getUpperBound(), c1);
+        Value spillInnerLoop =
             arith::SubIOp::create(rewriter, innerLoop.getLoc(),
                                   innerLoop.getUpperBound(), spillLoopBound);
         auto bufferType =
             MemRefType::get({2, 32, (blockingFactor * 16)}, ipType);
-        auto bBuffer =
+        auto packedBuffer =
             memref::AllocaOp::create(rewriter, outerLoop.getLoc(), bufferType);
 
         // First Shuffling outside the reduction loops
@@ -994,28 +1065,29 @@ struct VectorContractToAMXDotProduct
         auto rhsClone = rewriter.clone(*vectorOpRhs, rhsMapping);
 
         performShuffle(rewriter, outerLoop.getLoc(), rhsClone->getResult(0),
-                       ipType, blockingFactor, bBuffer, c0);
+                       ipType, blockingFactor, packedBuffer, c0);
 
         // First Set of Loops
-        auto newLoop1 = scf::ForOp::create(
-            rewriter, outerLoop.getLoc(), outerLoop.getLowerBound(), subBRLoop,
-            outerLoop.getStep(), loopItrArgs,
+        auto newLoopNonSpill = scf::ForOp::create(
+            rewriter, outerLoop.getLoc(), outerLoop.getLowerBound(),
+            spillOuterLoop, outerLoop.getStep(), loopItrArgs,
             [&](OpBuilder &rewriterOuterLoop, Location locOuterLoop,
                 Value ivOuterLoop, ValueRange iterArgsOuterLoop) {
               auto newInnerLoop1 = createLoops(
                   rewriter, innerLoop.getLoc(), innerLoop.getLowerBound(),
-                  subKloop, innerLoop.getStep(), iterArgsOuterLoop, ipType,
-                  opType, blockingFactor, isVnni, vectorOpLhs, vectorOpRhs,
-                  contractOp, outerLoop, innerLoop, ops, ivOuterLoop, bBuffer,
-                  true, spillLoopBound, nDimK, oddDimK);
+                  spillInnerLoop, innerLoop.getStep(), iterArgsOuterLoop,
+                  ipType, opType, blockingFactor, isVnni, vectorOpLhs,
+                  vectorOpRhs, contractOp, outerLoop, innerLoop, ops,
+                  ivOuterLoop, packedBuffer, true, spillLoopBound,
+                  isInnerLoopUBLarger, isInnerLoopUBHasOddQuot);
 
-              auto newInnerLoop =
-                  createLoops(rewriter, innerLoop.getLoc(), subKloop,
-                              innerLoop.getUpperBound(), innerLoop.getStep(),
-                              newInnerLoop1.getResults(), ipType, opType,
-                              blockingFactor, isVnni, vectorOpLhs, vectorOpRhs,
-                              contractOp, outerLoop, innerLoop, ops,
-                              ivOuterLoop, bBuffer, true, c0, nDimK, oddDimK);
+              auto newInnerLoop = createLoops(
+                  rewriter, innerLoop.getLoc(), spillInnerLoop,
+                  innerLoop.getUpperBound(), innerLoop.getStep(),
+                  newInnerLoop1.getResults(), ipType, opType, blockingFactor,
+                  isVnni, vectorOpLhs, vectorOpRhs, contractOp, outerLoop,
+                  innerLoop, ops, ivOuterLoop, packedBuffer, true, c0,
+                  isInnerLoopUBLarger, isInnerLoopUBHasOddQuot);
 
               scf::YieldOp::create(rewriterOuterLoop, locOuterLoop,
                                    newInnerLoop.getResults());
@@ -1023,24 +1095,26 @@ struct VectorContractToAMXDotProduct
 
         // Last set of Loops
         newLoop = scf::ForOp::create(
-            rewriter, outerLoop.getLoc(), subBRLoop, outerLoop.getUpperBound(),
-            outerLoop.getStep(), newLoop1.getResults(),
+            rewriter, outerLoop.getLoc(), spillOuterLoop,
+            outerLoop.getUpperBound(), outerLoop.getStep(),
+            newLoopNonSpill.getResults(),
             [&](OpBuilder &rewriterOuterLoop, Location locOuterLoop,
                 Value ivOuterLoop, ValueRange iterArgsOuterLoop) {
               auto newInnerLoop1 = createLoops(
                   rewriter, innerLoop.getLoc(), innerLoop.getLowerBound(),
-                  subKloop, innerLoop.getStep(), iterArgsOuterLoop, ipType,
-                  opType, blockingFactor, isVnni, vectorOpLhs, vectorOpRhs,
-                  contractOp, outerLoop, innerLoop, ops, ivOuterLoop, bBuffer,
-                  true, spillLoopBound, nDimK, oddDimK);
+                  spillInnerLoop, innerLoop.getStep(), iterArgsOuterLoop,
+                  ipType, opType, blockingFactor, isVnni, vectorOpLhs,
+                  vectorOpRhs, contractOp, outerLoop, innerLoop, ops,
+                  ivOuterLoop, packedBuffer, true, spillLoopBound,
+                  isInnerLoopUBLarger, isInnerLoopUBHasOddQuot);
 
-              auto newInnerLoop =
-                  createLoops(rewriter, innerLoop.getLoc(), subKloop,
-                              innerLoop.getUpperBound(), innerLoop.getStep(),
-                              newInnerLoop1.getResults(), ipType, opType,
-                              blockingFactor, isVnni, vectorOpLhs, vectorOpRhs,
-                              contractOp, outerLoop, innerLoop, ops,
-                              ivOuterLoop, bBuffer, false, c0, nDimK, oddDimK);
+              auto newInnerLoop = createLoops(
+                  rewriter, innerLoop.getLoc(), spillInnerLoop,
+                  innerLoop.getUpperBound(), innerLoop.getStep(),
+                  newInnerLoop1.getResults(), ipType, opType, blockingFactor,
+                  isVnni, vectorOpLhs, vectorOpRhs, contractOp, outerLoop,
+                  innerLoop, ops, ivOuterLoop, packedBuffer, false, c0,
+                  isInnerLoopUBLarger, isInnerLoopUBHasOddQuot);
 
               scf::YieldOp::create(rewriterOuterLoop, locOuterLoop,
                                    newInnerLoop.getResults());
@@ -1065,8 +1139,8 @@ struct VectorContractToAMXDotProduct
             nullptr, false, false);
 
       } else {
-        bool nDimK = false;
-        bool oddDimK = false;
+        bool isInnerLoopUBLarger = false;
+        bool isInnerLoopUBHasOddQuot = false;
 
         int64_t ubVal = 16 * blockingFactor;
         mlir::Value ub = innerLoop.getUpperBound();
@@ -1077,21 +1151,23 @@ struct VectorContractToAMXDotProduct
           }
         }
 
-        nDimK = ubVal > 16 * blockingFactor;
-        oddDimK = (((ubVal / (16 * blockingFactor)) % 2) == 1) && nDimK;
+        isInnerLoopUBLarger = ubVal > 16 * blockingFactor;
+        isInnerLoopUBHasOddQuot =
+            (((ubVal / (16 * blockingFactor)) % 2) == 1) && isInnerLoopUBLarger;
+
         rewriter.setInsertionPoint(innerLoop);
         auto c0 =
             arith::ConstantIndexOp::create(rewriter, innerLoop.getLoc(), 0);
         auto spillLoopBound = arith::ConstantIndexOp::create(
             rewriter, innerLoop.getLoc(), 16 * blockingFactor);
 
-        Value subKloop =
+        Value spillInnerLoop =
             arith::SubIOp::create(rewriter, innerLoop.getLoc(),
                                   innerLoop.getUpperBound(), spillLoopBound);
 
         auto bufferType =
             MemRefType::get({2, 32, (blockingFactor * 16)}, ipType);
-        auto bBuffer =
+        auto packedBuffer =
             memref::AllocaOp::create(rewriter, innerLoop.getLoc(), bufferType);
 
         // First Shuffling outside the reduction loops
@@ -1103,20 +1179,22 @@ struct VectorContractToAMXDotProduct
         auto rhsClone = rewriter.clone(*vectorOpRhs, rhsMapping);
 
         performShuffle(rewriter, innerLoop.getLoc(), rhsClone->getResult(0),
-                       ipType, blockingFactor, bBuffer, c0);
+                       ipType, blockingFactor, packedBuffer, c0);
 
-        auto newLoop1 = createLoops(
-            rewriter, innerLoop.getLoc(), innerLoop.getLowerBound(), subKloop,
-            innerLoop.getStep(), loopItrArgs, ipType, opType, blockingFactor,
-            isVnni, vectorOpLhs, vectorOpRhs, contractOp, nullptr, innerLoop,
-            ops, nullptr, bBuffer, true, spillLoopBound, nDimK, oddDimK);
+        auto newLoopNonSpill = createLoops(
+            rewriter, innerLoop.getLoc(), innerLoop.getLowerBound(),
+            spillInnerLoop, innerLoop.getStep(), loopItrArgs, ipType, opType,
+            blockingFactor, isVnni, vectorOpLhs, vectorOpRhs, contractOp,
+            nullptr, innerLoop, ops, nullptr, packedBuffer, true,
+            spillLoopBound, isInnerLoopUBLarger, isInnerLoopUBHasOddQuot);
 
-        newLoop = createLoops(rewriter, innerLoop.getLoc(), subKloop,
+        newLoop = createLoops(rewriter, innerLoop.getLoc(), spillInnerLoop,
                               innerLoop.getUpperBound(), innerLoop.getStep(),
-                              newLoop1.getResults(), ipType, opType,
+                              newLoopNonSpill.getResults(), ipType, opType,
                               blockingFactor, isVnni, vectorOpLhs, vectorOpRhs,
                               contractOp, nullptr, innerLoop, ops, nullptr,
-                              bBuffer, false, c0, nDimK, oddDimK);
+                              packedBuffer, false, c0, isInnerLoopUBLarger,
+                              isInnerLoopUBHasOddQuot);
       }
     }
 
@@ -1125,17 +1203,41 @@ struct VectorContractToAMXDotProduct
 
     if (!isVnni) {
       Location loc = outerLoop.getLoc();
+      Operation *accReadOp =
+          traceToVectorReadLikeParentOperation(contractOp.getAcc());
+
+      Value srcBuffAcc;
+      SmallVector<Value> indicesAcc;
+
+      llvm::TypeSwitch<Operation *>(accReadOp).Case<TransferReadOp, LoadOp>(
+          [&](auto readOp) {
+            srcBuffAcc = readOp.getOperand(0);
+
+            auto indices = readOp.getIndices();
+            indicesAcc.reserve(indices.size());
+
+            llvm::transform(indices, std::back_inserter(indicesAcc),
+                            [&](OpFoldResult ofr) {
+                              return mlir::getValueOrCreateConstantIndexOp(
+                                  rewriter, loc, ofr);
+                            });
+          });
+
+      auto outputShapes =
+          mlir::cast<mlir::MemRefType>(srcBuffAcc.getType()).getShape();
+      unsigned int M = outputShapes[outputShapes.size() - 2];
+      unsigned int N = outputShapes[outputShapes.size() - 1];
+
       SmallVector<Value> dps = newLoop.getResults();
-      auto bufferType = MemRefType::get({32, 32}, opType);
-      auto bBuffer =
-          memref::AllocaOp::create(rewriter, outerLoop.getLoc(), bufferType);
-      for (int i = 0, k = 0; i < 32; i = i + 16) {
-        for (int j = 0; j < 32; j = j + 16) {
-          Value indexOp_i =
-              arith::ConstantIndexOp::create(rewriter, outerLoop.getLoc(), i);
-          Value indexOp_j =
-              arith::ConstantIndexOp::create(rewriter, outerLoop.getLoc(), j);
-          amx::TileStoreOp::create(rewriter, outerLoop.getLoc(), bBuffer,
+      auto bufferType = MemRefType::get({M, N}, opType);
+      auto resultBuffer = memref::AllocaOp::create(rewriter, loc, bufferType);
+
+      // Store the amx tiled-dot product output into an MxN memref.
+      for (unsigned int i = 0, k = 0; i < M; i = i + 16) {
+        for (unsigned int j = 0; j < N; j = j + 16) {
+          Value indexOp_i = arith::ConstantIndexOp::create(rewriter, loc, i);
+          Value indexOp_j = arith::ConstantIndexOp::create(rewriter, loc, j);
+          amx::TileStoreOp::create(rewriter, loc, resultBuffer,
                                    ValueRange{indexOp_i, indexOp_j}, dps[k]);
           k++;
         }
@@ -1143,19 +1245,21 @@ struct VectorContractToAMXDotProduct
       auto c0 = arith::ConstantIndexOp::create(rewriter, loc, 0);
       auto c16 = arith::ConstantIndexOp::create(rewriter, loc, 16);
       auto one = arith::ConstantIndexOp::create(rewriter, loc, 1);
-      auto mBound = arith::ConstantIndexOp::create(rewriter, loc, 32);
+      auto mBound = arith::ConstantIndexOp::create(rewriter, loc, N);
 
+      // Create a loop that iterates over the MxN memerf, retrives two rows +
+      // shuffle them, add up the C element values and stores them back.
       scf::ForOp::create(
           rewriter, loc, c0, mBound, one, ValueRange{},
           [&](OpBuilder &nestedBuilder, Location loc, Value iv,
               ValueRange iterArgs) {
             auto row = vector::LoadOp::create(rewriter, loc,
                                               VectorType::get(16, opType),
-                                              bBuffer, ValueRange{iv, c0});
+                                              resultBuffer, ValueRange{iv, c0});
 
-            auto row2 = vector::LoadOp::create(rewriter, loc,
-                                               VectorType::get(16, opType),
-                                               bBuffer, ValueRange{iv, c16});
+            auto row2 = vector::LoadOp::create(
+                rewriter, loc, VectorType::get(16, opType), resultBuffer,
+                ValueRange{iv, c16});
 
             auto shuffle1 = vector::ShuffleOp::create(
                 rewriter, loc, VectorType::get(16, opType), row, row2,
@@ -1167,27 +1271,6 @@ struct VectorContractToAMXDotProduct
                 ArrayRef<int64_t>{8, 9, 10, 11, 24, 25, 26, 27, 12, 13, 14, 15,
                                   28, 29, 30, 31});
 
-            Operation *accReadOp =
-                traceToVectorReadLikeParentOperation(contractOp.getAcc());
-
-            Value srcBuffAcc;
-            SmallVector<Value> indicesAcc;
-
-            llvm::TypeSwitch<Operation *>(accReadOp)
-                .Case<TransferReadOp, LoadOp>([&](auto readOp) {
-                  srcBuffAcc = readOp.getOperand(0);
-
-                  auto indices = readOp.getIndices();
-                  indicesAcc.reserve(indices.size());
-
-                  llvm::transform(
-                      indices, std::back_inserter(indicesAcc),
-                      [&](OpFoldResult ofr) {
-                        return mlir::getValueOrCreateConstantIndexOp(rewriter,
-                                                                     loc, ofr);
-                      });
-                });
-
             indicesAcc[indicesAcc.size() - 2] = iv;
             indicesAcc[indicesAcc.size() - 1] = c0;
 
@@ -1199,6 +1282,7 @@ struct VectorContractToAMXDotProduct
             Value valueCRow2 = vector::LoadOp::create(
                 rewriter, loc, VectorType::get(16, opType), srcBuffAcc,
                 indicesAcc);
+
             Value addOp;
             Value addOp2;
 
@@ -1227,11 +1311,12 @@ struct VectorContractToAMXDotProduct
             scf::YieldOp::create(nestedBuilder, loc);
           });
     }
+
     auto bufferType = MemRefType::get({16, 16}, opType);
-    auto bBuffer =
+    auto resultBuffer =
         memref::AllocaOp::create(rewriter, outerLoop.getLoc(), bufferType);
-
     SmallVector<Value> dps = newLoop.getResults();
+
     for (size_t i = 0; i < ops.size(); i++) {
       vector::ContractionOp contOp = ops[i];
       Operation *resultWriteOp =
@@ -1242,7 +1327,7 @@ struct VectorContractToAMXDotProduct
         Value indexOp_0 =
             arith::ConstantIndexOp::create(rewriter, outerLoop.getLoc(), 0);
 
-        amx::TileStoreOp::create(rewriter, outerLoop.getLoc(), bBuffer,
+        amx::TileStoreOp::create(rewriter, outerLoop.getLoc(), resultBuffer,
                                  ValueRange{indexOp_0, indexOp_0}, dps[i]);
 
         auto c0 =
@@ -1257,7 +1342,7 @@ struct VectorContractToAMXDotProduct
             [&](OpBuilder &builder, Location loc, Value iv,
                 ValueRange iterArgs) {
               auto resultAcc = vector::LoadOp::create(
-                  rewriter, loc, VectorType::get(16, opType), bBuffer,
+                  rewriter, loc, VectorType::get(16, opType), resultBuffer,
                   ValueRange{iv, c0});
 
               Operation *accReadOp =
@@ -1283,7 +1368,7 @@ struct VectorContractToAMXDotProduct
 
               Value sum =
                   arith::AddIOp::create(builder, loc, iv, indicesAcc[0]);
-              indicesAcc[0] = sum;
+              indicesAcc[indicesAcc.size() - 2] = sum;
 
               auto acc = vector::LoadOp::create(rewriter, loc,
                                                 VectorType::get(16, opType),
diff --git a/mlir/test/Dialect/X86/AMX/vector-contract-to-tiled-dp.mlir b/mlir/test/Dialect/X86/AMX/vector-contract-to-tiled-dp.mlir
index 151946453df81..6bb90a80da66e 100644
--- a/mlir/test/Dialect/X86/AMX/vector-contract-to-tiled-dp.mlir
+++ b/mlir/test/Dialect/X86/AMX/vector-contract-to-tiled-dp.mlir
@@ -1341,3 +1341,154 @@ module attributes {transform.with_named_sequence} {
   }
 }
 
+// -----
+
+!vecA = vector<16x64xi8>
+!vecB = vector<64x16xi8>
+!vecC = vector<16x16xi32>
+!memrefA = memref<16x64xi8, strided<[256, 1], offset: ?>>
+!memrefB = memref<64x32xi8, strided<[128, 1], offset: ?>>
+!memrefC = memref<16x32xi32, strided<[128, 1], offset: ?>>
+
+#map = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+func.func @negative_vc_wrong_order_no_pair(%arg0: memref<64x256xi8>, %arg1: memref<256x128xi8>, %arg2: memref<64x128xi32>) -> memref<64x128xi32> {
+  %0 = ub.poison : i32
+  %1 = ub.poison : i8
+  %c0 = arith.constant 0 : index
+  %c64 = arith.constant 64 : index
+  %c128 = arith.constant 128 : index
+  %c256 = arith.constant 256 : index
+  %c16 = arith.constant 16 : index
+  %c32 = arith.constant 32 : index
+  scf.for %arg3 = %c0 to %c64 step %c16 {
+    scf.for %arg4 = %c0 to %c128 step %c32 {
+      %subview = memref.subview %arg2[%arg3, %arg4] [16, 32] [1, 1]
+                : memref<64x128xi32> to !memrefC
+      %2 = vector.transfer_read %subview[%c0, %c0], %0 {in_bounds = [true, true]}
+                : !memrefC, !vecC
+      %3 = vector.transfer_read %subview[%c0, %c16], %0 {in_bounds = [true, true]}
+                : !memrefC, !vecC
+      %4:2 = scf.for %arg5 = %c0 to %c256 step %c64 iter_args(%arg6 = %2, %arg7 = %3) -> (!vecC, !vecC) {
+        %subview_0 = memref.subview %arg0[%arg3, %arg5] [16, 64] [1, 1]
+                : memref<64x256xi8> to !memrefA
+        %subview_1 = memref.subview %arg1[%arg5, %arg4] [64, 32] [1, 1]
+                : memref<256x128xi8> to !memrefB
+        %5 = vector.transfer_read %subview_0[%c0, %c0], %1 {in_bounds = [true, true]}
+                : !memrefA, !vecA
+        %6 = vector.transfer_read %subview_1[%c0, %c0], %1 {in_bounds = [true, true]}
+                : !memrefB, !vecB
+        %7 = vector.transfer_read %subview_1[%c0, %c16], %1 {in_bounds = [true, true]}
+                : !memrefB, !vecB
+        %8 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types =
+                ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+                %5, %7, %arg7 {unroll_shape = array<i64: 16, 16, 64>} : !vecA, !vecB into !vecC
+        %9 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types =
+                ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+                %5, %6, %arg6 {unroll_shape = array<i64: 16, 16, 64>} : !vecA, !vecB into !vecC
+        scf.yield %9, %8 : !vecC, !vecC
+      }
+      vector.transfer_write %4#1, %subview[%c0, %c16] {in_bounds = [true, true]}
+                : !vecC, !memrefC
+      vector.transfer_write %4#0, %subview[%c0, %c0] {in_bounds = [true, true]}
+                : !vecC, !memrefC
+    }
+  }
+  %alloc = memref.alloc() : memref<64x128xi32>
+  memref.copy %arg2, %alloc : memref<64x128xi32> to memref<64x128xi32>
+  return %alloc : memref<64x128xi32>
+}
+
+
+// CHECK-LABEL: @negative_vc_wrong_order_no_pair
+// CHECK-NOT: x86.amx.tile_zero : !x86.amx.tile<16x16xi32>
+// CHECK-NOT: x86.amx.tile_load
+// CHECK-NOT: x86.amx.tile_muli
+// CHECK: vector.contract
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.x86.vector_contract_to_amx_dot_product
+    } : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+!vecA = vector<16x64xi8>
+!vecB = vector<64x16xi8>
+!vecC = vector<16x16xi32>
+!memrefA = memref<32x64xi8, strided<[256, 1], offset: ?>>
+!memrefB = memref<64x16xi8, strided<[128, 1], offset: ?>>
+!memrefC = memref<32x16xi32, strided<[128, 1], offset: ?>>
+
+#map = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+func.func @negative_vc_no_pair(%arg0: memref<64x256xi8>, %arg1: memref<256x128xi8>, %arg2: memref<64x128xi32>) -> memref<64x128xi32> {
+  %0 = ub.poison : i32
+  %1 = ub.poison : i8
+  %c0 = arith.constant 0 : index
+  %c64 = arith.constant 64 : index
+  %c128 = arith.constant 128 : index
+  %c256 = arith.constant 256 : index
+  %c32 = arith.constant 32 : index
+  %c16 = arith.constant 16 : index
+  scf.for %arg3 = %c0 to %c64 step %c32 {
+    scf.for %arg4 = %c0 to %c128 step %c16 {
+      %subview = memref.subview %arg2[%arg3, %arg4] [32, 16] [1, 1] : memref<64x128xi32> to !memrefC
+      %2 = vector.transfer_read %subview[%c0, %c0], %0 {in_bounds = [true, true]}
+                : !memrefC, !vecC
+      %3 = vector.transfer_read %subview[%c16, %c0], %0 {in_bounds = [true, true]}
+                : !memrefC, !vecC
+      %4:2 = scf.for %arg5 = %c0 to %c256 step %c64 iter_args(%arg6 = %2, %arg7 = %3) -> (!vecC, !vecC) {
+        %subview_0 = memref.subview %arg0[%arg3, %arg5] [32, 64] [1, 1]
+                : memref<64x256xi8> to !memrefA
+        %subview_1 = memref.subview %arg1[%arg5, %arg4] [64, 16] [1, 1]
+                : memref<256x128xi8> to !memrefB
+        %5 = vector.transfer_read %subview_0[%c0, %c0], %1 {in_bounds = [true, true]}
+                : !memrefA, !vecA
+        %6 = vector.transfer_read %subview_0[%c16, %c0], %1 {in_bounds = [true, true]}
+                : !memrefA, !vecA
+        %7 = vector.transfer_read %subview_1[%c0, %c0], %1 {in_bounds = [true, true]}
+                : !memrefB, !vecB
+        %8 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types =
+                ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+                %5, %7, %arg6 {unroll_shape = array<i64: 16, 16, 64>} : !vecA, !vecB into !vecC
+        %9 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types =
+                ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+                %6, %7, %arg7 {unroll_shape = array<i64: 16, 16, 64>} : !vecA, !vecB into !vecC
+        scf.yield %8, %9 : !vecC, !vecC
+      }
+      vector.transfer_write %4#1, %subview[%c16, %c0] {in_bounds = [true, true]}
+                : !vecC, !memrefC
+      vector.transfer_write %4#0, %subview[%c0, %c0] {in_bounds = [true, true]}
+                : !vecC, !memrefC
+    }
+  }
+  %alloc = memref.alloc() : memref<64x128xi32>
+  memref.copy %arg2, %alloc : memref<64x128xi32> to memref<64x128xi32>
+  return %alloc : memref<64x128xi32>
+}
+
+// CHECK-LABEL: @negative_vc_no_pair
+// CHECK-NOT: x86.amx.tile_zero : !x86.amx.tile<16x16xi32>
+// CHECK-NOT: x86.amx.tile_load
+// CHECK-NOT: x86.amx.tile_muli
+// CHECK: vector.contract
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.x86.vector_contract_to_amx_dot_product
+    } : !transform.any_op
+    transform.yield
+  }
+}

>From ed5e583ca32dc5c610caa6b09f7d2078051b7bb5 Mon Sep 17 00:00:00 2001
From: Arun Thangamani <arun.thangamani at intel.com>
Date: Fri, 27 Mar 2026 08:50:28 -0700
Subject: [PATCH 3/5] code refactor.

---
 .../VectorContractToAMXDotProduct.cpp         | 53 ++++++++-----------
 1 file changed, 23 insertions(+), 30 deletions(-)

diff --git a/mlir/lib/Dialect/X86/Transforms/VectorContractToAMXDotProduct.cpp b/mlir/lib/Dialect/X86/Transforms/VectorContractToAMXDotProduct.cpp
index 744c065b4e05e..63e187613bfc4 100644
--- a/mlir/lib/Dialect/X86/Transforms/VectorContractToAMXDotProduct.cpp
+++ b/mlir/lib/Dialect/X86/Transforms/VectorContractToAMXDotProduct.cpp
@@ -70,9 +70,8 @@ getSrcIndxValue(OpBuilder &rewriter, Location loc, Value operand,
   if (!srcBuff)
     return failure();
 
-  if (isNotAcc) {
+  if (isNotAcc)
     indexVals.pop_back();
-  }
 
   SmallVector<Value> indices;
   indices.reserve(indexVals.size());
@@ -206,8 +205,7 @@ static amx::TileLoadOp createTileLoads(OpBuilder &rewriter, Location loc,
   }
 
   amx::TileType tileType = amx::TileType::get({16, (16 * offset)}, ipType);
-  auto load = amx::TileLoadOp::create(rewriter, loc, tileType, mat, indices);
-  return load;
+  return amx::TileLoadOp::create(rewriter, loc, tileType, mat, indices);
 }
 
 static void performShuffle(OpBuilder &rewriter, Location loc, Value matB,
@@ -424,11 +422,10 @@ static SmallVector<Value> createTileZeros(OpBuilder &rewriter, Location loc,
   return loopItrArgs;
 }
 
-static Value bufferIndxToStore(OpBuilder &rewriter, Location loc,
-                               Value ivInnerLoop, Value ivOuterLoop,
-                               bool isInnerLoopUBHasOddQuot,
-                               bool isInnerLoopUBLarger, bool pack,
-                               unsigned int blockingFactor) {
+static Value getIndxToLoadStoreFromPckBuffer(
+    OpBuilder &rewriter, Location loc, Value ivInnerLoop, Value ivOuterLoop,
+    bool isInnerLoopUBHasOddQuot, bool isInnerLoopUBLarger, bool pack,
+    unsigned int blockingFactor) {
 
   Value c2 = arith::ConstantIndexOp::create(rewriter, loc, 2);
   Value packOffset =
@@ -444,8 +441,6 @@ static Value bufferIndxToStore(OpBuilder &rewriter, Location loc,
         rewriter, loc, rewriter.getIndexType(), ivOuterLoop, c2);
   }
 
-  // if K quotient is odd. Then, BR loop iv is taken
-  // into consideration
   if (isInnerLoopUBHasOddQuot) {
     auto remOuterLoop = arith::RemUIOp::create(
         rewriter, loc, rewriter.getIndexType(), ivOuterLoop, c2);
@@ -454,6 +449,7 @@ static Value bufferIndxToStore(OpBuilder &rewriter, Location loc,
     remInnerLoop = arith::RemUIOp::create(rewriter, loc,
                                           rewriter.getIndexType(), remAdd, c2);
   }
+
   return remInnerLoop;
 }
 
@@ -477,11 +473,11 @@ createLoops(OpBuilder &rewriter, Location loc, Value lowerBound,
       [&](OpBuilder &rewriterNewInnerLoop, Location locNewInnerLoop,
           Value ivNewInnerLoop, ValueRange iterArgsNewInnerLoop) {
         IRMapping mapping;
-        if (outerLoop) {
+        if (outerLoop)
           mapping.map(vectorOpLhs->getOperand(
                           getIndexPosition(contractOp.getLhs(), outerLoop) + 1),
                       ivOuterLoop);
-        }
+
         mapping.map(vectorOpLhs->getOperand(
                         getIndexPosition(contractOp.getLhs(), innerLoop) + 1),
                     ivNewInnerLoop);
@@ -516,10 +512,10 @@ createLoops(OpBuilder &rewriter, Location loc, Value lowerBound,
                   rewriter, locNewInnerLoop, (16 * blockingFactor));
               ivNewInnerLoop = arith::AddIOp::create(rewriter, locNewInnerLoop,
                                                      nLoadIndx, ivNewInnerLoop);
-              indxToStoreInBuffer =
-                  bufferIndxToStore(rewriter, loc, ivNewInnerLoop, ivOuterLoop,
-                                    isInnerLoopUBHasOddQuot,
-                                    isInnerLoopUBLarger, pack, blockingFactor);
+              indxToStoreInBuffer = getIndxToLoadStoreFromPckBuffer(
+                  rewriter, loc, ivNewInnerLoop, ivOuterLoop,
+                  isInnerLoopUBHasOddQuot, isInnerLoopUBLarger, pack,
+                  blockingFactor);
               Value indxToLoadFromMatB =
                   arith::AddIOp::create(rewriter, loc, indxToStoreInBuffer, c1);
               indxToLoadFromBuffer =
@@ -547,12 +543,12 @@ createLoops(OpBuilder &rewriter, Location loc, Value lowerBound,
         }
 
         IRMapping rhsMapping;
-        if (outerLoop) {
+        if (outerLoop)
           rhsMapping.map(
               vectorOpRhs->getOperand(
                   getIndexPosition(contractOp.getRhs(), outerLoop) + 1),
               ivOuterLoop);
-        }
+
         rhsMapping.map(
             vectorOpRhs->getOperand(
                 getIndexPosition(contractOp.getRhs(), innerLoop) + 1),
@@ -568,10 +564,10 @@ createLoops(OpBuilder &rewriter, Location loc, Value lowerBound,
                   rewriter, locNewInnerLoop, (16 * blockingFactor));
               matB = Value();
               indxToLoadFromBuffer = c0;
-              indxToLoadFromBuffer =
-                  bufferIndxToStore(rewriter, loc, nLoadIndx, ivOuterLoop,
-                                    isInnerLoopUBHasOddQuot,
-                                    isInnerLoopUBLarger, pack, blockingFactor);
+              indxToLoadFromBuffer = getIndxToLoadStoreFromPckBuffer(
+                  rewriter, loc, nLoadIndx, ivOuterLoop,
+                  isInnerLoopUBHasOddQuot, isInnerLoopUBLarger, pack,
+                  blockingFactor);
             }
           } else {
             if (!pack) {
@@ -709,13 +705,8 @@ struct VectorContractToAMXDotProduct
                       "transfer_read or a load. And, the result should be "
                       "stored using transfer_write or store.");
 
-    Type ipType;
-    Type opType;
-
-    if (lhsTy.getElementType().isBF16()) {
-      ipType = rewriter.getBF16Type();
-      opType = rewriter.getF32Type();
-    }
+    Type ipType = rewriter.getBF16Type();
+    Type opType = rewriter.getF32Type();
 
     if (lhsTy.getElementType().isSignlessInteger(8)) {
       ipType = rewriter.getIntegerType(8);
@@ -910,6 +901,7 @@ struct VectorContractToAMXDotProduct
     // Case 2: The acc are passed as iter args through the reduction loop.
     // We support, reduction loop depth until 2. TODO: Support for n-depth
     // reduction loop.
+    // TODOs: Re-factor 2a and 2b.
     SmallVector<scf::ForOp> loopLists;
     Operation *current = contractOp;
     while (true) {
@@ -1122,6 +1114,7 @@ struct VectorContractToAMXDotProduct
       }
     }
 
+    // Case 2b: Reduction loop depth is 1.
     if (loopLists.size() == 1) {
       outerLoop = loopLists[0];
       innerLoop = loopLists[0];

>From 35ecac82ba288eb9c49124e0dcb2161801e8895f Mon Sep 17 00:00:00 2001
From: Arun Thangamani <arun.thangamani at intel.com>
Date: Mon, 30 Mar 2026 08:46:25 -0700
Subject: [PATCH 4/5] minor change to shuffling order for int8

---
 .../Transforms/VectorContractToAMXDotProduct.cpp | 16 ++++++++--------
 .../X86/AMX/vector-contract-to-tiled-dp.mlir     |  4 ++--
 2 files changed, 10 insertions(+), 10 deletions(-)

diff --git a/mlir/lib/Dialect/X86/Transforms/VectorContractToAMXDotProduct.cpp b/mlir/lib/Dialect/X86/Transforms/VectorContractToAMXDotProduct.cpp
index 63e187613bfc4..eb35720b657ea 100644
--- a/mlir/lib/Dialect/X86/Transforms/VectorContractToAMXDotProduct.cpp
+++ b/mlir/lib/Dialect/X86/Transforms/VectorContractToAMXDotProduct.cpp
@@ -267,19 +267,19 @@ static void performShuffle(OpBuilder &rewriter, Location loc, Value matB,
               vec2,
               ArrayRef<int64_t>{
                   0,   32,  64, 96,  1,   33,  65,  97,  2,   34,  66,  98, 3,
-                  35,  67,  99, 4,   36,  68,  100, 5,   37,  69,  101, 6,  38,
-                  70,  102, 7,  39,  71,  103, 8,   40,  72,  104, 9,   41, 73,
-                  105, 10,  42, 74,  106, 11,  43,  75,  107, 12,  44,  76, 108,
-                  13,  45,  77, 109, 14,  46,  78,  110, 15,  47,  79,  111});
+                  35,  67,  99, 8,   40,  72,  104, 9,   41,  73,  105, 10, 42,
+                  74,  106, 11, 43,  75,  107, 16,  48,  80,  112, 17,  49, 81,
+                  113, 18,  50, 82,  114, 19,  51,  83,  115, 24,  56,  88, 120,
+                  25,  57,  89, 121, 26,  58,  90,  122, 27,  59,  91,  123});
 
           shuffle2 = vector::ShuffleOp::create(
               rewriter, loc, VectorType::get({(16 * offset)}, ipType), vec1,
               vec2,
               ArrayRef<int64_t>{
-                  16, 48,  80, 112, 17, 49,  81, 113, 18, 50,  82, 114, 19, 51,
-                  83, 115, 20, 52,  84, 116, 21, 53,  85, 117, 22, 54,  86, 118,
-                  23, 55,  87, 119, 24, 56,  88, 120, 25, 57,  89, 121, 26, 58,
-                  90, 122, 27, 59,  91, 123, 28, 60,  92, 124, 29, 61,  93, 125,
+                  4,  36,  68, 100, 5,  37,  69, 101, 6,  38,  70, 102, 7,  39,
+                  71, 103, 12, 44,  76, 108, 13, 45,  77, 109, 14, 46,  78, 110,
+                  15, 47,  79, 111, 20, 52,  84, 116, 21, 53,  85, 117, 22, 54,
+                  86, 118, 23, 55,  87, 119, 28, 60,  92, 124, 29, 61,  93, 125,
                   30, 62,  94, 126, 31, 63,  95, 127});
         }
 
diff --git a/mlir/test/Dialect/X86/AMX/vector-contract-to-tiled-dp.mlir b/mlir/test/Dialect/X86/AMX/vector-contract-to-tiled-dp.mlir
index 6bb90a80da66e..1a6deed31eceb 100644
--- a/mlir/test/Dialect/X86/AMX/vector-contract-to-tiled-dp.mlir
+++ b/mlir/test/Dialect/X86/AMX/vector-contract-to-tiled-dp.mlir
@@ -770,8 +770,8 @@ func.func @online_packing_int8_matmul_loop(%arg0: memref<64x256xi8>, %arg1: memr
 // CHECK-LABEL: @online_packing_int8_matmul_loop
 // CHECK-COUNT-4: x86.amx.tile_zero : !x86.amx.tile<16x16xi32>
 // CHECK: scf.for {{.*}} -> (!x86.amx.tile<16x16xi32>, !x86.amx.tile<16x16xi32>, !x86.amx.tile<16x16xi32>, !x86.amx.tile<16x16xi32>) {
-// CHECK: vector.shuffle{{.*}}[0, 32, 64, 96, 1, 33, 65, 97, 2, 34, 66, 98, 3, 35, 67, 99, 4, 36, 68, 100, 5, 37, 69, 101, 6, 38, 70, 102, 7, 39, 71, 103, 8, 40, 72, 104, 9, 41, 73, 105, 10, 42, 74, 106, 11, 43, 75, 107, 12, 44, 76, 108, 13, 45, 77, 109, 14, 46, 78, 110, 15, 47, 79, 111] : vector<64xi8>, vector<64xi8>
-// CHECK-NEXT: vector.shuffle{{.*}}[16, 48, 80, 112, 17, 49, 81, 113, 18, 50, 82, 114, 19, 51, 83, 115, 20, 52, 84, 116, 21, 53, 85, 117, 22, 54, 86, 118, 23, 55, 87, 119, 24, 56, 88, 120, 25, 57, 89, 121, 26, 58, 90, 122, 27, 59, 91, 123, 28, 60, 92, 124, 29, 61, 93, 125, 30, 62, 94, 126, 31, 63, 95, 127] : vector<64xi8>, vector<64xi8>
+// CHECK: vector.shuffle{{.*}}[0, 32, 64, 96, 1, 33, 65, 97, 2, 34, 66, 98, 3, 35, 67, 99, 8, 40, 72, 104, 9, 41, 73, 105, 10, 42, 74, 106, 11, 43, 75, 107, 16, 48, 80, 112, 17, 49, 81, 113, 18, 50, 82, 114, 19, 51, 83, 115, 24, 56, 88, 120, 25, 57, 89, 121, 26, 58, 90, 122, 27, 59, 91, 123] : vector<64xi8>, vector<64xi8>
+// CHECK-NEXT: vector.shuffle{{.*}}[4, 36, 68, 100, 5, 37, 69, 101, 6, 38, 70, 102, 7, 39, 71, 103, 12, 44, 76, 108, 13, 45, 77, 109, 14, 46, 78, 110, 15, 47, 79, 111, 20, 52, 84, 116, 21, 53, 85, 117, 22, 54, 86, 118, 23, 55, 87, 119, 28, 60, 92, 124, 29, 61, 93, 125, 30, 62, 94, 126, 31, 63, 95, 127] : vector<64xi8>, vector<64xi8>
 // CHECK: x86.amx.tile_load
 // CHECK: x86.amx.tile_muli
 // CHECK: scf.yield {{.*}} !x86.amx.tile<16x16xi32>, !x86.amx.tile<16x16xi32>, !x86.amx.tile<16x16xi32>, !x86.amx.tile<16x16xi32>

>From b1b4c569ab021bda2906ee8601687c1721d6d289 Mon Sep 17 00:00:00 2001
From: Arun Thangamani <arun.thangamani at intel.com>
Date: Wed, 8 Apr 2026 11:08:57 -0700
Subject: [PATCH 5/5] remove un-necessary braces

---
 .../X86/Transforms/VectorContractToAMXDotProduct.cpp     | 9 +++++----
 1 file changed, 5 insertions(+), 4 deletions(-)

diff --git a/mlir/lib/Dialect/X86/Transforms/VectorContractToAMXDotProduct.cpp b/mlir/lib/Dialect/X86/Transforms/VectorContractToAMXDotProduct.cpp
index eb35720b657ea..cc66308e98260 100644
--- a/mlir/lib/Dialect/X86/Transforms/VectorContractToAMXDotProduct.cpp
+++ b/mlir/lib/Dialect/X86/Transforms/VectorContractToAMXDotProduct.cpp
@@ -966,10 +966,8 @@ struct VectorContractToAMXDotProduct
       unsigned int pairCount = 0;
       for (size_t j = 0; j < ops.size(); j++) {
         for (size_t i = j; i < ops.size(); i++) {
-
-          if (i != j && validatePairVectorContract(ops[j], ops[i], true, 16)) {
+          if (i != j && validatePairVectorContract(ops[j], ops[i], true, 16))
             pairCount = pairCount + 2;
-          }
         }
       }
 
@@ -1116,7 +1114,6 @@ struct VectorContractToAMXDotProduct
 
     // Case 2b: Reduction loop depth is 1.
     if (loopLists.size() == 1) {
-      outerLoop = loopLists[0];
       innerLoop = loopLists[0];
 
       SmallVector<Value> loopItrArgs = createTileZeros(
@@ -1189,6 +1186,10 @@ struct VectorContractToAMXDotProduct
                               packedBuffer, false, c0, isInnerLoopUBLarger,
                               isInnerLoopUBHasOddQuot);
       }
+
+      // This helps the final store back to the acc uses the same code for
+      // the both reduction loop depth 1 or 2.
+      outerLoop = innerLoop;
     }
 
     // Copy the amx tile accumulation results to a MemRef buffer, add the



More information about the Mlir-commits mailing list