[Mlir-commits] [mlir] [mlir][x86] Fix - multiple issues / F8 support for AMX dot-product lowering (PR #196984)

Arun Thangamani llvmlistbot at llvm.org
Mon May 25 04:51:51 PDT 2026


https://github.com/arun-thmn updated https://github.com/llvm/llvm-project/pull/196984

>From 91394a7e8e01385a767407aa2340cf9e2486bc4c Mon Sep 17 00:00:00 2001
From: Arun Thangamani <arun.thangamani at intel.com>
Date: Mon, 11 May 2026 09:02:23 -0700
Subject: [PATCH 01/10] fixex issues with AMX dot-product lowering

---
 .../VectorContractToAMXDotProduct.cpp         | 391 +++++++++---------
 1 file changed, 202 insertions(+), 189 deletions(-)

diff --git a/mlir/lib/Dialect/X86/Transforms/VectorContractToAMXDotProduct.cpp b/mlir/lib/Dialect/X86/Transforms/VectorContractToAMXDotProduct.cpp
index 94b94292e675f..551fccb47e114 100644
--- a/mlir/lib/Dialect/X86/Transforms/VectorContractToAMXDotProduct.cpp
+++ b/mlir/lib/Dialect/X86/Transforms/VectorContractToAMXDotProduct.cpp
@@ -27,6 +27,31 @@ using namespace mlir::x86;
 
 namespace {
 
+static Value traceToVectorWriteLikeUserOperationForAMX(Value v) {
+  if (v.getNumUses() > 1)
+    return nullptr;
+
+  for (OpOperand &use : v.getUses()) {
+    Operation *user = use.getOwner();
+
+    if (!isa<scf::YieldOp>(user)) {
+      return v;
+    }
+
+    // --- SCF YIELD ---
+    if (auto yield = dyn_cast<scf::YieldOp>(user)) {
+      Operation *parent = yield->getParentOp();
+      unsigned idx = use.getOperandNumber();
+      if (auto res =
+              traceToVectorWriteLikeUserOperationForAMX(parent->getResult(idx)))
+        return res;
+      continue;
+    }
+  }
+
+  return nullptr;
+}
+
 // Function to collapse the last two dimension (vnni and k) to help the
 // amx.tile_load to correctly load the packed element type.
 static Value collapseInnerDims(OpBuilder &builder, mlir::Location loc,
@@ -216,22 +241,30 @@ static void performShuffle(OpBuilder &rewriter, Location loc, Value matB,
   Value c16 = arith::ConstantIndexOp::create(rewriter, loc, 16);
 
   auto subview = matB.getDefiningOp<mlir::memref::SubViewOp>();
-  SmallVector<Value> subviewOffset(subview.getOffsets().size(), c0);
+  SmallVector<Value> subviewOffset(subview.getMixedOffsets().size(), c0);
 
   Value cStep = arith::ConstantIndexOp::create(rewriter, loc, offset);
   Value cBound = arith::ConstantIndexOp::create(rewriter, loc, (16 * offset));
   Value offsetIndx =
       arith::ConstantIndexOp::create(rewriter, loc, (offset / 2));
 
+  // llvm::outs() << "check-a:" << matB << " subview:" << subview << "\n";
+  // llvm::outs() << "The size:" << subview.getMixedOffsets().size() << "\n";
+
   scf::ForOp::create(
       rewriter, loc, c0, cBound, cStep, ValueRange{},
       [&](OpBuilder &nestedBuilder, Location loc, Value iv,
           ValueRange iterArgs) {
+        // llvm::outs() << "check-a0" << subviewOffset.size() <<  "\n";
         subviewOffset[subviewOffset.size() - 2] = iv;
+
+        // llvm::outs() << "check-a1" << "\n";
         auto vec1 = vector::LoadOp::create(
             rewriter, loc, VectorType::get((16 * offset), ipType), matB,
             ValueRange(subviewOffset));
 
+        // llvm::outs() << "check-b" << "\n";
+
         // Increment the iv by 1 or 2 based on the type to load the next 32/64
         // elements
         Value incIV = arith::AddIOp::create(rewriter, loc, offsetIndx, iv);
@@ -243,6 +276,8 @@ static void performShuffle(OpBuilder &rewriter, Location loc, Value matB,
         vector::ShuffleOp shuffle1;
         vector::ShuffleOp shuffle2;
 
+        // llvm::outs() << "check-c" << "\n";
+
         if (ipType.isBF16()) {
 
           shuffle1 = vector::ShuffleOp::create(
@@ -283,6 +318,8 @@ static void performShuffle(OpBuilder &rewriter, Location loc, Value matB,
                   30, 62,  94, 126, 31, 63,  95, 127});
         }
 
+        // llvm::outs() << "check-d" << "\n";
+
         // iv to store the shuffled elements
         Value ivShuff1 = arith::DivUIOp::create(rewriter, loc, iv, cStep);
         Value ivShuff2 = arith::AddIOp::create(rewriter, loc, ivShuff1, c16);
@@ -468,6 +505,8 @@ createLoops(OpBuilder &rewriter, Location loc, Value lowerBound,
   Value c1 = arith::ConstantIndexOp::create(rewriter, loc, 1);
   Value c2 = arith::ConstantIndexOp::create(rewriter, loc, 2);
 
+  int64_t offset = step.getDefiningOp<arith::ConstantIndexOp>().value();
+
   auto newLoop = scf::ForOp::create(
       rewriter, loc, lowerBound, upperBound, step, loopItrArgs,
       [&](OpBuilder &rewriterNewInnerLoop, Location locNewInnerLoop,
@@ -485,7 +524,6 @@ createLoops(OpBuilder &rewriter, Location loc, Value lowerBound,
 
         Value indxToStoreInBuffer = c0;
         Value indxToLoadFromBuffer = c0;
-
         if (!isVnni) {
           if (outerLoop) {
             if (innerLoopIndex.value() == 0) {
@@ -509,7 +547,7 @@ createLoops(OpBuilder &rewriter, Location loc, Value lowerBound,
 
             } else {
               Value nLoadIndx = arith::ConstantIndexOp::create(
-                  rewriter, locNewInnerLoop, (16 * blockingFactor));
+                  rewriter, locNewInnerLoop, offset);
               ivNewInnerLoop = arith::AddIOp::create(rewriter, locNewInnerLoop,
                                                      nLoadIndx, ivNewInnerLoop);
               indxToStoreInBuffer = getIndxToLoadStoreFromPckBuffer(
@@ -525,7 +563,7 @@ createLoops(OpBuilder &rewriter, Location loc, Value lowerBound,
           } else {
             if (pack) {
               Value nLoadIndx = arith::ConstantIndexOp::create(
-                  rewriter, locNewInnerLoop, (16 * blockingFactor));
+                  rewriter, locNewInnerLoop, offset);
               ivNewInnerLoop = arith::AddIOp::create(rewriter, locNewInnerLoop,
                                                      nLoadIndx, ivNewInnerLoop);
               Value quotient_K = arith::DivUIOp::create(
@@ -541,27 +579,49 @@ createLoops(OpBuilder &rewriter, Location loc, Value lowerBound,
             }
           }
         }
-
         IRMapping rhsMapping;
-        if (outerLoop)
-          rhsMapping.map(
-              vectorOpRhs->getOperand(
-                  getIndexPosition(contractOp.getRhs(), outerLoop) + 1),
-              ivOuterLoop);
 
-        rhsMapping.map(
-            vectorOpRhs->getOperand(
-                getIndexPosition(contractOp.getRhs(), innerLoop) + 1),
-            ivNewInnerLoop);
-        auto rhsClone = rewriterNewInnerLoop.clone(*vectorOpRhs, rhsMapping);
+        Value matB;
+        Operation *rhsOp = vectorOpRhs;
 
-        Value matB = rhsClone->getResult(0);
+        // Clone only if the op has operands.
+        if (rhsOp->getNumOperands() > 0) {
 
+          if (outerLoop) {
+            int64_t outerPos = getIndexPosition(contractOp.getRhs(), outerLoop);
+
+            if (outerPos >= 0) {
+              unsigned operandIdx = static_cast<unsigned>(outerPos + 1);
+
+              if (operandIdx < rhsOp->getNumOperands()) {
+                rhsMapping.map(rhsOp->getOperand(operandIdx), ivOuterLoop);
+              }
+            }
+          }
+
+          int64_t innerPos = getIndexPosition(contractOp.getRhs(), innerLoop);
+
+          if (innerPos >= 0) {
+            unsigned operandIdx = static_cast<unsigned>(innerPos + 1);
+
+            if (operandIdx < rhsOp->getNumOperands()) {
+              rhsMapping.map(rhsOp->getOperand(operandIdx), ivNewInnerLoop);
+            }
+          }
+
+          auto rhsClone = rewriterNewInnerLoop.clone(*rhsOp, rhsMapping);
+
+          matB = rhsClone->getResult(0);
+
+        } else {
+          // memref.get_global / constants
+          matB = rhsOp->getResult(0);
+        }
         if (!isVnni) {
           if (outerLoop) {
             if (!pack) {
               Value nLoadIndx = arith::ConstantIndexOp::create(
-                  rewriter, locNewInnerLoop, (16 * blockingFactor));
+                  rewriter, locNewInnerLoop, offset);
               matB = Value();
               indxToLoadFromBuffer = c0;
               indxToLoadFromBuffer = getIndxToLoadStoreFromPckBuffer(
@@ -572,7 +632,7 @@ createLoops(OpBuilder &rewriter, Location loc, Value lowerBound,
           } else {
             if (!pack) {
               Value nLoadIndx = arith::ConstantIndexOp::create(
-                  rewriter, locNewInnerLoop, (16 * blockingFactor));
+                  rewriter, locNewInnerLoop, offset);
               matB = Value();
               Value quotient_K = arith::DivUIOp::create(
                   rewriter, loc, ivNewInnerLoop, nLoadIndx);
@@ -581,7 +641,6 @@ createLoops(OpBuilder &rewriter, Location loc, Value lowerBound,
             }
           }
         }
-
         // compute tiled dot-product
         SmallVector<Value> accumulators = createTiledDp(
             rewriter, locNewInnerLoop, ops, lhsClone->getResult(0), matB,
@@ -860,7 +919,7 @@ struct VectorContractToAMXDotProduct
                                       14, 30, 46, 62, 15, 31, 47, 63});
               }
 
-              auto rem = arith::RemUIOp::create(
+              auto rem = arith::DivUIOp::create(
                   rewriter, loc, rewriter.getIndexType(), iv, step);
 
               vector::StoreOp::create(rewriter, loc, shuffle1, packedBuffer,
@@ -988,6 +1047,7 @@ struct VectorContractToAMXDotProduct
     scf::ForOp newLoop;
     // Case 2a: Reduction loop depth is 2.
     if (loopLists.size() == 2) {
+
       outerLoop = loopLists[1];
       innerLoop = loopLists[0];
 
@@ -1120,8 +1180,8 @@ struct VectorContractToAMXDotProduct
 
     // Case 2b: Reduction loop depth is 1.
     if (loopLists.size() == 1) {
-      innerLoop = loopLists[0];
 
+      innerLoop = loopLists[0];
       SmallVector<Value> loopItrArgs = createTileZeros(
           rewriter, innerLoop.getLoc(), opType, innerLoop, ops.size());
 
@@ -1135,6 +1195,7 @@ struct VectorContractToAMXDotProduct
             nullptr, false, false);
 
       } else {
+
         bool isInnerLoopUBLarger = false;
         bool isInnerLoopUBHasOddQuot = false;
 
@@ -1154,8 +1215,12 @@ struct VectorContractToAMXDotProduct
         rewriter.setInsertionPoint(innerLoop);
         auto c0 =
             arith::ConstantIndexOp::create(rewriter, innerLoop.getLoc(), 0);
+
+        int64_t stepVal =
+            innerLoop.getStep().getDefiningOp<arith::ConstantIndexOp>().value();
+
         auto spillLoopBound = arith::ConstantIndexOp::create(
-            rewriter, innerLoop.getLoc(), 16 * blockingFactor);
+            rewriter, innerLoop.getLoc(), stepVal);
 
         Value spillInnerLoop =
             arith::SubIOp::create(rewriter, innerLoop.getLoc(),
@@ -1173,10 +1238,8 @@ struct VectorContractToAMXDotProduct
                 getIndexPosition(contractOp.getRhs(), innerLoop) + 1),
             c0);
         auto rhsClone = rewriter.clone(*vectorOpRhs, rhsMapping);
-
         performShuffle(rewriter, innerLoop.getLoc(), rhsClone->getResult(0),
                        ipType, blockingFactor, packedBuffer, c0);
-
         auto newLoopNonSpill = createLoops(
             rewriter, innerLoop.getLoc(), innerLoop.getLowerBound(),
             spillInnerLoop, innerLoop.getStep(), loopItrArgs, ipType, opType,
@@ -1200,194 +1263,144 @@ struct VectorContractToAMXDotProduct
 
     // Copy the amx tile accumulation results to a MemRef buffer, add the
     // initial accumulation value, and store back to the C-Matrix
+    Location loc = outerLoop.getLoc();
+    Value srcBuffAcc;
+    SmallVector<Value> indicesAcc;
+
+    llvm::TypeSwitch<Operation *>(accReadOp).Case<TransferReadOp, LoadOp>(
+        [&](auto readOp) {
+          srcBuffAcc = readOp.getOperand(0);
+
+          auto indices = readOp.getIndices();
+          indicesAcc.reserve(indices.size());
+
+          llvm::transform(indices, std::back_inserter(indicesAcc),
+                          [&](OpFoldResult ofr) {
+                            return mlir::getValueOrCreateConstantIndexOp(
+                                rewriter, loc, ofr);
+                          });
+        });
 
-    if (!isVnni) {
-      Location loc = outerLoop.getLoc();
-      Operation *accReadOp =
-          traceToVectorReadLikeParentOperation(contractOp.getAcc());
-
-      Value srcBuffAcc;
-      SmallVector<Value> indicesAcc;
-
-      llvm::TypeSwitch<Operation *>(accReadOp).Case<TransferReadOp, LoadOp>(
-          [&](auto readOp) {
-            srcBuffAcc = readOp.getOperand(0);
-
-            auto indices = readOp.getIndices();
-            indicesAcc.reserve(indices.size());
-
-            llvm::transform(indices, std::back_inserter(indicesAcc),
-                            [&](OpFoldResult ofr) {
-                              return mlir::getValueOrCreateConstantIndexOp(
-                                  rewriter, loc, ofr);
-                            });
-          });
-
-      auto outputShapes =
-          mlir::cast<mlir::MemRefType>(srcBuffAcc.getType()).getShape();
-      unsigned int M = outputShapes[outputShapes.size() - 2];
-      unsigned int N = outputShapes[outputShapes.size() - 1];
-
-      SmallVector<Value> dps = newLoop.getResults();
-      auto bufferType = MemRefType::get({M, N}, opType);
-      auto resultBuffer = memref::AllocaOp::create(rewriter, loc, bufferType);
-
-      // Store the amx tiled-dot product output into an MxN memref.
-      for (unsigned int i = 0, k = 0; i < M; i = i + 16) {
-        for (unsigned int j = 0; j < N; j = j + 16) {
-          Value indexOp_i = arith::ConstantIndexOp::create(rewriter, loc, i);
-          Value indexOp_j = arith::ConstantIndexOp::create(rewriter, loc, j);
-          amx::TileStoreOp::create(rewriter, loc, resultBuffer,
-                                   ValueRange{indexOp_i, indexOp_j}, dps[k]);
-          k++;
-        }
+    auto outputShapes =
+        mlir::cast<mlir::MemRefType>(srcBuffAcc.getType()).getShape();
+    unsigned int M = outputShapes[outputShapes.size() - 2];
+    unsigned int N = outputShapes[outputShapes.size() - 1];
+
+    SmallVector<Value> dps = newLoop.getResults();
+    auto bufferType = MemRefType::get({M, N}, opType);
+    auto resultBuffer = memref::AllocaOp::create(rewriter, loc, bufferType);
+
+    // Store the amx tiled-dot product output into an MxN memref.
+    for (unsigned int i = 0, k = 0; i < M; i = i + 16) {
+      for (unsigned int j = 0; j < N; j = j + 16) {
+        Value indexOp_i = arith::ConstantIndexOp::create(rewriter, loc, i);
+        Value indexOp_j = arith::ConstantIndexOp::create(rewriter, loc, j);
+        amx::TileStoreOp::create(rewriter, loc, resultBuffer,
+                                 ValueRange{indexOp_i, indexOp_j}, dps[k]);
+        k++;
       }
-      auto c0 = arith::ConstantIndexOp::create(rewriter, loc, 0);
-      auto c16 = arith::ConstantIndexOp::create(rewriter, loc, 16);
-      auto one = arith::ConstantIndexOp::create(rewriter, loc, 1);
-      auto mBound = arith::ConstantIndexOp::create(rewriter, loc, N);
-
-      // Create a loop that iterates over the MxN memerf, retrives two rows +
-      // shuffle them, add up the C element values and stores them back.
-      scf::ForOp::create(
-          rewriter, loc, c0, mBound, one, ValueRange{},
-          [&](OpBuilder &nestedBuilder, Location loc, Value iv,
-              ValueRange iterArgs) {
-            auto row = vector::LoadOp::create(rewriter, loc,
-                                              VectorType::get(16, opType),
-                                              resultBuffer, ValueRange{iv, c0});
-
-            auto row2 = vector::LoadOp::create(
-                rewriter, loc, VectorType::get(16, opType), resultBuffer,
-                ValueRange{iv, c16});
-
-            auto shuffle1 = vector::ShuffleOp::create(
+    }
+    auto c0 = arith::ConstantIndexOp::create(rewriter, loc, 0);
+    auto c16 = arith::ConstantIndexOp::create(rewriter, loc, 16);
+    auto one = arith::ConstantIndexOp::create(rewriter, loc, 1);
+    auto mBound = arith::ConstantIndexOp::create(rewriter, loc, N);
+
+    // Create a loop that iterates over the MxN memerf, retrives two rows +
+    // shuffle them, add up the C element values and stores them back.
+    scf::ForOp::create(
+        rewriter, loc, c0, mBound, one, ValueRange{},
+        [&](OpBuilder &nestedBuilder, Location loc, Value iv,
+            ValueRange iterArgs) {
+          auto row =
+              vector::LoadOp::create(rewriter, loc, VectorType::get(16, opType),
+                                     resultBuffer, ValueRange{iv, c0});
+
+          auto row2 =
+              vector::LoadOp::create(rewriter, loc, VectorType::get(16, opType),
+                                     resultBuffer, ValueRange{iv, c16});
+
+          Value shuffle1 = row;
+          Value shuffle2 = row2;
+
+          if (!isVnni) {
+            shuffle1 = vector::ShuffleOp::create(
                 rewriter, loc, VectorType::get(16, opType), row, row2,
                 ArrayRef<int64_t>{0, 1, 2, 3, 16, 17, 18, 19, 4, 5, 6, 7, 20,
                                   21, 22, 23});
 
-            auto shuffle2 = vector::ShuffleOp::create(
+            shuffle2 = vector::ShuffleOp::create(
                 rewriter, loc, VectorType::get(16, opType), row, row2,
                 ArrayRef<int64_t>{8, 9, 10, 11, 24, 25, 26, 27, 12, 13, 14, 15,
                                   28, 29, 30, 31});
+          }
+          indicesAcc[indicesAcc.size() - 2] = iv;
+          indicesAcc[indicesAcc.size() - 1] = c0;
 
-            indicesAcc[indicesAcc.size() - 2] = iv;
-            indicesAcc[indicesAcc.size() - 1] = c0;
+          Value valueCRow1 =
+              vector::LoadOp::create(rewriter, loc, VectorType::get(16, opType),
+                                     srcBuffAcc, indicesAcc);
+          indicesAcc[indicesAcc.size() - 1] = c16;
 
-            Value valueCRow1 = vector::LoadOp::create(
-                rewriter, loc, VectorType::get(16, opType), srcBuffAcc,
-                indicesAcc);
-            indicesAcc[indicesAcc.size() - 1] = c16;
+          Value valueCRow2 =
+              vector::LoadOp::create(rewriter, loc, VectorType::get(16, opType),
+                                     srcBuffAcc, indicesAcc);
 
-            Value valueCRow2 = vector::LoadOp::create(
-                rewriter, loc, VectorType::get(16, opType), srcBuffAcc,
-                indicesAcc);
+          Value addOp;
+          Value addOp2;
 
-            Value addOp;
-            Value addOp2;
+          if (ipType.isBF16()) {
+            addOp = arith::AddFOp::create(rewriter, loc, shuffle1, valueCRow1);
 
-            if (ipType.isBF16()) {
-              addOp =
-                  arith::AddFOp::create(rewriter, loc, shuffle1, valueCRow1);
+            addOp2 = arith::AddFOp::create(rewriter, loc, shuffle2, valueCRow2);
+          }
 
-              addOp2 =
-                  arith::AddFOp::create(rewriter, loc, shuffle2, valueCRow2);
-            }
+          if (ipType.isSignlessInteger(8)) {
+            addOp = arith::AddIOp::create(rewriter, loc, shuffle1, valueCRow1);
 
-            if (ipType.isSignlessInteger(8)) {
-              addOp =
-                  arith::AddIOp::create(rewriter, loc, shuffle1, valueCRow1);
+            addOp2 = arith::AddIOp::create(rewriter, loc, shuffle2, valueCRow2);
+          }
 
-              addOp2 =
-                  arith::AddIOp::create(rewriter, loc, shuffle2, valueCRow2);
-            }
-            indicesAcc[indicesAcc.size() - 1] = c0;
-            vector::StoreOp::create(rewriter, loc, addOp, srcBuffAcc,
-                                    indicesAcc);
-            indicesAcc[indicesAcc.size() - 1] = c16;
-            vector::StoreOp::create(rewriter, loc, addOp2, srcBuffAcc,
-                                    indicesAcc);
-
-            scf::YieldOp::create(nestedBuilder, loc);
-          });
-    }
+          vector::StoreOp::create(rewriter, loc, addOp, resultBuffer,
+                                  ValueRange{iv, c0});
+          vector::StoreOp::create(rewriter, loc, addOp2, resultBuffer,
+                                  ValueRange{iv, c16});
 
-    auto bufferType = MemRefType::get({16, 16}, opType);
-    auto resultBuffer =
-        memref::AllocaOp::create(rewriter, outerLoop.getLoc(), bufferType);
-    SmallVector<Value> dps = newLoop.getResults();
+          scf::YieldOp::create(nestedBuilder, loc);
+        });
+
+    SmallVector<Value> writeResults;
+    for (unsigned int i = 0; i < M; i = i + 16) {
+      for (unsigned int j = 0; j < N; j = j + 16) {
+        Value indexOp_i = arith::ConstantIndexOp::create(rewriter, loc, i);
+        Value indexOp_j = arith::ConstantIndexOp::create(rewriter, loc, j);
+
+        auto flatTy = mlir::VectorType::get({16, 16}, opType);
+
+        int64_t srcRank =
+            (dyn_cast<ShapedType>(resultBuffer.getType())).getRank();
+        Value padding = ub::PoisonOp::create(rewriter, loc, opType);
+        auto map = AffineMap::getMinorIdentityMap(srcRank, flatTy.getRank(),
+                                                  rewriter.getContext());
+        SmallVector<bool> inBounds(flatTy.getRank(), true);
+
+        auto vec1 = vector::TransferReadOp::create(
+            rewriter, loc, flatTy, resultBuffer,
+            ValueRange{indexOp_i, indexOp_j}, padding, map, inBounds);
+        writeResults.push_back(vec1);
+      }
+    }
 
     for (size_t i = 0; i < ops.size(); i++) {
       vector::ContractionOp contOp = ops[i];
-      Operation *resultWriteOp =
-          traceToVectorWriteLikeUserOperation(contOp.getResult());
-      if (isVnni) {
-        rewriter.setInsertionPoint(resultWriteOp);
-
-        Value indexOp_0 =
-            arith::ConstantIndexOp::create(rewriter, outerLoop.getLoc(), 0);
-
-        amx::TileStoreOp::create(rewriter, outerLoop.getLoc(), resultBuffer,
-                                 ValueRange{indexOp_0, indexOp_0}, dps[i]);
-
-        auto c0 =
-            arith::ConstantIndexOp::create(rewriter, outerLoop.getLoc(), 0);
-        auto one =
-            arith::ConstantIndexOp::create(rewriter, outerLoop.getLoc(), 1);
-        auto mBound =
-            arith::ConstantIndexOp::create(rewriter, outerLoop.getLoc(), 16);
-
-        scf::ForOp::create(
-            rewriter, outerLoop.getLoc(), c0, mBound, one, ValueRange{},
-            [&](OpBuilder &builder, Location loc, Value iv,
-                ValueRange iterArgs) {
-              auto resultAcc = vector::LoadOp::create(
-                  rewriter, loc, VectorType::get(16, opType), resultBuffer,
-                  ValueRange{iv, c0});
-
-              Operation *accReadOp =
-                  traceToVectorReadLikeParentOperation(ops[i].getAcc());
-
-              Value srcBuffAcc;
-              SmallVector<Value> indicesAcc;
+      Value vecRoc = writeResults[i];
 
-              llvm::TypeSwitch<Operation *>(accReadOp)
-                  .Case<TransferReadOp, LoadOp>([&](auto readOp) {
-                    srcBuffAcc = readOp.getOperand(0);
-
-                    auto indices = readOp.getIndices();
-                    indicesAcc.reserve(indices.size());
-
-                    llvm::transform(
-                        indices, std::back_inserter(indicesAcc),
-                        [&](OpFoldResult ofr) {
-                          return mlir::getValueOrCreateConstantIndexOp(
-                              rewriter, loc, ofr);
-                        });
-                  });
-
-              Value sum =
-                  arith::AddIOp::create(builder, loc, iv, indicesAcc[0]);
-              indicesAcc[indicesAcc.size() - 2] = sum;
-
-              auto acc = vector::LoadOp::create(rewriter, loc,
-                                                VectorType::get(16, opType),
-                                                srcBuffAcc, indicesAcc);
-              Value addition;
-              if (ipType.isBF16())
-                addition = arith::AddFOp::create(rewriter, loc, resultAcc, acc);
-
-              if (ipType.isSignlessInteger(8))
-                addition = arith::AddIOp::create(rewriter, loc, resultAcc, acc);
-
-              vector::StoreOp::create(builder, loc, addition, srcBuffAcc,
-                                      indicesAcc);
-
-              scf::YieldOp::create(builder, outerLoop.getLoc());
-            });
+      Value resultWriteOp =
+          traceToVectorWriteLikeUserOperationForAMX(contOp.getResult());
+      if (auto vecType = llvm::dyn_cast<VectorType>(resultWriteOp.getType())) {
+        vecRoc = mlir::vector::ShapeCastOp::create(rewriter, loc, vecType,
+                                                   writeResults[i]);
       }
-
-      rewriter.eraseOp(resultWriteOp);
+      resultWriteOp.replaceAllUsesWith(vecRoc);
     }
 
     return success();

>From 8e448d5f2a3e369dcc9514587e242e913d3e7f67 Mon Sep 17 00:00:00 2001
From: Arun Thangamani <arun.thangamani at intel.com>
Date: Tue, 12 May 2026 05:55:29 -0700
Subject: [PATCH 02/10] counting offset on the subview result

---
 .../VectorContractToAMXDotProduct.cpp         |  47 ++++----
 .../X86/AMX/vector-contract-to-tiled-dp.mlir  | 100 ++++++++++++++++--
 2 files changed, 122 insertions(+), 25 deletions(-)

diff --git a/mlir/lib/Dialect/X86/Transforms/VectorContractToAMXDotProduct.cpp b/mlir/lib/Dialect/X86/Transforms/VectorContractToAMXDotProduct.cpp
index 551fccb47e114..64e5a6b56504b 100644
--- a/mlir/lib/Dialect/X86/Transforms/VectorContractToAMXDotProduct.cpp
+++ b/mlir/lib/Dialect/X86/Transforms/VectorContractToAMXDotProduct.cpp
@@ -239,32 +239,24 @@ static void performShuffle(OpBuilder &rewriter, Location loc, Value matB,
 
   Value c0 = arith::ConstantIndexOp::create(rewriter, loc, 0);
   Value c16 = arith::ConstantIndexOp::create(rewriter, loc, 16);
-
-  auto subview = matB.getDefiningOp<mlir::memref::SubViewOp>();
-  SmallVector<Value> subviewOffset(subview.getMixedOffsets().size(), c0);
+  SmallVector<Value> subviewOffset(
+      llvm::cast<MemRefType>(matB.getType()).getRank(), c0);
 
   Value cStep = arith::ConstantIndexOp::create(rewriter, loc, offset);
   Value cBound = arith::ConstantIndexOp::create(rewriter, loc, (16 * offset));
   Value offsetIndx =
       arith::ConstantIndexOp::create(rewriter, loc, (offset / 2));
 
-  // llvm::outs() << "check-a:" << matB << " subview:" << subview << "\n";
-  // llvm::outs() << "The size:" << subview.getMixedOffsets().size() << "\n";
-
   scf::ForOp::create(
       rewriter, loc, c0, cBound, cStep, ValueRange{},
       [&](OpBuilder &nestedBuilder, Location loc, Value iv,
           ValueRange iterArgs) {
-        // llvm::outs() << "check-a0" << subviewOffset.size() <<  "\n";
         subviewOffset[subviewOffset.size() - 2] = iv;
 
-        // llvm::outs() << "check-a1" << "\n";
         auto vec1 = vector::LoadOp::create(
             rewriter, loc, VectorType::get((16 * offset), ipType), matB,
             ValueRange(subviewOffset));
 
-        // llvm::outs() << "check-b" << "\n";
-
         // Increment the iv by 1 or 2 based on the type to load the next 32/64
         // elements
         Value incIV = arith::AddIOp::create(rewriter, loc, offsetIndx, iv);
@@ -276,8 +268,6 @@ static void performShuffle(OpBuilder &rewriter, Location loc, Value matB,
         vector::ShuffleOp shuffle1;
         vector::ShuffleOp shuffle2;
 
-        // llvm::outs() << "check-c" << "\n";
-
         if (ipType.isBF16()) {
 
           shuffle1 = vector::ShuffleOp::create(
@@ -318,8 +308,6 @@ static void performShuffle(OpBuilder &rewriter, Location loc, Value matB,
                   30, 62,  94, 126, 31, 63,  95, 127});
         }
 
-        // llvm::outs() << "check-d" << "\n";
-
         // iv to store the shuffled elements
         Value ivShuff1 = arith::DivUIOp::create(rewriter, loc, iv, cStep);
         Value ivShuff2 = arith::AddIOp::create(rewriter, loc, ivShuff1, c16);
@@ -829,6 +817,8 @@ struct VectorContractToAMXDotProduct
                                            "The ACC src is not a MemRef type.");
       auto [srcBuffAcc, indicesAcc] = *srcIndxAcc;
 
+      Value c0 = arith::ConstantIndexOp::create(rewriter, loc, 0);
+
       // amx.tile_loads
       auto tileType = amx::TileType::get({16, (16 * blockingFactor)}, ipType);
       auto loadLhs = amx::TileLoadOp::create(rewriter, loc, tileType,
@@ -856,7 +846,6 @@ struct VectorContractToAMXDotProduct
         auto packedBuffer = memref::AllocaOp::create(rewriter, loc, bufferType);
 
         // create a loop that does online packing.
-        Value c0 = arith::ConstantIndexOp::create(rewriter, loc, 0);
         Value step =
             arith::ConstantIndexOp::create(rewriter, loc, blockingFactor);
         Value uBound = arith::ConstantIndexOp::create(rewriter, loc,
@@ -951,9 +940,32 @@ struct VectorContractToAMXDotProduct
         dp = amx::TileMulIOp::create(rewriter, loc, tileTypeAcc, loadLhs,
                                      loadRhs, loadAcc);
 
-      amx::TileStoreOp::create(rewriter, loc, srcBuffAcc, indicesAcc, dp);
+      auto bufferType = MemRefType::get({16, 16}, opType);
+      auto resultBuffer = memref::AllocaOp::create(rewriter, loc, bufferType);
+
+      amx::TileStoreOp::create(rewriter, loc, resultBuffer, ValueRange{c0, c0},
+                               dp);
+
+      auto flatTy = mlir::VectorType::get({16, 16}, opType);
+      int64_t srcRank =
+          (dyn_cast<ShapedType>(resultBuffer.getType())).getRank();
+      Value padding = ub::PoisonOp::create(rewriter, loc, opType);
+      auto map = AffineMap::getMinorIdentityMap(srcRank, flatTy.getRank(),
+                                                rewriter.getContext());
+      SmallVector<bool> inBounds(flatTy.getRank(), true);
+
+      Value vecRow = vector::TransferReadOp::create(
+          rewriter, loc, flatTy, resultBuffer, ValueRange{c0, c0}, padding, map,
+          inBounds);
+
+      Value resultOp =
+          traceToVectorWriteLikeUserOperationForAMX(contractOp.getResult());
+      if (auto vecType = llvm::dyn_cast<VectorType>(resultOp.getType())) {
+        vecRow =
+            mlir::vector::ShapeCastOp::create(rewriter, loc, vecType, vecRow);
+      }
 
-      rewriter.eraseOp(resultWriteOp);
+      resultOp.replaceAllUsesWith(vecRow);
       return success();
     }
 
@@ -1186,7 +1198,6 @@ struct VectorContractToAMXDotProduct
           rewriter, innerLoop.getLoc(), opType, innerLoop, ops.size());
 
       if (isVnni) {
-
         newLoop = createLoops(
             rewriter, innerLoop.getLoc(), innerLoop.getLowerBound(),
             innerLoop.getUpperBound(), innerLoop.getStep(), loopItrArgs, ipType,
diff --git a/mlir/test/Dialect/X86/AMX/vector-contract-to-tiled-dp.mlir b/mlir/test/Dialect/X86/AMX/vector-contract-to-tiled-dp.mlir
index 1a6deed31eceb..20d269fd6ff88 100644
--- a/mlir/test/Dialect/X86/AMX/vector-contract-to-tiled-dp.mlir
+++ b/mlir/test/Dialect/X86/AMX/vector-contract-to-tiled-dp.mlir
@@ -239,13 +239,17 @@ func.func @online_packing_int8(
 
   %3 = vector.transfer_read %arg2[%c0, %c0], %32 {in_bounds = [true, true]} : !memrefC, !vecC
 
+  %bias = arith.constant dense<13> : !vecC
+
   %4 = vector.contract {
     indexing_maps = [#map, #map1, #map2],
     iterator_types = ["parallel", "parallel", "reduction"],
     kind = #vector.kind<add>}
     %1, %2, %3 : !vecA, !vecB into !vecC
 
-  vector.transfer_write %4, %arg2[%c0, %c0] {in_bounds = [true, true]} : !vecC, !memrefC
+  %5 = arith.addi %4, %bias : !vecC
+
+  vector.transfer_write %5, %arg2[%c0, %c0] {in_bounds = [true, true]} : !vecC, !memrefC
 
   return %arg2 : !memrefC
 }
@@ -259,10 +263,11 @@ func.func @online_packing_int8(
 // CHECK: x86.amx.tile_load {{.*}} !x86.amx.tile<16x16xi32>
 // CHECK: x86.amx.tile_muli
 // CHECK: x86.amx.tile_store {{.*}} !x86.amx.tile<16x16xi32>
+// CHECK: vector.transfer_read
+// CHECK: arith.addi
+// CHECK: vector.transfer_write
 // CHECK-NOT: vector.contract
 
-
-
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
     %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
@@ -695,7 +700,80 @@ func.func @online_packing_bf16_loop(%arg0: memref<16x64x96xbf16>, %arg1: memref<
 // CHECK-NOT: scf.for {{.*}} vector<16x16xf32>, vector<16x16xf32>, vector<16x16xf32>, vector<16x16xf32>
 // CHECK-NOT: vector.contract
 
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.x86.vector_contract_to_amx_dot_product
+    } : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+!vecAB = vector<1x16x16x2xbf16>
+!vecC = vector<16x16xf32>
+!memrefA = memref<1x32x16x2xbf16, strided<[1024, 32, 2, 1], offset: ?>>
+!memrefB = memref<1x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>>
+
+#map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d2, d4, d1)>
+#map1 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d4, d3, d1)>
+#map2 = affine_map<(d0, d1, d2, d3, d4) -> (d2, d3)>
+
+func.func @brgemm_bf16_with_cano(%arg0: memref<16x32x16x2xbf16>, %arg1: memref<16x16x32x2xbf16>, %arg2: memref<32x32xf32>) -> memref<32x32xf32> {
+  %0 = ub.poison : f32
+  %1 = ub.poison : bf16
+  %c0 = arith.constant 0 : index
+  %c16 = arith.constant 16 : index
+  %c1 = arith.constant 1 : index
+  %2 = vector.transfer_read %arg2[%c0, %c0], %0 {in_bounds = [true, true]} : memref<32x32xf32>, !vecC
+  %3 = vector.transfer_read %arg2[%c0, %c16], %0 {in_bounds = [true, true]} : memref<32x32xf32>, !vecC
+  %4 = vector.transfer_read %arg2[%c16, %c0], %0 {in_bounds = [true, true]} : memref<32x32xf32>, !vecC
+  %5 = vector.transfer_read %arg2[%c16, %c16], %0 {in_bounds = [true, true]} : memref<32x32xf32>, !vecC
+
+  %6:4 = scf.for %arg3 = %c0 to %c16 step %c1 iter_args(%arg4 = %2, %arg5 = %3, %arg6 = %4, %arg7 = %5) -> (!vecC, !vecC, !vecC, !vecC) {
+
+    %subview = memref.subview %arg0[%arg3, 0, 0, 0] [1, 32, 16, 2] [1, 1, 1, 1] : memref<16x32x16x2xbf16> to !memrefA
+    %subview_0 = memref.subview %arg1[%arg3, 0, 0, 0] [1, 16, 32, 2] [1, 1, 1, 1] : memref<16x16x32x2xbf16> to !memrefB
+
+    %7 = vector.transfer_read %subview[%c0, %c0, %c0, %c0], %1 {in_bounds = [true, true, true, true]} : !memrefA, !vecAB
+    %8 = vector.transfer_read %subview[%c0, %c16, %c0, %c0], %1 {in_bounds = [true, true, true, true]} : !memrefA, !vecAB
+    %9 = vector.transfer_read %subview_0[%c0, %c0, %c0, %c0], %1 {in_bounds = [true, true, true, true]} : !memrefB, !vecAB
+    %10 = vector.transfer_read %subview_0[%c0, %c0, %c16, %c0], %1 {in_bounds = [true, true, true, true]} : !memrefB, !vecAB
+
+    %11 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types =
+        ["reduction", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+        %7, %9, %arg4 {unroll_shape = array<i64: 1, 2, 16, 16, 16>} : !vecAB, !vecAB into !vecC
+    %12 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types =
+        ["reduction", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+        %7, %10, %arg5 {unroll_shape = array<i64: 1, 2, 16, 16, 16>} : !vecAB, !vecAB into !vecC
+    %13 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types =
+        ["reduction", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+        %8, %9, %arg6 {unroll_shape = array<i64: 1, 2, 16, 16, 16>} : !vecAB, !vecAB into !vecC
+    %14 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types =
+        ["reduction", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+        %8, %10, %arg7 {unroll_shape = array<i64: 1, 2, 16, 16, 16>} : !vecAB, !vecAB into !vecC
+    scf.yield %11, %12, %13, %14 : !vecC, !vecC, !vecC, !vecC
+  }
 
+  vector.transfer_write %6#3, %arg2[%c16, %c16] {in_bounds = [true, true]} : !vecC, memref<32x32xf32>
+  vector.transfer_write %6#2, %arg2[%c16, %c0] {in_bounds = [true, true]} : !vecC, memref<32x32xf32>
+  vector.transfer_write %6#1, %arg2[%c0, %c16] {in_bounds = [true, true]} : !vecC, memref<32x32xf32>
+  vector.transfer_write %6#0, %arg2[%c0, %c0] {in_bounds = [true, true]} : !vecC, memref<32x32xf32>
+  %alloc = memref.alloc() : memref<32x32xf32>
+  memref.copy %arg2, %alloc : memref<32x32xf32> to memref<32x32xf32>
+  return %alloc : memref<32x32xf32>
+}
+
+// CHECK-LABEL: @brgemm_bf16_with_cano
+// CHECK-1: scf.for {{.*}} -> (!x86.amx.tile<16x16xf32>, !x86.amx.tile<16x16xf32>, !x86.amx.tile<16x16xf32>, !x86.amx.tile<16x16xf32>) {
+// CHECK-4: x86.amx.tile_zero : !x86.amx.tile<16x16xf32>
+// CHECK-4: x86.amx.tile_load
+// CHECK-4: x86.amx.tile_mulf
+// CHECK: scf.yield {{.*}} : !x86.amx.tile<16x16xf32>, !x86.amx.tile<16x16xf32>, !x86.amx.tile<16x16xf32>, !x86.amx.tile<16x16xf32>
+// CHECK-NOT: scf.for {{.*}} vector<16x16xf32>, vector<16x16xf32>, vector<16x16xf32>, vector<16x16xf32>
+// CHECK-NOT: vector.contract
 
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
@@ -728,6 +806,7 @@ func.func @online_packing_int8_matmul_loop(%arg0: memref<64x256xi8>, %arg1: memr
   %c128 = arith.constant 128 : index
   %c256 = arith.constant 256 : index
   %c32 = arith.constant 32 : index
+  %bias = arith.constant dense<13> : !vecC
   scf.for %arg3 = %c0 to %c64 step %c32 {
     scf.for %arg4 = %c0 to %c128 step %c32 {
       %subview = memref.subview %arg2[%arg3, %arg4] [32, 32] [1, 1] : memref<64x128xi32> to !memrefC
@@ -756,10 +835,16 @@ func.func @online_packing_int8_matmul_loop(%arg0: memref<64x256xi8>, %arg1: memr
                 %8, %10, %arg9 {unroll_shape = array<i64: 16, 16, 64>} : !vecA, !vecB into !vecC
         scf.yield %11, %12, %13, %14 : !vecC, !vecC, !vecC, !vecC
       }
-      vector.transfer_write %6#3, %subview[%c16, %c16] {in_bounds = [true, true]} : !vecC, !memrefC
-      vector.transfer_write %6#2, %subview[%c16, %c0] {in_bounds = [true, true]} : !vecC, !memrefC
-      vector.transfer_write %6#1, %subview[%c0, %c16] {in_bounds = [true, true]} : !vecC, !memrefC
-      vector.transfer_write %6#0, %subview[%c0, %c0] {in_bounds = [true, true]} : !vecC, !memrefC
+
+      %7 = arith.addi %6#3, %bias : !vecC
+      %8 = arith.addi %6#2, %bias : !vecC
+      %9 = arith.addi %6#1, %bias : !vecC
+      %10 = arith.addi %6#0, %bias : !vecC
+
+      vector.transfer_write %7, %subview[%c16, %c16] {in_bounds = [true, true]} : !vecC, !memrefC
+      vector.transfer_write %8, %subview[%c16, %c0] {in_bounds = [true, true]} : !vecC, !memrefC
+      vector.transfer_write %9, %subview[%c0, %c16] {in_bounds = [true, true]} : !vecC, !memrefC
+      vector.transfer_write %10, %subview[%c0, %c0] {in_bounds = [true, true]} : !vecC, !memrefC
     }
   }
   %alloc = memref.alloc() : memref<64x128xi32>
@@ -777,6 +862,7 @@ func.func @online_packing_int8_matmul_loop(%arg0: memref<64x256xi8>, %arg1: memr
 // CHECK: scf.yield {{.*}} !x86.amx.tile<16x16xi32>, !x86.amx.tile<16x16xi32>, !x86.amx.tile<16x16xi32>, !x86.amx.tile<16x16xi32>
 // CHECK: vector.shuffle{{.*}}[0, 1, 2, 3, 16, 17, 18, 19, 4, 5, 6, 7, 20, 21, 22, 23] : vector<16xi32>, vector<16xi32>
 // CHECK-NEXT: vector.shuffle{{.*}}[8, 9, 10, 11, 24, 25, 26, 27, 12, 13, 14, 15, 28, 29, 30, 31] : vector<16xi32>, vector<16xi32>
+// CHECK-COUNT-4: arith.addi
 // CHECK-NOT: scf.for {{.*}} vector<16x16xi32>, vector<16x16xi32>, vector<16x16xi32>, vector<16x16xi32>
 // CHECK-NOT: vector.contract
 

>From 4dbcaf4def5e6039ce05b972de8691379535c171 Mon Sep 17 00:00:00 2001
From: Arun Thangamani <arun.thangamani at intel.com>
Date: Tue, 12 May 2026 07:52:25 -0700
Subject: [PATCH 03/10] enable support for f8 type

---
 .../VectorContractToAMXDotProduct.cpp         |  26 ++-
 .../X86/AMX/vector-contract-to-tiled-dp.mlir  | 219 ++++++++++++++++++
 2 files changed, 238 insertions(+), 7 deletions(-)

diff --git a/mlir/lib/Dialect/X86/Transforms/VectorContractToAMXDotProduct.cpp b/mlir/lib/Dialect/X86/Transforms/VectorContractToAMXDotProduct.cpp
index 64e5a6b56504b..595f585ad2d61 100644
--- a/mlir/lib/Dialect/X86/Transforms/VectorContractToAMXDotProduct.cpp
+++ b/mlir/lib/Dialect/X86/Transforms/VectorContractToAMXDotProduct.cpp
@@ -285,7 +285,8 @@ static void performShuffle(OpBuilder &rewriter, Location loc, Value matB,
                                 23, 55, 28, 60, 29, 61, 30, 62, 31, 63});
         }
 
-        if (ipType.isSignlessInteger(8)) {
+        if (ipType.isSignlessInteger(8) || ipType.isF8E5M2() ||
+            ipType.isF8E4M3FN()) {
 
           shuffle1 = vector::ShuffleOp::create(
               rewriter, loc, VectorType::get({(16 * offset)}, ipType), vec1,
@@ -419,7 +420,7 @@ createTiledDp(OpBuilder &rewriter, Location loc,
     auto accTileType = amx::TileType::get({16, 16}, opType);
 
     Value dp;
-    if (ipType.isBF16())
+    if (ipType.isBF16() || ipType.isF8E5M2() || ipType.isF8E4M3FN())
       dp = amx::TileMulFOp::create(rewriter, loc, accTileType, tilesLhs,
                                    tilesRhs, accIterArgs[i]);
 
@@ -725,15 +726,20 @@ struct VectorContractToAMXDotProduct
 
     VectorType lhsTy = contractOp.getLhsType();
     if (!lhsTy.getElementType().isBF16() &&
-        !lhsTy.getElementType().isSignlessInteger(8))
+        !lhsTy.getElementType().isSignlessInteger(8) &&
+        !lhsTy.getElementType().isF8E4M3FN() &&
+        !lhsTy.getElementType().isF8E5M2())
       return rewriter.notifyMatchFailure(
-          contractOp, "Only BF16/Int8 lowering is supported.");
+          contractOp, "Only BF16/Int8/F8 lowering is supported.");
 
     VectorType accTy = dyn_cast<VectorType>(contractOp.getAccType());
     if (!accTy)
       return rewriter.notifyMatchFailure(contractOp, "Wrong accmulator type.");
 
-    if ((lhsTy.getElementType().isBF16() && !accTy.getElementType().isF32()) ||
+    if (((lhsTy.getElementType().isBF16() ||
+          lhsTy.getElementType().isF8E4M3FN() ||
+          lhsTy.getElementType().isF8E5M2()) &&
+         !accTy.getElementType().isF32()) ||
         (lhsTy.getElementType().isSignlessInteger(8) &&
          !accTy.getElementType().isSignlessInteger(32)))
       return rewriter.notifyMatchFailure(contractOp,
@@ -760,6 +766,12 @@ struct VectorContractToAMXDotProduct
       opType = rewriter.getIntegerType(32);
     }
 
+    if (lhsTy.getElementType().isF8E4M3FN())
+      ipType = rewriter.getF8E4M3FNType();
+
+    if (lhsTy.getElementType().isF8E5M2())
+      ipType = rewriter.getF8E5M2Type();
+
     if (accReadOp->getBlock() == contractOp->getBlock() &&
         resultWriteOp->getBlock() != contractOp->getBlock())
       return rewriter.notifyMatchFailure(
@@ -932,7 +944,7 @@ struct VectorContractToAMXDotProduct
 
       // Tiled dot-product.
       Value dp;
-      if (ipType.isBF16())
+      if (ipType.isBF16() || ipType.isF8E5M2() || ipType.isF8E4M3FN())
         dp = amx::TileMulFOp::create(rewriter, loc, tileTypeAcc, loadLhs,
                                      loadRhs, loadAcc);
 
@@ -1359,7 +1371,7 @@ struct VectorContractToAMXDotProduct
           Value addOp;
           Value addOp2;
 
-          if (ipType.isBF16()) {
+          if (ipType.isBF16() || ipType.isF8E5M2() || ipType.isF8E4M3FN()) {
             addOp = arith::AddFOp::create(rewriter, loc, shuffle1, valueCRow1);
 
             addOp2 = arith::AddFOp::create(rewriter, loc, shuffle2, valueCRow2);
diff --git a/mlir/test/Dialect/X86/AMX/vector-contract-to-tiled-dp.mlir b/mlir/test/Dialect/X86/AMX/vector-contract-to-tiled-dp.mlir
index 20d269fd6ff88..1ebb7010050ab 100644
--- a/mlir/test/Dialect/X86/AMX/vector-contract-to-tiled-dp.mlir
+++ b/mlir/test/Dialect/X86/AMX/vector-contract-to-tiled-dp.mlir
@@ -216,6 +216,60 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
+!vecA = vector<1x16x16x4xf8E5M2>
+!vecB = vector<1x16x16x4xf8E5M2>
+!vecC = vector<16x16xf32>
+!memrefA = memref<1x32x16x4xf8E5M2>
+!memrefB = memref<1x16x32x4xf8E5M2>
+!memrefC = memref<32x32xf32>
+#map = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d3, d4)>
+#map1 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d3, d2, d4)>
+#map2 = affine_map<(d0, d4, d1, d2, d3) -> (d1, d2)>
+func.func @brgemm_f8E5M2(
+  %arg0: !memrefA, %arg1: !memrefB, %arg2: !memrefC) -> !memrefC
+{
+  %c0 = arith.constant 0 : index
+  %0 = ub.poison : f8E5M2
+  %32 = ub.poison : f32
+
+  %1 = vector.transfer_read %arg0[%c0, %c0, %c0, %c0], %0 {in_bounds = [true, true, true, true]} :
+        !memrefA, !vecA
+  %2 = vector.transfer_read %arg1[%c0, %c0, %c0, %c0], %0 {in_bounds = [true, true, true, true]} :
+        !memrefB, !vecB
+
+  %3 = vector.transfer_read %arg2[%c0, %c0], %32 {in_bounds = [true, true]} : !memrefC, !vecC
+
+  %4 = vector.contract {
+    indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["reduction", "reduction", "parallel", "parallel", "reduction"],
+    kind = #vector.kind<add>}
+    %1, %2, %3 : !vecA, !vecB into !vecC
+
+  vector.transfer_write %4, %arg2[%c0, %c0] {in_bounds = [true, true]} : !vecC, !memrefC
+
+  return %arg2 : !memrefC
+}
+
+// CHECK-LABEL: @brgemm_f8E5M2
+// CHECK: x86.amx.tile_load {{.*}} !x86.amx.tile<16x64xf8E5M2>
+// CHECK: x86.amx.tile_load {{.*}} !x86.amx.tile<16x64xf8E5M2>
+// CHECK: x86.amx.tile_load {{.*}} !x86.amx.tile<16x16xf32>
+// CHECK: x86.amx.tile_mulf
+// CHECK: x86.amx.tile_store {{.*}} !x86.amx.tile<16x16xf32>
+// CHECK-NOT: vector.contract
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.x86.vector_contract_to_amx_dot_product
+    } : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
 !vecA = vector<16x64xi8>
 !vecB = vector<64x16xi8>
 !vecC = vector<16x16xi32>
@@ -523,6 +577,88 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
+!vecAB = vector<16x16x4xf8E4M3FN>
+!vecC = vector<16x16xf32>
+!memrefA = memref<16x16x4xf8E4M3FN, strided<[256, 4, 1], offset: ?>>
+!memrefB = memref<16x32x4xf8E4M3FN, strided<[512, 4, 1], offset: ?>>
+!memrefC = memref<16x32xf32, strided<[128, 1], offset: ?>>
+
+#map = affine_map<(d0, d1, d2, d3) -> (d1, d3, d0)>
+#map1 = affine_map<(d0, d1, d2, d3) -> (d3, d2, d0)>
+#map2 = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+
+func.func @matmul_f8E4M3FN_loop(%arg0: memref<64x64x4xf8E4M3FN>, %arg1: memref<64x128x4xf8E4M3FN>, %arg2: memref<64x128xf32>) {
+  %0 = ub.poison : f32
+  %1 = ub.poison : f8E4M3FN
+  %c0 = arith.constant 0 : index
+  %c64 = arith.constant 64 : index
+  %c128 = arith.constant 128 : index
+  %c16 = arith.constant 16 : index
+  %c32 = arith.constant 32 : index
+  scf.for %arg3 = %c0 to %c64 step %c16 {
+    scf.for %arg4 = %c0 to %c128 step %c32 {
+
+      %subview = memref.subview %arg2[%arg3, %arg4] [16, 32] [1, 1] :
+                memref<64x128xf32> to !memrefC
+      %2 = vector.transfer_read %subview[%c0, %c0], %0 {in_bounds = [true, true]} :
+                !memrefC, !vecC
+      %3 = vector.transfer_read %subview[%c0, %c16], %0 {in_bounds = [true, true]} :
+                !memrefC, !vecC
+
+      %4:2 = scf.for %arg5 = %c0 to %c64 step %c16 iter_args(%arg6 = %2, %arg7 = %3) -> (!vecC, !vecC) {
+
+        %subview_0 = memref.subview %arg0[%arg3, %arg5, 0] [16, 16, 4] [1, 1, 1] :
+                memref<64x64x4xf8E4M3FN> to !memrefA
+        %subview_1 = memref.subview %arg1[%arg5, %arg4, 0] [16, 32, 4] [1, 1, 1] :
+                memref<64x128x4xf8E4M3FN> to !memrefB
+        %5 = vector.transfer_read %subview_0[%c0, %c0, %c0], %1 {in_bounds = [true, true, true]} :
+                !memrefA, !vecAB
+        %6 = vector.transfer_read %subview_1[%c0, %c0, %c0], %1 {in_bounds = [true, true, true]} :
+                !memrefB, !vecAB
+        %7 = vector.transfer_read %subview_1[%c0, %c16, %c0], %1 {in_bounds = [true, true, true]} :
+                !memrefB, !vecAB
+
+        %8 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types =
+                ["reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+                %5, %6, %arg6 {unroll_shape = array<i64: 4, 16, 16, 16>} : !vecAB, !vecAB into !vecC
+        %9 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types =
+                ["reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+                %5, %7, %arg7 {unroll_shape = array<i64: 4, 16, 16, 16>} : !vecAB, !vecAB into !vecC
+
+        scf.yield %8, %9 : !vecC, !vecC
+      }
+
+      vector.transfer_write %4#1, %subview[%c0, %c16] {in_bounds = [true, true]} :
+                !vecC, !memrefC
+      vector.transfer_write %4#0, %subview[%c0, %c0] {in_bounds = [true, true]} :
+                !vecC, !memrefC
+    }
+  }
+
+  return
+}
+
+// CHECK-LABEL: @matmul_f8E4M3FN_loop
+// CHECK-2: x86.amx.tile_zero : !x86.amx.tile<16x16xf32>
+// CHECK: scf.for {{.*}} -> (!x86.amx.tile<16x16xf32>, !x86.amx.tile<16x16xf32>) {
+// CHECK-3: x86.amx.tile_load
+// CHECK-2: x86.amx.tile_mulf
+// CHECK: scf.yield {{.*}} !x86.amx.tile<16x16xf32>, !x86.amx.tile<16x16xf32>
+// CHECK-NOT: scf.for {{.*}} vector<16x16xf32>, vector<16x16xf32>
+// CHECK-NOT: vector.contract
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.x86.vector_contract_to_amx_dot_product
+    } : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
 !vecAB = vector<1x16x16x4xi8>
 !vecC = vector<1x16x16xi32>
 !memrefA = memref<1x16x16x4xi8, strided<[16384, 256, 4, 1], offset: ?>>
@@ -712,6 +848,89 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
+!vecA = vector<16x64xf8E5M2>
+!vecB = vector<64x16xf8E5M2>
+!vecC = vector<16x16xf32>
+!memrefA = memref<32x64xf8E5M2, strided<[256, 1], offset: ?>>
+!memrefB = memref<64x32xf8E5M2, strided<[128, 1], offset: ?>>
+!memrefC = memref<32x32xf32, strided<[128, 1], offset: ?>>
+
+#map = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
+func.func @online_packing_int8_matmul_loop(%arg0: memref<64x256xf8E5M2>, %arg1: memref<256x128xf8E5M2>, %arg2: memref<64x128xf32>) -> memref<64x128xf32> {
+  %c16 = arith.constant 16 : index
+  %0 = ub.poison : f32
+  %1 = ub.poison : f8E5M2
+  %c0 = arith.constant 0 : index
+  %c64 = arith.constant 64 : index
+  %c128 = arith.constant 128 : index
+  %c256 = arith.constant 256 : index
+  %c32 = arith.constant 32 : index
+  scf.for %arg3 = %c0 to %c64 step %c32 {
+    scf.for %arg4 = %c0 to %c128 step %c32 {
+      %subview = memref.subview %arg2[%arg3, %arg4] [32, 32] [1, 1] : memref<64x128xf32> to !memrefC
+      %2 = vector.transfer_read %subview[%c0, %c0], %0 {in_bounds = [true, true]} : !memrefC, !vecC
+      %3 = vector.transfer_read %subview[%c0, %c16], %0 {in_bounds = [true, true]} : !memrefC, !vecC
+      %4 = vector.transfer_read %subview[%c16, %c0], %0 {in_bounds = [true, true]} : !memrefC, !vecC
+      %5 = vector.transfer_read %subview[%c16, %c16], %0 {in_bounds = [true, true]} : !memrefC, !vecC
+      %6:4 = scf.for %arg5 = %c0 to %c256 step %c64 iter_args(%arg6 = %2, %arg7 = %3, %arg8 = %4, %arg9 = %5) -> (!vecC, !vecC, !vecC, !vecC) {
+        %subview_0 = memref.subview %arg0[%arg3, %arg5] [32, 64] [1, 1] : memref<64x256xf8E5M2> to !memrefA
+        %subview_1 = memref.subview %arg1[%arg5, %arg4] [64, 32] [1, 1] : memref<256x128xf8E5M2> to !memrefB
+        %7 = vector.transfer_read %subview_0[%c0, %c0], %1 {in_bounds = [true, true]} : !memrefA, !vecA
+        %8 = vector.transfer_read %subview_0[%c16, %c0], %1 {in_bounds = [true, true]} : !memrefA, !vecA
+        %9 = vector.transfer_read %subview_1[%c0, %c0], %1 {in_bounds = [true, true]} : !memrefB, !vecB
+        %10 = vector.transfer_read %subview_1[%c0, %c16], %1 {in_bounds = [true, true]} : !memrefB, !vecB
+        %11 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types =
+                ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+                %7, %9, %arg6 {unroll_shape = array<i64: 16, 16, 64>} : !vecA, !vecB into !vecC
+        %12 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types =
+                ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+                %7, %10, %arg7 {unroll_shape = array<i64: 16, 16, 64>} : !vecA, !vecB into !vecC
+        %13 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types =
+                ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+                %8, %9, %arg8 {unroll_shape = array<i64: 16, 16, 64>} : !vecA, !vecB into !vecC
+        %14 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types =
+                ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+                %8, %10, %arg9 {unroll_shape = array<i64: 16, 16, 64>} : !vecA, !vecB into !vecC
+        scf.yield %11, %12, %13, %14 : !vecC, !vecC, !vecC, !vecC
+      }
+      vector.transfer_write %6#3, %subview[%c16, %c16] {in_bounds = [true, true]} : !vecC, !memrefC
+      vector.transfer_write %6#2, %subview[%c16, %c0] {in_bounds = [true, true]} : !vecC, !memrefC
+      vector.transfer_write %6#1, %subview[%c0, %c16] {in_bounds = [true, true]} : !vecC, !memrefC
+      vector.transfer_write %6#0, %subview[%c0, %c0] {in_bounds = [true, true]} : !vecC, !memrefC
+    }
+  }
+  %alloc = memref.alloc() : memref<64x128xf32>
+  memref.copy %arg2, %alloc : memref<64x128xf32> to memref<64x128xf32>
+  return %alloc : memref<64x128xf32>
+}
+
+// CHECK-LABEL: @online_packing_int8_matmul_loop
+// CHECK-COUNT-4: x86.amx.tile_zero : !x86.amx.tile<16x16xf32>
+// CHECK: scf.for {{.*}} -> (!x86.amx.tile<16x16xf32>, !x86.amx.tile<16x16xf32>, !x86.amx.tile<16x16xf32>, !x86.amx.tile<16x16xf32>) {
+// CHECK: vector.shuffle{{.*}}[0, 32, 64, 96, 1, 33, 65, 97, 2, 34, 66, 98, 3, 35, 67, 99, 8, 40, 72, 104, 9, 41, 73, 105, 10, 42, 74, 106, 11, 43, 75, 107, 16, 48, 80, 112, 17, 49, 81, 113, 18, 50, 82, 114, 19, 51, 83, 115, 24, 56, 88, 120, 25, 57, 89, 121, 26, 58, 90, 122, 27, 59, 91, 123] : vector<64xf8E5M2>, vector<64xf8E5M2>
+// CHECK-NEXT: vector.shuffle{{.*}}[4, 36, 68, 100, 5, 37, 69, 101, 6, 38, 70, 102, 7, 39, 71, 103, 12, 44, 76, 108, 13, 45, 77, 109, 14, 46, 78, 110, 15, 47, 79, 111, 20, 52, 84, 116, 21, 53, 85, 117, 22, 54, 86, 118, 23, 55, 87, 119, 28, 60, 92, 124, 29, 61, 93, 125, 30, 62, 94, 126, 31, 63, 95, 127] : vector<64xf8E5M2>, vector<64xf8E5M2>
+// CHECK: x86.amx.tile_load
+// CHECK: x86.amx.tile_mulf
+// CHECK: scf.yield {{.*}} !x86.amx.tile<16x16xf32>, !x86.amx.tile<16x16xf32>, !x86.amx.tile<16x16xf32>, !x86.amx.tile<16x16xf32>
+// CHECK: vector.shuffle{{.*}}[0, 1, 2, 3, 16, 17, 18, 19, 4, 5, 6, 7, 20, 21, 22, 23] : vector<16xf32>, vector<16xf32>
+// CHECK-NEXT: vector.shuffle{{.*}}[8, 9, 10, 11, 24, 25, 26, 27, 12, 13, 14, 15, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32>
+// CHECK-NOT: scf.for {{.*}} vector<16x16xf32>, vector<16x16xf32>, vector<16x16xf32>, vector<16x16xf32>
+// CHECK-NOT: vector.contract
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.x86.vector_contract_to_amx_dot_product
+    } : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
 !vecAB = vector<1x16x16x2xbf16>
 !vecC = vector<16x16xf32>
 !memrefA = memref<1x32x16x2xbf16, strided<[1024, 32, 2, 1], offset: ?>>

>From 4d6cc4949aad936848ccb094184775ce964ec210 Mon Sep 17 00:00:00 2001
From: Arun Thangamani <arun.thangamani at intel.com>
Date: Tue, 12 May 2026 22:39:12 -0700
Subject: [PATCH 04/10] relaxed a condition on ShapeCast which is not needed.

---
 .../VectorContractToAMXDotProduct.cpp         | 34 +++++++++----------
 mlir/lib/Dialect/X86/Utils/X86Utils.cpp       |  2 +-
 2 files changed, 17 insertions(+), 19 deletions(-)

diff --git a/mlir/lib/Dialect/X86/Transforms/VectorContractToAMXDotProduct.cpp b/mlir/lib/Dialect/X86/Transforms/VectorContractToAMXDotProduct.cpp
index 595f585ad2d61..34f6298c00b34 100644
--- a/mlir/lib/Dialect/X86/Transforms/VectorContractToAMXDotProduct.cpp
+++ b/mlir/lib/Dialect/X86/Transforms/VectorContractToAMXDotProduct.cpp
@@ -27,23 +27,20 @@ using namespace mlir::x86;
 
 namespace {
 
-static Value traceToVectorWriteLikeUserOperationForAMX(Value v) {
+static Value contractionUsersAfterYield(Value v) {
   if (v.getNumUses() > 1)
     return nullptr;
 
   for (OpOperand &use : v.getUses()) {
     Operation *user = use.getOwner();
 
-    if (!isa<scf::YieldOp>(user)) {
+    if (!isa<scf::YieldOp>(user))
       return v;
-    }
 
-    // --- SCF YIELD ---
     if (auto yield = dyn_cast<scf::YieldOp>(user)) {
       Operation *parent = yield->getParentOp();
       unsigned idx = use.getOperandNumber();
-      if (auto res =
-              traceToVectorWriteLikeUserOperationForAMX(parent->getResult(idx)))
+      if (auto res = contractionUsersAfterYield(parent->getResult(idx)))
         return res;
       continue;
     }
@@ -494,7 +491,9 @@ createLoops(OpBuilder &rewriter, Location loc, Value lowerBound,
   Value c1 = arith::ConstantIndexOp::create(rewriter, loc, 1);
   Value c2 = arith::ConstantIndexOp::create(rewriter, loc, 2);
 
-  int64_t offset = step.getDefiningOp<arith::ConstantIndexOp>().value();
+  int64_t offset = 16 * blockingFactor;
+  if (auto cst = step.getDefiningOp<arith::ConstantIndexOp>())
+    offset = cst.value();
 
   auto newLoop = scf::ForOp::create(
       rewriter, loc, lowerBound, upperBound, step, loopItrArgs,
@@ -970,8 +969,7 @@ struct VectorContractToAMXDotProduct
           rewriter, loc, flatTy, resultBuffer, ValueRange{c0, c0}, padding, map,
           inBounds);
 
-      Value resultOp =
-          traceToVectorWriteLikeUserOperationForAMX(contractOp.getResult());
+      Value resultOp = contractionUsersAfterYield(contractOp.getResult());
       if (auto vecType = llvm::dyn_cast<VectorType>(resultOp.getType())) {
         vecRow =
             mlir::vector::ShapeCastOp::create(rewriter, loc, vecType, vecRow);
@@ -1236,15 +1234,16 @@ struct VectorContractToAMXDotProduct
             (((ubVal / (16 * blockingFactor)) % 2) == 1) && isInnerLoopUBLarger;
 
         rewriter.setInsertionPoint(innerLoop);
+
         auto c0 =
             arith::ConstantIndexOp::create(rewriter, innerLoop.getLoc(), 0);
-
-        int64_t stepVal =
-            innerLoop.getStep().getDefiningOp<arith::ConstantIndexOp>().value();
+        int64_t offset = 16 * blockingFactor;
+        if (auto cst =
+                innerLoop.getStep().getDefiningOp<arith::ConstantIndexOp>())
+          offset = cst.value();
 
         auto spillLoopBound = arith::ConstantIndexOp::create(
-            rewriter, innerLoop.getLoc(), stepVal);
-
+            rewriter, innerLoop.getLoc(), offset);
         Value spillInnerLoop =
             arith::SubIOp::create(rewriter, innerLoop.getLoc(),
                                   innerLoop.getUpperBound(), spillLoopBound);
@@ -1417,12 +1416,11 @@ struct VectorContractToAMXDotProduct
       vector::ContractionOp contOp = ops[i];
       Value vecRoc = writeResults[i];
 
-      Value resultWriteOp =
-          traceToVectorWriteLikeUserOperationForAMX(contOp.getResult());
-      if (auto vecType = llvm::dyn_cast<VectorType>(resultWriteOp.getType())) {
+      Value resultWriteOp = contractionUsersAfterYield(contOp.getResult());
+      if (auto vecType = llvm::dyn_cast<VectorType>(resultWriteOp.getType()))
         vecRoc = mlir::vector::ShapeCastOp::create(rewriter, loc, vecType,
                                                    writeResults[i]);
-      }
+
       resultWriteOp.replaceAllUsesWith(vecRoc);
     }
 
diff --git a/mlir/lib/Dialect/X86/Utils/X86Utils.cpp b/mlir/lib/Dialect/X86/Utils/X86Utils.cpp
index aea6bf6adcd4a..2cd6012b9ec03 100644
--- a/mlir/lib/Dialect/X86/Utils/X86Utils.cpp
+++ b/mlir/lib/Dialect/X86/Utils/X86Utils.cpp
@@ -203,7 +203,7 @@ Operation *traceToVectorWriteLikeUserOperation(Value v) {
     if (isa<vector::TransferWriteOp>(user) || isa<vector::StoreOp>(user))
       return user;
 
-    if (isa<vector::ShapeCastOp, vector::ShuffleOp>(user))
+    if (isa<vector::ShuffleOp>(user))
       return nullptr;
 
     // --- SCF YIELD ---

>From 2574d27a1e094aa9715176fad8f2c524113b9970 Mon Sep 17 00:00:00 2001
From: Arun Thangamani <arun.thangamani at intel.com>
Date: Wed, 13 May 2026 20:42:22 -0700
Subject: [PATCH 05/10] extra validation on contraction lhs and rhs type.

---
 .../Dialect/X86/Transforms/VectorContractToAMXDotProduct.cpp  | 4 ++++
 mlir/test/Dialect/X86/AMX/vector-contract-to-tiled-dp.mlir    | 4 ++--
 2 files changed, 6 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/X86/Transforms/VectorContractToAMXDotProduct.cpp b/mlir/lib/Dialect/X86/Transforms/VectorContractToAMXDotProduct.cpp
index 34f6298c00b34..e31c0f6c7586e 100644
--- a/mlir/lib/Dialect/X86/Transforms/VectorContractToAMXDotProduct.cpp
+++ b/mlir/lib/Dialect/X86/Transforms/VectorContractToAMXDotProduct.cpp
@@ -731,6 +731,10 @@ struct VectorContractToAMXDotProduct
       return rewriter.notifyMatchFailure(
           contractOp, "Only BF16/Int8/F8 lowering is supported.");
 
+    if (lhsTy.getElementType() != contractOp.getRhsType().getElementType())
+      return rewriter.notifyMatchFailure(
+          contractOp, "Contraction should have same lhs and rhs type.");
+
     VectorType accTy = dyn_cast<VectorType>(contractOp.getAccType());
     if (!accTy)
       return rewriter.notifyMatchFailure(contractOp, "Wrong accmulator type.");
diff --git a/mlir/test/Dialect/X86/AMX/vector-contract-to-tiled-dp.mlir b/mlir/test/Dialect/X86/AMX/vector-contract-to-tiled-dp.mlir
index 1ebb7010050ab..d8950bd8baaf7 100644
--- a/mlir/test/Dialect/X86/AMX/vector-contract-to-tiled-dp.mlir
+++ b/mlir/test/Dialect/X86/AMX/vector-contract-to-tiled-dp.mlir
@@ -858,7 +858,7 @@ module attributes {transform.with_named_sequence} {
 #map = affine_map<(d0, d1, d2) -> (d0, d2)>
 #map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
 #map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
-func.func @online_packing_int8_matmul_loop(%arg0: memref<64x256xf8E5M2>, %arg1: memref<256x128xf8E5M2>, %arg2: memref<64x128xf32>) -> memref<64x128xf32> {
+func.func @online_packing_f8E5M2_matmul_loop(%arg0: memref<64x256xf8E5M2>, %arg1: memref<256x128xf8E5M2>, %arg2: memref<64x128xf32>) -> memref<64x128xf32> {
   %c16 = arith.constant 16 : index
   %0 = ub.poison : f32
   %1 = ub.poison : f8E5M2
@@ -906,7 +906,7 @@ func.func @online_packing_int8_matmul_loop(%arg0: memref<64x256xf8E5M2>, %arg1:
   return %alloc : memref<64x128xf32>
 }
 
-// CHECK-LABEL: @online_packing_int8_matmul_loop
+// CHECK-LABEL: @online_packing_f8E5M2_matmul_loop
 // CHECK-COUNT-4: x86.amx.tile_zero : !x86.amx.tile<16x16xf32>
 // CHECK: scf.for {{.*}} -> (!x86.amx.tile<16x16xf32>, !x86.amx.tile<16x16xf32>, !x86.amx.tile<16x16xf32>, !x86.amx.tile<16x16xf32>) {
 // CHECK: vector.shuffle{{.*}}[0, 32, 64, 96, 1, 33, 65, 97, 2, 34, 66, 98, 3, 35, 67, 99, 8, 40, 72, 104, 9, 41, 73, 105, 10, 42, 74, 106, 11, 43, 75, 107, 16, 48, 80, 112, 17, 49, 81, 113, 18, 50, 82, 114, 19, 51, 83, 115, 24, 56, 88, 120, 25, 57, 89, 121, 26, 58, 90, 122, 27, 59, 91, 123] : vector<64xf8E5M2>, vector<64xf8E5M2>

>From 433ebe854bdf4533903b76a71c49049ea5a676bf Mon Sep 17 00:00:00 2001
From: Arun Thangamani <arun.thangamani at intel.com>
Date: Wed, 13 May 2026 21:01:03 -0700
Subject: [PATCH 06/10] clean-up.

---
 .../X86/Transforms/VectorContractToAMXDotProduct.cpp  | 11 +++++------
 1 file changed, 5 insertions(+), 6 deletions(-)

diff --git a/mlir/lib/Dialect/X86/Transforms/VectorContractToAMXDotProduct.cpp b/mlir/lib/Dialect/X86/Transforms/VectorContractToAMXDotProduct.cpp
index e31c0f6c7586e..a425d1534d836 100644
--- a/mlir/lib/Dialect/X86/Transforms/VectorContractToAMXDotProduct.cpp
+++ b/mlir/lib/Dialect/X86/Transforms/VectorContractToAMXDotProduct.cpp
@@ -581,9 +581,8 @@ createLoops(OpBuilder &rewriter, Location loc, Value lowerBound,
             if (outerPos >= 0) {
               unsigned operandIdx = static_cast<unsigned>(outerPos + 1);
 
-              if (operandIdx < rhsOp->getNumOperands()) {
+              if (operandIdx < rhsOp->getNumOperands())
                 rhsMapping.map(rhsOp->getOperand(operandIdx), ivOuterLoop);
-              }
             }
           }
 
@@ -592,19 +591,18 @@ createLoops(OpBuilder &rewriter, Location loc, Value lowerBound,
           if (innerPos >= 0) {
             unsigned operandIdx = static_cast<unsigned>(innerPos + 1);
 
-            if (operandIdx < rhsOp->getNumOperands()) {
+            if (operandIdx < rhsOp->getNumOperands())
               rhsMapping.map(rhsOp->getOperand(operandIdx), ivNewInnerLoop);
-            }
           }
 
           auto rhsClone = rewriterNewInnerLoop.clone(*rhsOp, rhsMapping);
-
           matB = rhsClone->getResult(0);
 
         } else {
           // memref.get_global / constants
           matB = rhsOp->getResult(0);
         }
+
         if (!isVnni) {
           if (outerLoop) {
             if (!pack) {
@@ -1332,7 +1330,7 @@ struct VectorContractToAMXDotProduct
     auto mBound = arith::ConstantIndexOp::create(rewriter, loc, N);
 
     // Create a loop that iterates over the MxN memerf, retrives two rows +
-    // shuffle them, add up the C element values and stores them back.
+    // shuffle them, add up the C element values and stores them to temp buffer.
     scf::ForOp::create(
         rewriter, loc, c0, mBound, one, ValueRange{},
         [&](OpBuilder &nestedBuilder, Location loc, Value iv,
@@ -1416,6 +1414,7 @@ struct VectorContractToAMXDotProduct
       }
     }
 
+    // Replace use of vector.contract with dot-products.
     for (size_t i = 0; i < ops.size(); i++) {
       vector::ContractionOp contOp = ops[i];
       Value vecRoc = writeResults[i];

>From 22473a87e0fcaac055971b48516f4b972d4085a7 Mon Sep 17 00:00:00 2001
From: Arun Thangamani <arun.thangamani at intel.com>
Date: Fri, 15 May 2026 09:30:25 -0700
Subject: [PATCH 07/10] minor fix in int8  shuffling

---
 .../lib/Dialect/X86/Transforms/VectorContractToAMXDotProduct.cpp | 1 +
 1 file changed, 1 insertion(+)

diff --git a/mlir/lib/Dialect/X86/Transforms/VectorContractToAMXDotProduct.cpp b/mlir/lib/Dialect/X86/Transforms/VectorContractToAMXDotProduct.cpp
index a425d1534d836..6c02d5229119e 100644
--- a/mlir/lib/Dialect/X86/Transforms/VectorContractToAMXDotProduct.cpp
+++ b/mlir/lib/Dialect/X86/Transforms/VectorContractToAMXDotProduct.cpp
@@ -876,6 +876,7 @@ struct VectorContractToAMXDotProduct
                   arith::AddIOp::create(rewriter, loc, nextLoadIndx, iv);
 
               indicesRhs[indicesRhs.size() - 2] = iv;
+              indicesRhs[indicesRhs.size() - 1] = c0;
               ValueRange range1(indicesRhs);
               auto vec1 = vector::LoadOp::create(
                   rewriter, loc,

>From a274536b6dc589b9abfa57d5bf78e11f8900001c Mon Sep 17 00:00:00 2001
From: Arun Thangamani <arun.thangamani at intel.com>
Date: Wed, 20 May 2026 23:25:12 -0700
Subject: [PATCH 08/10] clean-up, comments and use of rewriter

---
 .../VectorContractToAMXDotProduct.cpp         | 54 +++++++++----------
 mlir/lib/Dialect/X86/Utils/X86Utils.cpp       |  4 +-
 2 files changed, 29 insertions(+), 29 deletions(-)

diff --git a/mlir/lib/Dialect/X86/Transforms/VectorContractToAMXDotProduct.cpp b/mlir/lib/Dialect/X86/Transforms/VectorContractToAMXDotProduct.cpp
index 6c02d5229119e..4b1f9adfde76d 100644
--- a/mlir/lib/Dialect/X86/Transforms/VectorContractToAMXDotProduct.cpp
+++ b/mlir/lib/Dialect/X86/Transforms/VectorContractToAMXDotProduct.cpp
@@ -27,26 +27,23 @@ using namespace mlir::x86;
 
 namespace {
 
+// Recursively follows single-use values through scf.yield operations
+// and returns the first non-yield user result in the contraction chain.
 static Value contractionUsersAfterYield(Value v) {
-  if (v.getNumUses() > 1)
+  if (v.getNumUses() != 1)
     return nullptr;
 
-  for (OpOperand &use : v.getUses()) {
-    Operation *user = use.getOwner();
+  OpOperand &use = *v.use_begin();
+  Operation *user = use.getOwner();
 
-    if (!isa<scf::YieldOp>(user))
-      return v;
+  if (!isa<scf::YieldOp>(user))
+    return v;
 
-    if (auto yield = dyn_cast<scf::YieldOp>(user)) {
-      Operation *parent = yield->getParentOp();
-      unsigned idx = use.getOperandNumber();
-      if (auto res = contractionUsersAfterYield(parent->getResult(idx)))
-        return res;
-      continue;
-    }
-  }
+  auto yield = cast<scf::YieldOp>(user);
+  Operation *parent = yield->getParentOp();
+  unsigned idx = use.getOperandNumber();
 
-  return nullptr;
+  return contractionUsersAfterYield(parent->getResult(idx));
 }
 
 // Function to collapse the last two dimension (vnni and k) to help the
@@ -138,6 +135,9 @@ static LogicalResult validateContractOps(OpBuilder &rewriter,
 
     if (buffRhs != srcBuffRhs)
       return failure();
+
+    if (!contractionUsersAfterYield(contractOp.getResult()))
+      return failure();
   }
 
   VectorType accTy = dyn_cast<VectorType>(contractOp.getAccType());
@@ -572,7 +572,7 @@ createLoops(OpBuilder &rewriter, Location loc, Value lowerBound,
         Value matB;
         Operation *rhsOp = vectorOpRhs;
 
-        // Clone only if the op has operands.
+        // Clone for the subview type operations
         if (rhsOp->getNumOperands() > 0) {
 
           if (outerLoop) {
@@ -599,7 +599,7 @@ createLoops(OpBuilder &rewriter, Location loc, Value lowerBound,
           matB = rhsClone->getResult(0);
 
         } else {
-          // memref.get_global / constants
+          // The mat B is of kind 'memref.get_global @__constant'
           matB = rhsOp->getResult(0);
         }
 
@@ -803,7 +803,7 @@ struct VectorContractToAMXDotProduct
         return rewriter.notifyMatchFailure(
             contractOp, "The contract operation doesn't satisfy the operands "
                         "dimensions. M, N, and vnni dims are 16, 16, and 2/4. "
-                        "The rest dims should be 1.");
+                        "The rest dims should be 1. Op should have one user.");
 
       Location loc = contractOp.getLoc();
 
@@ -978,7 +978,7 @@ struct VectorContractToAMXDotProduct
             mlir::vector::ShapeCastOp::create(rewriter, loc, vecType, vecRow);
       }
 
-      resultOp.replaceAllUsesWith(vecRow);
+      rewriter.replaceAllUsesWith(resultOp, vecRow);
       return success();
     }
 
@@ -1044,9 +1044,10 @@ struct VectorContractToAMXDotProduct
 
         if (failed(validate))
           return rewriter.notifyMatchFailure(
-              contractOp, "The associated contract operations doesn't satisfy "
-                          "the re-write conditions either the dimensions are "
-                          "wrong or MemRef source are different.");
+              contractOp,
+              "The associated contract operations doesn't satisfy "
+              "the re-write conditions either the dimensions are "
+              "wrong or MemRef source are different or many users.");
 
         ops.push_back(contract);
       }
@@ -1072,7 +1073,6 @@ struct VectorContractToAMXDotProduct
     scf::ForOp newLoop;
     // Case 2a: Reduction loop depth is 2.
     if (loopLists.size() == 2) {
-
       outerLoop = loopLists[1];
       innerLoop = loopLists[0];
 
@@ -1328,12 +1328,12 @@ struct VectorContractToAMXDotProduct
     auto c0 = arith::ConstantIndexOp::create(rewriter, loc, 0);
     auto c16 = arith::ConstantIndexOp::create(rewriter, loc, 16);
     auto one = arith::ConstantIndexOp::create(rewriter, loc, 1);
-    auto mBound = arith::ConstantIndexOp::create(rewriter, loc, N);
+    auto nBound = arith::ConstantIndexOp::create(rewriter, loc, N);
 
     // Create a loop that iterates over the MxN memerf, retrives two rows +
     // shuffle them, add up the C element values and stores them to temp buffer.
     scf::ForOp::create(
-        rewriter, loc, c0, mBound, one, ValueRange{},
+        rewriter, loc, c0, nBound, one, ValueRange{},
         [&](OpBuilder &nestedBuilder, Location loc, Value iv,
             ValueRange iterArgs) {
           auto row =
@@ -1418,14 +1418,14 @@ struct VectorContractToAMXDotProduct
     // Replace use of vector.contract with dot-products.
     for (size_t i = 0; i < ops.size(); i++) {
       vector::ContractionOp contOp = ops[i];
-      Value vecRoc = writeResults[i];
+      Value vecRow = writeResults[i];
 
       Value resultWriteOp = contractionUsersAfterYield(contOp.getResult());
       if (auto vecType = llvm::dyn_cast<VectorType>(resultWriteOp.getType()))
-        vecRoc = mlir::vector::ShapeCastOp::create(rewriter, loc, vecType,
+        vecRow = mlir::vector::ShapeCastOp::create(rewriter, loc, vecType,
                                                    writeResults[i]);
 
-      resultWriteOp.replaceAllUsesWith(vecRoc);
+      rewriter.replaceAllUsesWith(resultWriteOp, vecRow);
     }
 
     return success();
diff --git a/mlir/lib/Dialect/X86/Utils/X86Utils.cpp b/mlir/lib/Dialect/X86/Utils/X86Utils.cpp
index 2cd6012b9ec03..a04a026f35ae6 100644
--- a/mlir/lib/Dialect/X86/Utils/X86Utils.cpp
+++ b/mlir/lib/Dialect/X86/Utils/X86Utils.cpp
@@ -187,8 +187,8 @@ Operation *traceToVectorReadLikeParentOperation(Value v) {
 // This function recursively traces a value through its uses to find
 // a downstream vector write-like operation (`vector.transfer_write`
 // or `vector.store`). It transparently follows values across `scf.for`
-// and `scf.yield` boundaries while stopping if layout-altering ops such
-// as `shape_cast` or `shuffle` are encountered. The traversal returns
+// and `scf.yield` boundaries while stopping if layout-altering ops
+// like `shuffle` are encountered. The traversal returns
 // the  matching write-like user. Returns `nullptr` if none is found or
 // the value has multiple users.
 Operation *traceToVectorWriteLikeUserOperation(Value v) {

>From b402d03db7f7b80b694d17dd5f2d183fda82682b Mon Sep 17 00:00:00 2001
From: Arun Thangamani <arun.thangamani at intel.com>
Date: Sun, 24 May 2026 01:02:39 -0700
Subject: [PATCH 09/10] support for the reduction loops to accept dynamic lower
 bounds

---
 .../VectorContractToAMXDotProduct.cpp         | 34 ++++++++++++++++---
 1 file changed, 29 insertions(+), 5 deletions(-)

diff --git a/mlir/lib/Dialect/X86/Transforms/VectorContractToAMXDotProduct.cpp b/mlir/lib/Dialect/X86/Transforms/VectorContractToAMXDotProduct.cpp
index 4b1f9adfde76d..ee6a2bf0ef6ca 100644
--- a/mlir/lib/Dialect/X86/Transforms/VectorContractToAMXDotProduct.cpp
+++ b/mlir/lib/Dialect/X86/Transforms/VectorContractToAMXDotProduct.cpp
@@ -1138,15 +1138,29 @@ struct VectorContractToAMXDotProduct
         rhsMapping.map(
             vectorOpRhs->getOperand(
                 getIndexPosition(contractOp.getRhs(), outerLoop) + 1),
-            c0);
+            outerLoop.getLowerBound());
         rhsMapping.map(
             vectorOpRhs->getOperand(
                 getIndexPosition(contractOp.getRhs(), innerLoop) + 1),
-            c0);
+            innerLoop.getLowerBound());
         auto rhsClone = rewriter.clone(*vectorOpRhs, rhsMapping);
 
+        Value quotient_batch = arith::DivUIOp::create(
+            rewriter, outerLoop.getLoc(), outerLoop.getLowerBound(),
+            outerLoop.getStep());
+        Value quotient_k = arith::DivUIOp::create(rewriter, outerLoop.getLoc(),
+                                                  innerLoop.getLowerBound(),
+                                                  innerLoop.getStep());
+
+        Value quotient_add = arith::AddIOp::create(rewriter, outerLoop.getLoc(),
+                                                   quotient_batch, quotient_k);
+        Value c2 =
+            arith::ConstantIndexOp::create(rewriter, outerLoop.getLoc(), 2);
+        Value rem = arith::RemUIOp::create(rewriter, outerLoop.getLoc(),
+                                           quotient_add, c2);
+
         performShuffle(rewriter, outerLoop.getLoc(), rhsClone->getResult(0),
-                       ipType, blockingFactor, packedBuffer, c0);
+                       ipType, blockingFactor, packedBuffer, rem);
 
         // First Set of Loops
         auto newLoopNonSpill = scf::ForOp::create(
@@ -1261,10 +1275,20 @@ struct VectorContractToAMXDotProduct
         rhsMapping.map(
             vectorOpRhs->getOperand(
                 getIndexPosition(contractOp.getRhs(), innerLoop) + 1),
-            c0);
+            innerLoop.getLowerBound());
         auto rhsClone = rewriter.clone(*vectorOpRhs, rhsMapping);
+
+        Value quotient_k = arith::DivUIOp::create(rewriter, innerLoop.getLoc(),
+                                                  innerLoop.getLowerBound(),
+                                                  innerLoop.getStep());
+        Value c2 =
+            arith::ConstantIndexOp::create(rewriter, innerLoop.getLoc(), 2);
+        Value rem = arith::RemUIOp::create(rewriter, innerLoop.getLoc(),
+                                           quotient_k, c2);
+
         performShuffle(rewriter, innerLoop.getLoc(), rhsClone->getResult(0),
-                       ipType, blockingFactor, packedBuffer, c0);
+                       ipType, blockingFactor, packedBuffer, rem);
+
         auto newLoopNonSpill = createLoops(
             rewriter, innerLoop.getLoc(), innerLoop.getLowerBound(),
             spillInnerLoop, innerLoop.getStep(), loopItrArgs, ipType, opType,

>From ef01f0b6fa3f4f96ef227c867b77a00898d76c6b Mon Sep 17 00:00:00 2001
From: Arun Thangamani <arun.thangamani at intel.com>
Date: Mon, 25 May 2026 04:51:30 -0700
Subject: [PATCH 10/10] check on loop step + new test-cases

---
 .../VectorContractToAMXDotProduct.cpp         |  52 ++++-
 .../X86/AMX/vector-contract-to-tiled-dp.mlir  | 202 ++++++++++++++++++
 2 files changed, 247 insertions(+), 7 deletions(-)

diff --git a/mlir/lib/Dialect/X86/Transforms/VectorContractToAMXDotProduct.cpp b/mlir/lib/Dialect/X86/Transforms/VectorContractToAMXDotProduct.cpp
index ee6a2bf0ef6ca..82840a4f86d79 100644
--- a/mlir/lib/Dialect/X86/Transforms/VectorContractToAMXDotProduct.cpp
+++ b/mlir/lib/Dialect/X86/Transforms/VectorContractToAMXDotProduct.cpp
@@ -107,6 +107,20 @@ getSrcIndxValue(OpBuilder &rewriter, Location loc, Value operand,
   return std::make_pair(srcBuff, indices);
 }
 
+// Function to validate the loop step value.
+static LogicalResult validateLoopStep(OpBuilder &rewriter, Value step,
+                                      int64_t value) {
+
+  auto cst = step.getDefiningOp<arith::ConstantIndexOp>();
+  if (!cst)
+    return failure();
+
+  if (cst.value() != value && cst.value() != 1)
+    return failure();
+
+  return success();
+}
+
 // Function to validate the vector.contract operation.
 static LogicalResult validateContractOps(OpBuilder &rewriter,
                                          vector::ContractionOp contractOp,
@@ -135,11 +149,11 @@ static LogicalResult validateContractOps(OpBuilder &rewriter,
 
     if (buffRhs != srcBuffRhs)
       return failure();
-
-    if (!contractionUsersAfterYield(contractOp.getResult()))
-      return failure();
   }
 
+  if (!contractionUsersAfterYield(contractOp.getResult()))
+    return failure();
+
   VectorType accTy = dyn_cast<VectorType>(contractOp.getAccType());
   if (!accTy)
     return failure();
@@ -973,10 +987,8 @@ struct VectorContractToAMXDotProduct
           inBounds);
 
       Value resultOp = contractionUsersAfterYield(contractOp.getResult());
-      if (auto vecType = llvm::dyn_cast<VectorType>(resultOp.getType())) {
-        vecRow =
-            mlir::vector::ShapeCastOp::create(rewriter, loc, vecType, vecRow);
-      }
+      if (auto vecType = llvm::dyn_cast<VectorType>(resultOp.getType()))
+        vecRow = vector::ShapeCastOp::create(rewriter, loc, vecType, vecRow);
 
       rewriter.replaceAllUsesWith(resultOp, vecRow);
       return success();
@@ -1076,6 +1088,21 @@ struct VectorContractToAMXDotProduct
       outerLoop = loopLists[1];
       innerLoop = loopLists[0];
 
+      LogicalResult validateOuterLoopStep =
+          validateLoopStep(rewriter, outerLoop.getStep(), 1);
+      if (failed(validateOuterLoopStep))
+        return rewriter.notifyMatchFailure(contractOp, "Invalid loop step.");
+
+      int64_t stepValue = 16;
+      if (!isVnni)
+        stepValue = stepValue * blockingFactor;
+      LogicalResult validateInnerLoopStep =
+          validateLoopStep(rewriter, innerLoop.getStep(), stepValue);
+      if (failed(validateInnerLoopStep))
+        return rewriter.notifyMatchFailure(
+            contractOp, "Invalid loop step. The step should be 32 for BF16 and "
+                        "64 for Int8/F8.");
+
       SmallVector<Value> loopItrArgs = createTileZeros(
           rewriter, outerLoop.getLoc(), opType, outerLoop, ops.size());
 
@@ -1221,6 +1248,17 @@ struct VectorContractToAMXDotProduct
     if (loopLists.size() == 1) {
 
       innerLoop = loopLists[0];
+      int64_t stepValue = 16;
+      if (!isVnni)
+        stepValue = stepValue * blockingFactor;
+
+      LogicalResult validateInnerLoopStep =
+          validateLoopStep(rewriter, innerLoop.getStep(), stepValue);
+      if (failed(validateInnerLoopStep))
+        return rewriter.notifyMatchFailure(
+            contractOp, "Invalid loop step. The step should be 32 for BF16 and "
+                        "64 for Int8/F8 or 1 if it is batch loop.");
+
       SmallVector<Value> loopItrArgs = createTileZeros(
           rewriter, innerLoop.getLoc(), opType, innerLoop, ops.size());
 
diff --git a/mlir/test/Dialect/X86/AMX/vector-contract-to-tiled-dp.mlir b/mlir/test/Dialect/X86/AMX/vector-contract-to-tiled-dp.mlir
index d8950bd8baaf7..fb2314dfdf506 100644
--- a/mlir/test/Dialect/X86/AMX/vector-contract-to-tiled-dp.mlir
+++ b/mlir/test/Dialect/X86/AMX/vector-contract-to-tiled-dp.mlir
@@ -846,6 +846,111 @@ module attributes {transform.with_named_sequence} {
   }
 }
 
+// -----
+
+!vecA = vector<16x32xbf16>
+!vecB = vector<32x16xbf16>
+!vecC = vector<16x16xf32>
+!memrefA = memref<32x32xbf16, strided<[96, 1], offset: ?>>
+!memrefB = memref<32x32xbf16, strided<[128, 1], offset: ?>>
+!memrefC = memref<32x32xf32, strided<[128, 1], offset: ?>>
+
+#map = affine_map<(d1, d2, d3) -> (d1, d3)>
+#map1 = affine_map<(d1, d2, d3) -> (d3, d2)>
+#map2 = affine_map<(d1, d2, d3) -> (d1, d2)>
+
+func.func @online_packing_bf16_loop_lb_non_zero(%arg0: memref<64x96xbf16>, %arg1: memref<96x128xbf16>, %arg2: memref<64x128xf32>, %klb: index, %kub: index) -> memref<64x128xf32> {
+  %0 = ub.poison : f32
+  %1 = ub.poison : bf16
+  %c0 = arith.constant 0 : index
+  %c64 = arith.constant 64 : index
+  %c128 = arith.constant 128 : index
+  %c16 = arith.constant 16 : index
+  %c96 = arith.constant 96 : index
+  %c32 = arith.constant 32 : index
+  %c1 = arith.constant 1 : index
+  scf.for %arg3 = %c0 to %c64 step %c32 {
+    scf.for %arg4 = %c0 to %c128 step %c32 {
+
+      %subview = memref.subview %arg2[%arg3, %arg4] [32, 32] [1, 1] :
+                memref<64x128xf32> to !memrefC
+      %2 = vector.transfer_read %subview[%c0, %c0], %0 {in_bounds = [true, true]} :
+                !memrefC, !vecC
+      %3 = vector.transfer_read %subview[%c0, %c16], %0 {in_bounds = [true, true]} :
+                !memrefC, !vecC
+      %4 = vector.transfer_read %subview[%c16, %c0], %0 {in_bounds = [true, true]} :
+                !memrefC, !vecC
+      %5 = vector.transfer_read %subview[%c16, %c16], %0 {in_bounds = [true, true]} :
+                !memrefC, !vecC
+
+      %7:4 = scf.for %arg10 = %klb to %kub step %c32 iter_args(%arg11 = %2, %arg12 = %3, %arg13 = %4, %arg14 = %5) -> (!vecC, !vecC, !vecC, !vecC) {
+        %subview_0 = memref.subview %arg0[%arg3, %arg10] [32, 32] [1, 1] :
+              memref<64x96xbf16> to !memrefA
+        %subview_1 = memref.subview %arg1[%arg10, %arg4] [32, 32] [1, 1] :
+              memref<96x128xbf16> to !memrefB
+        %8 = vector.transfer_read %subview_0[%c0, %c0], %1 {in_bounds = [true, true]} :
+              !memrefA, !vecA
+        %9 = vector.transfer_read %subview_0[%c16, %c0], %1 {in_bounds = [true, true]} :
+              !memrefA, !vecA
+        %10 = vector.transfer_read %subview_1[%c0, %c0], %1 {in_bounds = [true, true]} :
+              !memrefB, !vecB
+        %11 = vector.transfer_read %subview_1[%c0, %c16], %1 {in_bounds = [true, true]} :
+              !memrefB, !vecB
+
+        %12 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types =
+              ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+              %8, %10, %arg11 {unroll_shape = array<i64: 1, 16, 16, 32>} : !vecA, !vecB into !vecC
+        %13 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types =
+              ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+              %8, %11, %arg12 {unroll_shape = array<i64: 1, 16, 16, 32>} : !vecA, !vecB into !vecC
+        %14 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types =
+              ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+              %9, %10, %arg13 {unroll_shape = array<i64: 1, 16, 16, 32>} : !vecA, !vecB into !vecC
+        %15 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types =
+              ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+              %9, %11, %arg14 {unroll_shape = array<i64: 1, 16, 16, 32>} : !vecA, !vecB into !vecC
+
+        scf.yield %12, %13, %14, %15 : !vecC, !vecC, !vecC, !vecC
+      }
+      vector.transfer_write %7#3, %subview[%c16, %c16] {in_bounds = [true, true]} :
+                !vecC, !memrefC
+      vector.transfer_write %7#2, %subview[%c16, %c0] {in_bounds = [true, true]} :
+                !vecC, !memrefC
+      vector.transfer_write %7#1, %subview[%c0, %c16] {in_bounds = [true, true]} :
+                !vecC, !memrefC
+      vector.transfer_write %7#0, %subview[%c0, %c0] {in_bounds = [true, true]} :
+                !vecC, !memrefC
+    }
+  }
+  %alloc = memref.alloc() : memref<64x128xf32>
+  memref.copy %arg2, %alloc : memref<64x128xf32> to memref<64x128xf32>
+  return %alloc : memref<64x128xf32>
+}
+
+// CHECK-LABEL: @online_packing_bf16_loop_lb_non_zero
+// CHECK-COUNT-4: x86.amx.tile_zero : !x86.amx.tile<16x16xf32>
+// CHECK: scf.for {{.*}} -> (!x86.amx.tile<16x16xf32>, !x86.amx.tile<16x16xf32>, !x86.amx.tile<16x16xf32>, !x86.amx.tile<16x16xf32>) {
+// CHECK: vector.shuffle{{.*}}[0, 32, 1, 33, 2, 34, 3, 35, 8, 40, 9, 41, 10, 42, 11, 43, 16, 48, 17, 49, 18, 50, 19, 51, 24, 56, 25, 57, 26, 58, 27, 59] : vector<32xbf16>, vector<32xbf16>
+// CHECK-NEXT: vector.shuffle{{.*}}[4, 36, 5, 37, 6, 38, 7, 39, 12, 44, 13, 45, 14, 46, 15, 47, 20, 52, 21, 53, 22, 54, 23, 55, 28, 60, 29, 61, 30, 62, 31, 63] : vector<32xbf16>, vector<32xbf16>
+// CHECK: x86.amx.tile_load
+// CHECK: x86.amx.tile_mulf
+// CHECK: scf.yield {{.*}} : !x86.amx.tile<16x16xf32>, !x86.amx.tile<16x16xf32>, !x86.amx.tile<16x16xf32>, !x86.amx.tile<16x16xf32>
+// CHECK: vector.shuffle{{.*}}[0, 1, 2, 3, 16, 17, 18, 19, 4, 5, 6, 7, 20, 21, 22, 23] : vector<16xf32>, vector<16xf32>
+// CHECK-NEXT: vector.shuffle{{.*}}[8, 9, 10, 11, 24, 25, 26, 27, 12, 13, 14, 15, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32>
+// CHECK-NOT: scf.for {{.*}} vector<16x16xf32>, vector<16x16xf32>, vector<16x16xf32>, vector<16x16xf32>
+// CHECK-NOT: vector.contract
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.x86.vector_contract_to_amx_dot_product
+    } : !transform.any_op
+    transform.yield
+  }
+}
+
+
 // -----
 
 !vecA = vector<16x64xf8E5M2>
@@ -1458,6 +1563,103 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
+!vecA = vector<16x32xbf16>
+!vecB = vector<32x16xbf16>
+!vecC = vector<16x16xf32>
+!memrefA = memref<32x32xbf16, strided<[96, 1], offset: ?>>
+!memrefB = memref<32x32xbf16, strided<[128, 1], offset: ?>>
+!memrefC = memref<32x32xf32, strided<[128, 1], offset: ?>>
+
+#map = affine_map<(d1, d2, d3) -> (d1, d3)>
+#map1 = affine_map<(d1, d2, d3) -> (d3, d2)>
+#map2 = affine_map<(d1, d2, d3) -> (d1, d2)>
+
+func.func @negative_online_packing_bf16_dynamic_loop_step(%arg0: memref<64x96xbf16>, %arg1: memref<96x128xbf16>, %arg2: memref<64x128xf32>, %kStep: index) -> memref<64x128xf32> {
+  %0 = ub.poison : f32
+  %1 = ub.poison : bf16
+  %c0 = arith.constant 0 : index
+  %c64 = arith.constant 64 : index
+  %c128 = arith.constant 128 : index
+  %c16 = arith.constant 16 : index
+  %c96 = arith.constant 96 : index
+  %c32 = arith.constant 32 : index
+  %c1 = arith.constant 1 : index
+  scf.for %arg3 = %c0 to %c64 step %c32 {
+    scf.for %arg4 = %c0 to %c128 step %c32 {
+
+      %subview = memref.subview %arg2[%arg3, %arg4] [32, 32] [1, 1] :
+                memref<64x128xf32> to !memrefC
+      %2 = vector.transfer_read %subview[%c0, %c0], %0 {in_bounds = [true, true]} :
+                !memrefC, !vecC
+      %3 = vector.transfer_read %subview[%c0, %c16], %0 {in_bounds = [true, true]} :
+                !memrefC, !vecC
+      %4 = vector.transfer_read %subview[%c16, %c0], %0 {in_bounds = [true, true]} :
+                !memrefC, !vecC
+      %5 = vector.transfer_read %subview[%c16, %c16], %0 {in_bounds = [true, true]} :
+                !memrefC, !vecC
+
+      %7:4 = scf.for %arg10 = %c0 to %c32 step %kStep iter_args(%arg11 = %2, %arg12 = %3, %arg13 = %4, %arg14 = %5) -> (!vecC, !vecC, !vecC, !vecC) {
+        %subview_0 = memref.subview %arg0[%arg3, %arg10] [32, 32] [1, 1] :
+              memref<64x96xbf16> to !memrefA
+        %subview_1 = memref.subview %arg1[%arg10, %arg4] [32, 32] [1, 1] :
+              memref<96x128xbf16> to !memrefB
+        %8 = vector.transfer_read %subview_0[%c0, %c0], %1 {in_bounds = [true, true]} :
+              !memrefA, !vecA
+        %9 = vector.transfer_read %subview_0[%c16, %c0], %1 {in_bounds = [true, true]} :
+              !memrefA, !vecA
+        %10 = vector.transfer_read %subview_1[%c0, %c0], %1 {in_bounds = [true, true]} :
+              !memrefB, !vecB
+        %11 = vector.transfer_read %subview_1[%c0, %c16], %1 {in_bounds = [true, true]} :
+              !memrefB, !vecB
+
+        %12 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types =
+              ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+              %8, %10, %arg11 {unroll_shape = array<i64: 1, 16, 16, 32>} : !vecA, !vecB into !vecC
+        %13 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types =
+              ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+              %8, %11, %arg12 {unroll_shape = array<i64: 1, 16, 16, 32>} : !vecA, !vecB into !vecC
+        %14 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types =
+              ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+              %9, %10, %arg13 {unroll_shape = array<i64: 1, 16, 16, 32>} : !vecA, !vecB into !vecC
+        %15 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types =
+              ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+              %9, %11, %arg14 {unroll_shape = array<i64: 1, 16, 16, 32>} : !vecA, !vecB into !vecC
+
+        scf.yield %12, %13, %14, %15 : !vecC, !vecC, !vecC, !vecC
+      }
+      vector.transfer_write %7#3, %subview[%c16, %c16] {in_bounds = [true, true]} :
+                !vecC, !memrefC
+      vector.transfer_write %7#2, %subview[%c16, %c0] {in_bounds = [true, true]} :
+                !vecC, !memrefC
+      vector.transfer_write %7#1, %subview[%c0, %c16] {in_bounds = [true, true]} :
+                !vecC, !memrefC
+      vector.transfer_write %7#0, %subview[%c0, %c0] {in_bounds = [true, true]} :
+                !vecC, !memrefC
+    }
+  }
+  %alloc = memref.alloc() : memref<64x128xf32>
+  memref.copy %arg2, %alloc : memref<64x128xf32> to memref<64x128xf32>
+  return %alloc : memref<64x128xf32>
+}
+
+// CHECK-LABEL: @negative_online_packing_bf16_dynamic_loop_step
+// CHECK-NOT: x86.amx.tile_zero : !x86.amx.tile<16x16xf32>
+// CHECK-NOT: x86.amx.tile_load
+// CHECK-NOT: x86.amx.tile_mulf
+// CHECK: vector.contract
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.x86.vector_contract_to_amx_dot_product
+    } : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
 !vecAB = vector<1x1x16x16x4xi8>
 !vecC = vector<16x16xi32>
 !memrefA = memref<1x1x16x16x4xi8, strided<[262144, 16384, 256, 4, 1], offset: ?>>



More information about the Mlir-commits mailing list