[Mlir-commits] [mlir] [mlir][x86vector] Lower vector.contract to sequence of FMAs (PR #163382)

Arun Thangamani llvmlistbot at llvm.org
Mon Oct 27 03:34:58 PDT 2025


https://github.com/arun-thmn updated https://github.com/llvm/llvm-project/pull/163382

>From ce42a09598082c2ca7562bbf80312f5b9c3d60f8 Mon Sep 17 00:00:00 2001
From: Arun Thangamani <arun.thangamani at intel.com>
Date: Tue, 14 Oct 2025 05:29:32 -0700
Subject: [PATCH 1/6] lower vector.contract to sequence of FMAs

---
 .../mlir/Dialect/X86Vector/Transforms.h       |   9 +
 .../X86Vector/Transforms/CMakeLists.txt       |   1 +
 .../X86Vector/Transforms/NanoKernels.cpp      | 483 ++++++++++++++++++
 3 files changed, 493 insertions(+)
 create mode 100644 mlir/lib/Dialect/X86Vector/Transforms/NanoKernels.cpp

diff --git a/mlir/include/mlir/Dialect/X86Vector/Transforms.h b/mlir/include/mlir/Dialect/X86Vector/Transforms.h
index d54111ca41e69..cde890038f20a 100644
--- a/mlir/include/mlir/Dialect/X86Vector/Transforms.h
+++ b/mlir/include/mlir/Dialect/X86Vector/Transforms.h
@@ -11,6 +11,10 @@
 
 #include "mlir/IR/Value.h"
 
+#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/PatternMatch.h"
+
 namespace mlir {
 
 class ImplicitLocOpBuilder;
@@ -79,6 +83,11 @@ struct MaskHelper {
   }
 };
 
+//===----------------------------------------------------------------------===//
+// Nano-kernels
+LogicalResult nanoKernels(RewriterBase &rewriter,
+                           vector::ContractionOp contractOp, int64_t vectorSize);
+
 //===----------------------------------------------------------------------===//
 /// Helpers extracted from:
 ///   - clang/lib/Headers/avxintrin.h
diff --git a/mlir/lib/Dialect/X86Vector/Transforms/CMakeLists.txt b/mlir/lib/Dialect/X86Vector/Transforms/CMakeLists.txt
index c51266afe9e8f..da377763331f2 100644
--- a/mlir/lib/Dialect/X86Vector/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/X86Vector/Transforms/CMakeLists.txt
@@ -1,6 +1,7 @@
 add_mlir_dialect_library(MLIRX86VectorTransforms
   AVXTranspose.cpp
   LegalizeForLLVMExport.cpp
+  NanoKernels.cpp
 
   LINK_LIBS PUBLIC
   MLIRArithDialect
diff --git a/mlir/lib/Dialect/X86Vector/Transforms/NanoKernels.cpp b/mlir/lib/Dialect/X86Vector/Transforms/NanoKernels.cpp
new file mode 100644
index 0000000000000..bc03b567e06ff
--- /dev/null
+++ b/mlir/lib/Dialect/X86Vector/Transforms/NanoKernels.cpp
@@ -0,0 +1,483 @@
+//===- NanoKernels.cpp - Lower matmul to Nanokernels -- -------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements matmul rewrites as nanokernels with respect to target
+// machine for FP32 and BF16 (TODO) types.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/AMX/AMXDialect.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/X86Vector/X86VectorDialect.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
+#include "mlir/Dialect/X86Vector/Transforms.h"
+#include "mlir/IR/PatternMatch.h"
+
+using namespace mlir;
+using namespace mlir::vector;
+using namespace mlir::x86vector;
+
+static FailureOr<SmallVector<scf::ForOp>>
+getNestedLoop(vector::ContractionOp contractOp, int64_t dimCount) {
+  SmallVector<scf::ForOp> list;
+  Operation *current = contractOp;
+  // It is register tiled loop structure on batch reduce matmul
+  // (M->N->Batch-reduce->K).
+  for (int i = 0; i < dimCount; i++) {
+    Operation *parent = current->getParentOfType<scf::ForOp>();
+    if (!parent)
+      return failure();
+    list.push_back(dyn_cast<scf::ForOp>(parent));
+    current = parent;
+  }
+  return list;
+}
+
+static LogicalResult checkNestedLoop(SmallVector<scf::ForOp> loops,
+                                     SmallVector<memref::SubViewOp> subviews,
+                                     int64_t dimCount) {
+  auto subviewOpLhsOffsets = subviews[0].getOffsets();
+  auto subviewOpRhsOffsets = subviews[1].getOffsets();
+  auto subviewOpAccOffsets = subviews[2].getOffsets();
+
+  if (dimCount == 4) {
+    Value ivK = loops[0].getInductionVar();
+    if (ivK != subviewOpLhsOffsets[2] || ivK != subviewOpRhsOffsets[1])
+      return failure();
+
+    Value ivReduction = loops[1].getInductionVar();
+    if (ivReduction != subviewOpLhsOffsets[0] ||
+        ivReduction != subviewOpRhsOffsets[0])
+      return failure();
+
+    Value ivN = loops[2].getInductionVar();
+    if (ivN != subviewOpAccOffsets[1] || ivN != subviewOpRhsOffsets[2])
+      return failure();
+
+    Value ivM = loops[3].getInductionVar();
+    if (ivM != subviewOpLhsOffsets[1] || ivM != subviewOpAccOffsets[0])
+      return failure();
+  }
+
+  if (dimCount == 3) {
+    Value ivK = loops[0].getInductionVar();
+    if (ivK != subviewOpLhsOffsets[1] || ivK != subviewOpRhsOffsets[0])
+      return failure();
+
+    Value ivN = loops[1].getInductionVar();
+    if (ivN != subviewOpAccOffsets[1] || ivN != subviewOpRhsOffsets[1])
+      return failure();
+
+    Value ivM = loops[2].getInductionVar();
+    if (ivM != subviewOpLhsOffsets[0] || ivM != subviewOpAccOffsets[0])
+      return failure();
+  }
+
+  return success();
+}
+
+static SmallVector<Value> loadAcc(Location loc, RewriterBase &rewriter,
+                                  Type elementType, int64_t M, int64_t N,
+                                  int64_t vectorSize, Value subviewOpAcc) {
+
+  SmallVector<Value> loopItrArgs;
+  int64_t outerBound = M;
+  int64_t innerBound = N;
+
+  int64_t outerStep = 1;
+  int64_t innerStep = vectorSize;
+
+  if ((N / vectorSize) > M) {
+    outerBound = N;
+    innerBound = M;
+
+    outerStep = vectorSize;
+    innerStep = 1;
+  }
+
+  for (int i = 0; i < outerBound; i = i + outerStep) {
+    for (int j = 0; j < innerBound; j = j + innerStep) {
+      Value indexOp_A = rewriter.create<arith::ConstantIndexOp>(loc, i);
+      Value indexOp_B = rewriter.create<arith::ConstantIndexOp>(loc, j);
+
+      if ((N / vectorSize) > M) {
+        indexOp_A = indexOp_B;
+        indexOp_B = rewriter.create<arith::ConstantIndexOp>(loc, i);
+      }
+
+      auto valueCRow = rewriter.create<vector::LoadOp>(
+          loc, VectorType::get(vectorSize, elementType), subviewOpAcc,
+          ValueRange{indexOp_A, indexOp_B});
+      loopItrArgs.push_back(valueCRow);
+    }
+  }
+
+  return loopItrArgs;
+}
+
+SmallVector<Value> nanoKernels(RewriterBase &rewriter, Location loc,
+                               Type elementType, int64_t vectorSize,
+                               int64_t vnni, int64_t M, int64_t N,
+                               ValueRange acc, Value matA, Value matB,
+                               int64_t dimCount) {
+
+  SmallVector<Value> accVector;
+  SmallVector<Value> matLoad;
+  Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+
+  int64_t outerBound = M;
+  int64_t outerStep = 1;
+
+  int64_t innerBound = N;
+  int64_t innerStep = vectorSize;
+
+  Value outerMatrix = matA;
+  Value innerMatrix = matB;
+
+  int64_t outerVectSize = vnni;
+  int64_t innerVectSize = vectorSize;
+
+  int64_t fmaBound = M;
+
+  if ((N / vectorSize) < M) {
+    outerBound = N;
+    innerBound = M;
+
+    outerStep = vectorSize;
+    innerStep = 1;
+
+    outerMatrix = matB;
+    innerMatrix = matA;
+
+    outerVectSize = vectorSize;
+    innerVectSize = vnni;
+
+    fmaBound = N / vectorSize;
+  }
+
+  for (int i = 0; i < outerBound; i = i + outerStep) {
+    Value indexOp_i = rewriter.create<arith::ConstantIndexOp>(loc, i);
+    Value valueRow;
+
+    if ((N / vectorSize) > M) {
+
+      SmallVector<Value> index = {c0, indexOp_i, c0};
+      if (dimCount == 3) {
+        index.erase(index.begin());
+      }
+
+      Value row = rewriter.create<vector::LoadOp>(
+          loc, VectorType::get(outerVectSize, elementType), outerMatrix, index);
+      valueRow = rewriter.create<vector::BroadcastOp>(
+          loc, VectorType::get(vectorSize, rewriter.getF32Type()), row);
+    } else {
+
+      SmallVector<Value> index = {c0, c0, indexOp_i};
+      if (dimCount == 3) {
+        index.erase(index.begin());
+      }
+
+      valueRow = rewriter.create<vector::LoadOp>(
+          loc, VectorType::get(outerVectSize, elementType), outerMatrix, index);
+    }
+
+    matLoad.push_back(valueRow);
+  }
+
+  for (int j = 0, k = 0; j < innerBound; j = j + innerStep) {
+    Value indexOp_j = rewriter.create<arith::ConstantIndexOp>(loc, j);
+    Value valueRow;
+
+    if ((N / vectorSize) < M) {
+      SmallVector<Value> index = {c0, indexOp_j, c0};
+      if (dimCount == 3) {
+        index.erase(index.begin());
+      }
+      Value row = rewriter.create<vector::LoadOp>(
+          loc, VectorType::get(innerVectSize, elementType), innerMatrix,
+          ValueRange(index));
+      valueRow = rewriter.create<vector::BroadcastOp>(
+          loc, VectorType::get(vectorSize, rewriter.getF32Type()), row);
+    } else {
+
+      SmallVector<Value> index = {c0, c0, indexOp_j};
+      if (dimCount == 3) {
+        index.erase(index.begin());
+      }
+
+      valueRow = rewriter.create<vector::LoadOp>(
+          loc, VectorType::get(innerVectSize, elementType), innerMatrix, index);
+    }
+
+    for (int i = 0; i < fmaBound; i = i + 1) {
+      auto fmaOdd =
+          rewriter.create<vector::FMAOp>(loc, matLoad[i], valueRow, acc[k]);
+      k++;
+      accVector.push_back(fmaOdd);
+    }
+  }
+
+  return accVector;
+}
+
+Value accVector(RewriterBase &rewriter, Location loc, VectorType vecType,
+                SmallVector<Value> FMAs, Value accVec, int64_t vecSize,
+                int64_t M, int64_t N) {
+
+  auto strides = rewriter.getI64ArrayAttr({1});
+  if ((N / vecSize) > M) {
+    for (int j = 0, k = 0; j < (N / vecSize); j++) {
+      for (int i = 0; i < M; i++) {
+        int64_t off = (j * vecSize) + (i * N);
+        auto offsets = rewriter.getI64ArrayAttr({off});
+        accVec = rewriter.create<vector::InsertStridedSliceOp>(
+            loc, vecType, FMAs[k], accVec, offsets, strides);
+        k++;
+      }
+    }
+
+  } else {
+    for (int i = 0, k = 0; i < M * N; i = i + vecSize) {
+      auto offsets = rewriter.getI64ArrayAttr({i});
+      accVec = rewriter.create<vector::InsertStridedSliceOp>(
+          loc, vecType, FMAs[k], accVec, offsets, strides);
+      k++;
+    }
+  }
+  return accVec;
+}
+
+scf::ForOp createLoop(RewriterBase &rewriter, Location loc, scf::ForOp kForOp,
+                      vector::TransferReadOp vectorReadOpLhs,
+                      vector::TransferReadOp vectorReadOpRhs,
+                      Value ivNewReductionForOp, Type elementType,
+                      int64_t vectorSize, int64_t vnni, int64_t M, int64_t N,
+                      ValueRange iterArgsNewReductionForOp, int64_t dimCount) {
+  auto newKForOp = rewriter.create<scf::ForOp>(
+      kForOp.getLoc(), kForOp.getLowerBound(), kForOp.getUpperBound(),
+      kForOp.getStep(), iterArgsNewReductionForOp,
+      [&](OpBuilder &rewriterNewKForOp, Location locNewKForOp,
+          Value ivNewKForOp, ValueRange iterArgsNewKForOp) {
+        IRMapping mapping;
+        mapping.map(vectorReadOpLhs.getBase().getDefiningOp()->getOperand(1),
+                    ivNewReductionForOp);
+        mapping.map(vectorReadOpLhs.getBase().getDefiningOp()->getOperand(3),
+                    ivNewKForOp);
+        auto lhsClone = rewriterNewKForOp.clone(
+            *vectorReadOpLhs.getBase().getDefiningOp(), mapping);
+
+        IRMapping rhsMapping;
+        rhsMapping.map(vectorReadOpRhs.getBase().getDefiningOp()->getOperand(1),
+                       ivNewReductionForOp);
+        rhsMapping.map(vectorReadOpRhs.getBase().getDefiningOp()->getOperand(2),
+                       ivNewKForOp);
+        auto rhsClone = rewriterNewKForOp.clone(
+            *vectorReadOpRhs.getBase().getDefiningOp(), rhsMapping);
+
+        auto evenFMAs =
+            nanoKernels(rewriter, kForOp.getLoc(), elementType, vectorSize,
+                        vnni, M, N, iterArgsNewKForOp, lhsClone->getResult(0),
+                        rhsClone->getResult(0), dimCount);
+
+        rewriterNewKForOp.create<scf::YieldOp>(locNewKForOp, evenFMAs);
+      });
+
+  return newKForOp;
+}
+
+scf::ForOp createLoop(RewriterBase &rewriter, Location loc, scf::ForOp kForOp,
+                      vector::TransferReadOp vectorReadOpLhs,
+                      vector::TransferReadOp vectorReadOpRhs, Type elementType,
+                      int64_t vectorSize, int64_t vnni, int64_t M, int64_t N,
+                      ValueRange iterArgsNewReductionForOp, int64_t dimCount) {
+
+  auto newKForOp = rewriter.create<scf::ForOp>(
+      kForOp.getLoc(), kForOp.getLowerBound(), kForOp.getUpperBound(),
+      kForOp.getStep(), iterArgsNewReductionForOp,
+      [&](OpBuilder &rewriterNewKForOp, Location locNewKForOp,
+          Value ivNewKForOp, ValueRange iterArgsNewKForOp) {
+        IRMapping mapping;
+        mapping.map(vectorReadOpLhs.getBase().getDefiningOp()->getOperand(2),
+                    ivNewKForOp);
+        auto lhsClone = rewriterNewKForOp.clone(
+            *vectorReadOpLhs.getBase().getDefiningOp(), mapping);
+
+        IRMapping rhsMapping;
+        rhsMapping.map(vectorReadOpRhs.getBase().getDefiningOp()->getOperand(1),
+                       ivNewKForOp);
+        auto rhsClone = rewriterNewKForOp.clone(
+            *vectorReadOpRhs.getBase().getDefiningOp(), rhsMapping);
+
+        auto evenFMAs =
+            nanoKernels(rewriter, loc, elementType, vectorSize, vnni, M, N,
+                        iterArgsNewKForOp, lhsClone->getResult(0),
+                        rhsClone->getResult(0), dimCount);
+
+        rewriterNewKForOp.create<scf::YieldOp>(locNewKForOp, evenFMAs);
+      });
+
+  return newKForOp;
+}
+
+LogicalResult mlir::x86vector::nanoKernels(RewriterBase &rewriter,
+                                           vector::ContractionOp contractOp, int64_t vectorSize) {
+  auto loc = contractOp.getLoc();
+
+  if (contractOp.getKind() != vector::CombiningKind::ADD) {
+    return rewriter.notifyMatchFailure(contractOp,
+                                       "Expects add combining kind");
+  }
+
+  auto dimCount = contractOp.getRhsType().getRank() + 1;
+
+  if ((dimCount != 3) && (dimCount != 4))
+    return rewriter.notifyMatchFailure(contractOp,
+                                       "Expects batch-reduce or batch matmuls");
+
+  // Get the M, N, K, and batch-reduce loops
+  auto loops = getNestedLoop(contractOp, dimCount);
+  if (failed(loops))
+    return rewriter.notifyMatchFailure(contractOp,
+                                       "Invalid loop nest in contract pattern");
+
+  auto nestedLoops = *loops;
+  scf::ForOp kForOp = nestedLoops[0];
+  scf::ForOp reductionForOp; 
+
+  vector::TransferReadOp vectorReadOpAcc;
+
+  if (dimCount == 4) {
+    reductionForOp = nestedLoops[1];
+    vectorReadOpAcc =
+        reductionForOp.getInitArgs()[0].getDefiningOp<vector::TransferReadOp>();
+  }
+
+  if (dimCount == 3) {
+    vectorReadOpAcc =
+        kForOp.getInitArgs()[0].getDefiningOp<vector::TransferReadOp>();
+  }
+
+  auto vectorReadOpLhs =
+      contractOp.getLhs().getDefiningOp<vector::TransferReadOp>();
+  auto vectorReadOpRhs =
+      contractOp.getRhs().getDefiningOp<vector::TransferReadOp>();
+
+  if (!vectorReadOpAcc || !vectorReadOpLhs || !vectorReadOpRhs)
+    return failure();
+
+  auto subviewOpAcc =
+      vectorReadOpAcc.getOperand(0).getDefiningOp<memref::SubViewOp>();
+  auto subviewOpLhs =
+      vectorReadOpLhs.getOperand(0).getDefiningOp<memref::SubViewOp>();
+  auto subviewOpRhs =
+      vectorReadOpRhs.getOperand(0).getDefiningOp<memref::SubViewOp>();
+
+  if (!subviewOpAcc || !subviewOpLhs || !subviewOpRhs)
+    return failure();
+
+  SmallVector<memref::SubViewOp> subviews;
+  subviews.push_back(subviewOpLhs);
+  subviews.push_back(subviewOpRhs);
+  subviews.push_back(subviewOpAcc);
+
+  // The M, N, K, and batch-reduce loop iv should match the iv's 
+  // used in the subviews
+  auto checkLoops = checkNestedLoop(*loops, subviews, dimCount);
+  if (failed(checkLoops))
+    return rewriter.notifyMatchFailure(
+        contractOp, "Loops doesn't match the iv in subviews");
+
+  auto elementType =
+      (cast<MemRefType>(subviewOpLhs.getType())).getElementType();
+
+  // TODO: Support for BF16 Type
+  if (!elementType.isF32())
+    return rewriter.notifyMatchFailure(
+        contractOp, "Only, FP32 type is supported");
+
+  auto lhsType = dyn_cast<ShapedType>(vectorReadOpLhs.getType());
+  auto rhsType = dyn_cast<ShapedType>(vectorReadOpRhs.getType());
+
+  // Get M, N, and K dimension size
+  int64_t M = lhsType.getDimSize(lhsType.getRank() - 2);
+  int64_t N = rhsType.getDimSize(rhsType.getRank() - 1);
+  int64_t K = lhsType.getDimSize(lhsType.getRank() - 1);
+  int64_t vnni = 1;
+
+  if (K != 1)
+    return rewriter.notifyMatchFailure(contractOp, "The k-dim should be 1");
+
+  if (dimCount == 4 && lhsType.getDimSize(lhsType.getRank() - 3) != 1)
+    return rewriter.notifyMatchFailure(contractOp,
+                                       "The reduction-dim should be 1");
+
+
+  if (dimCount == 4)
+    rewriter.setInsertionPoint(reductionForOp);
+
+  if (dimCount == 3)
+    rewriter.setInsertionPoint(kForOp);
+
+  // Load  MxN C sub matrix into acc vectors (e.g, <vectorSizexf32>)
+  SmallVector<Value> loopItrArgs =
+      loadAcc(loc, rewriter, elementType, M, N, vectorSize, subviewOpAcc);
+
+  // Create the batch-reduce and K-loop with acc vectors as the loop
+  // iterargs (batch-reduce matmul) + nanokernel generation
+  scf::ForOp newLoop;
+  if (dimCount == 4) {
+    newLoop = rewriter.create<scf::ForOp>(
+        reductionForOp.getLoc(), reductionForOp.getLowerBound(),
+        reductionForOp.getUpperBound(), reductionForOp.getStep(), loopItrArgs,
+        [&](OpBuilder &rewriterNewReductionForOp, Location locNewReductionForOp,
+            Value ivNewReductionForOp, ValueRange iterArgsNewReductionForOp) {
+          scf::ForOp newKForOp = createLoop(
+              rewriter, loc, kForOp, vectorReadOpLhs, vectorReadOpRhs,
+              ivNewReductionForOp, elementType, vectorSize, vnni, M, N,
+              iterArgsNewReductionForOp, dimCount);
+
+          rewriterNewReductionForOp.create<scf::YieldOp>(
+              locNewReductionForOp, newKForOp.getResults());
+        });
+  }
+
+  // Create only the K-loop (batch matmul) + nanokernel generation
+  if (dimCount == 3) {
+    newLoop =
+        createLoop(rewriter, loc, kForOp, vectorReadOpLhs, vectorReadOpRhs,
+                   elementType, vectorSize, vnni, M, N, loopItrArgs, dimCount);
+  }
+
+
+  // Combine all acc vectors into a MxN C matrix
+  auto vecType = VectorType::get({M * N}, rewriter.getF32Type());
+  auto zeroAttr =
+      DenseElementsAttr::get(vecType, rewriter.getF32FloatAttr(0.0));
+  Value accVec = rewriter.create<arith::ConstantOp>(loc, vecType, zeroAttr);
+
+  accVec = accVector(rewriter, loc, vecType, newLoop.getResults(), accVec, vectorSize, M, N);
+
+  auto accTy = dyn_cast<VectorType>(contractOp.getAccType());
+  auto reshapeAcc = rewriter.create<vector::ShapeCastOp>(loc, accTy, accVec);
+
+  // Replace all the use of vector.contract with results of nanokernels
+  if (dimCount == 4)
+    rewriter.replaceAllUsesWith(reductionForOp.getResult(0), reshapeAcc);
+
+  if (dimCount == 3)
+    rewriter.replaceAllUsesWith(kForOp.getResult(0), reshapeAcc);
+
+  return success();
+}

>From 176cb06fed4c7358603dba8453ab3c1d79910dbd Mon Sep 17 00:00:00 2001
From: Arun Thangamani <arun.thangamani at intel.com>
Date: Wed, 22 Oct 2025 04:29:22 -0700
Subject: [PATCH 2/6] initial code to wrap as a Transform dialect

---
 .../mlir/Dialect/X86Vector/CMakeLists.txt     |  2 +
 .../X86Vector/TransformOps/CMakeLists.txt     |  4 ++
 .../TransformOps/X86VectorTransformOps.td     | 38 +++++++++++++++++
 .../mlir/Dialect/X86Vector/Transforms.h       |  4 +-
 mlir/lib/Dialect/X86Vector/CMakeLists.txt     |  1 +
 .../X86Vector/TransformOps/CMakeLists.txt     | 20 +++++++++
 .../TransformOps/X86VectorTransformOps.cpp    | 42 +++++++++++++++++++
 .../X86Vector/Transforms/NanoKernels.cpp      | 21 +++++++++-
 8 files changed, 128 insertions(+), 4 deletions(-)
 create mode 100644 mlir/include/mlir/Dialect/X86Vector/TransformOps/CMakeLists.txt
 create mode 100644 mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td
 create mode 100644 mlir/lib/Dialect/X86Vector/TransformOps/CMakeLists.txt
 create mode 100644 mlir/lib/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp

diff --git a/mlir/include/mlir/Dialect/X86Vector/CMakeLists.txt b/mlir/include/mlir/Dialect/X86Vector/CMakeLists.txt
index 0fe01824b8248..bbe8e4eb892dd 100644
--- a/mlir/include/mlir/Dialect/X86Vector/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/X86Vector/CMakeLists.txt
@@ -3,3 +3,5 @@ add_mlir_doc(X86Vector X86Vector Dialects/ -gen-dialect-doc -dialect=x86vector)
 
 add_mlir_interface(X86VectorInterfaces)
 add_dependencies(MLIRX86VectorIncGen MLIRX86VectorInterfacesIncGen)
+
+add_subdirectory(TransformOps)
diff --git a/mlir/include/mlir/Dialect/X86Vector/TransformOps/CMakeLists.txt b/mlir/include/mlir/Dialect/X86Vector/TransformOps/CMakeLists.txt
new file mode 100644
index 0000000000000..6f377e10fa8f8
--- /dev/null
+++ b/mlir/include/mlir/Dialect/X86Vector/TransformOps/CMakeLists.txt
@@ -0,0 +1,4 @@
+set(LLVM_TARGET_DEFINITIONS X86VectorTransformOps.td)
+mlir_tablegen(X86VectorTransformOps.h.inc -gen-op-decls)
+mlir_tablegen(X86VectorTransformOps.cpp.inc -gen-op-defs)
+add_mlir_dialect_tablegen_target(MLIRX86VectorTransformOpsIncGen)
diff --git a/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td b/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td
new file mode 100644
index 0000000000000..23f0eebaebe34
--- /dev/null
+++ b/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td
@@ -0,0 +1,38 @@
+//===- X86VectorTransformOps.td - X86Vector transform ops ---*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef X86VECTOR_TRANSFORM_OPS
+#define X86VECTOR_TRANSFORM_OPS
+
+include "mlir/Dialect/Transform/IR/TransformDialect.td"
+include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.td"
+include "mlir/Interfaces/SideEffectInterfaces.td"
+include "mlir/IR/OpBase.td"
+include "mlir/Dialect/Transform/IR/TransformAttrs.td"
+include "mlir/Dialect/Transform/IR/TransformTypes.td"
+include "mlir/IR/RegionKindInterface.td"
+
+def ApplyVectorContractNanokernelLoweringPatternsOp : Op<Transform_Dialect,
+    "apply_patterns.x86vector.vector_contract_nanokernel_lowering",
+    [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+  let description = [{
+    Indicates that vector contract operation can be lowered to target
+    specific nanokernels.
+  }];
+
+  //let arguments = (ins DefaultValuedAttr<I64Attr, "8">:$vector_size);
+
+  //let assemblyFormat = [{
+    //(`vector_size` `=` $vector_size^)? attr-dict
+  //}];
+
+  let assemblyFormat = "attr-dict";
+}
+
+
+#endif // X86VECTOR_TRANSFORM_OPS
diff --git a/mlir/include/mlir/Dialect/X86Vector/Transforms.h b/mlir/include/mlir/Dialect/X86Vector/Transforms.h
index cde890038f20a..6ddb4e542ba54 100644
--- a/mlir/include/mlir/Dialect/X86Vector/Transforms.h
+++ b/mlir/include/mlir/Dialect/X86Vector/Transforms.h
@@ -85,8 +85,8 @@ struct MaskHelper {
 
 //===----------------------------------------------------------------------===//
 // Nano-kernels
-LogicalResult nanoKernels(RewriterBase &rewriter,
-                           vector::ContractionOp contractOp, int64_t vectorSize);
+void populateVectorContractNanokernelLoweringPatterns(RewritePatternSet &patterns); //RewriterBase &rewriter,
+                           //vector::ContractionOp contractOp, int64_t vectorSize);
 
 //===----------------------------------------------------------------------===//
 /// Helpers extracted from:
diff --git a/mlir/lib/Dialect/X86Vector/CMakeLists.txt b/mlir/lib/Dialect/X86Vector/CMakeLists.txt
index 9f57627c321fb..cb1e9d01821a2 100644
--- a/mlir/lib/Dialect/X86Vector/CMakeLists.txt
+++ b/mlir/lib/Dialect/X86Vector/CMakeLists.txt
@@ -1,2 +1,3 @@
 add_subdirectory(IR)
 add_subdirectory(Transforms)
+add_subdirectory(TransformOps)
diff --git a/mlir/lib/Dialect/X86Vector/TransformOps/CMakeLists.txt b/mlir/lib/Dialect/X86Vector/TransformOps/CMakeLists.txt
new file mode 100644
index 0000000000000..8814547620e58
--- /dev/null
+++ b/mlir/lib/Dialect/X86Vector/TransformOps/CMakeLists.txt
@@ -0,0 +1,20 @@
+add_mlir_dialect_library(MLIRX86VectorTransformOps
+  X86VectorTransformOps.cpp
+
+  DEPENDS
+  MLIRX86VectorTransformOpsIncGen
+
+  LINK_LIBS PUBLIC
+  MLIRIR
+  MLIRLLVMCommonConversion
+  MLIRLLVMDialect
+  MLIRVectorDialect
+  MLIRVectorToLLVM
+  MLIRVectorTransforms
+  MLIRSideEffectInterfaces
+  MLIRTransformDialect
+  MLIRTransformDialectUtils
+  MLIRVectorDialect
+  MLIRVectorToSCF
+  MLIRX86VectorTransforms
+  )
diff --git a/mlir/lib/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp b/mlir/lib/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp
new file mode 100644
index 0000000000000..1570972db1855
--- /dev/null
+++ b/mlir/lib/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp
@@ -0,0 +1,42 @@
+//===- X86VectorTransformOps.cpp - Implementation of Vector transform ops --===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+
+#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/Transform/IR/TransformDialect.h"
+#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/X86Vector/Transforms.h"
+
+#include "mlir/Dialect/Transform/IR/TransformAttrs.h"
+#include "mlir/Dialect/Transform/IR/TransformTypes.h"
+#include "mlir/IR/OpImplementation.h"
+#include "mlir/IR/RegionKindInterface.h"
+
+using namespace mlir;
+using namespace mlir::x86vector;
+using namespace mlir::transform;
+
+
+
+void mlir::transform::ApplyVectorContractNanokernelLoweringPatternsOp::populatePatterns(
+    RewritePatternSet &patterns) {
+  vector::populateVectorTransferLoweringPatterns(patterns);//,
+                                                 //getVectorSize());
+}
+
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp.inc"
+
+void mlir::x86vector::registerTransformDialectExtension(
+    DialectRegistry &registry) {
+  registry.addExtensions<X86VectorTransformDialectExtension>();
+}
+
diff --git a/mlir/lib/Dialect/X86Vector/Transforms/NanoKernels.cpp b/mlir/lib/Dialect/X86Vector/Transforms/NanoKernels.cpp
index bc03b567e06ff..5aeba6cad0445 100644
--- a/mlir/lib/Dialect/X86Vector/Transforms/NanoKernels.cpp
+++ b/mlir/lib/Dialect/X86Vector/Transforms/NanoKernels.cpp
@@ -21,6 +21,10 @@
 #include "mlir/Pass/Pass.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/IR/Dominance.h"
+#include "mlir/Interfaces/DestinationStyleOpInterface.h"
+
 #include "mlir/Dialect/Vector/Utils/VectorUtils.h"
 #include "mlir/Dialect/X86Vector/Transforms.h"
 #include "mlir/IR/PatternMatch.h"
@@ -331,9 +335,16 @@ scf::ForOp createLoop(RewriterBase &rewriter, Location loc, scf::ForOp kForOp,
   return newKForOp;
 }
 
-LogicalResult mlir::x86vector::nanoKernels(RewriterBase &rewriter,
-                                           vector::ContractionOp contractOp, int64_t vectorSize) {
+//LogicalResult mlir::x86vector::nanoKernels(RewriterBase &rewriter,
+                  //                         vector::ContractionOp contractOp)//, int64_t vectorSize) {
+
+struct VectorContractNanokernelLowering final : public OpRewritePattern<vector::ContractionOp> {
+  using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
+                                PatternRewriter &rewriter) const override {
   auto loc = contractOp.getLoc();
+  int64_t vectorSize = 16;
 
   if (contractOp.getKind() != vector::CombiningKind::ADD) {
     return rewriter.notifyMatchFailure(contractOp,
@@ -481,3 +492,9 @@ LogicalResult mlir::x86vector::nanoKernels(RewriterBase &rewriter,
 
   return success();
 }
+};
+
+
+void x86vector::populateVectorContractNanokernelLoweringPatterns(RewritePatternSet &patterns) {
+  patterns.add<VectorContractNanokernelLowering>(patterns.getContext());
+}

>From 15f4c5d54ccf3c820097dc8f4e2dc8d2999e3fe8 Mon Sep 17 00:00:00 2001
From: Arun Thangamani <arun.thangamani at intel.com>
Date: Wed, 22 Oct 2025 07:37:41 -0700
Subject: [PATCH 3/6] fix build errors in transform code

---
 .../TransformOps/X86VectorTransformOps.h      | 36 +++++++++++++++++++
 .../X86Vector/TransformOps/CMakeLists.txt     |  2 +-
 .../TransformOps/X86VectorTransformOps.cpp    | 30 +++++++++++++---
 3 files changed, 63 insertions(+), 5 deletions(-)
 create mode 100644 mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.h

diff --git a/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.h b/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.h
new file mode 100644
index 0000000000000..abb5da75e5bfd
--- /dev/null
+++ b/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.h
@@ -0,0 +1,36 @@
+//===- VectorTransformOps.h - Vector transform ops --------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_X86VECTOR_TRANSFORMOPS_X86VECTORTRANSFORMOPS_H
+#define MLIR_DIALECT_X86VECTOR_TRANSFORMOPS_X86VECTORTRANSFORMOPS_H
+
+#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
+#include "mlir/IR/OpImplementation.h"
+
+namespace mlir {
+namespace x86vector {
+} // namespace vector
+} // namespace mlir
+
+//===----------------------------------------------------------------------===//
+// Vector Transform Operations
+//===----------------------------------------------------------------------===//
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.h.inc"
+
+namespace mlir {
+class DialectRegistry;
+
+namespace x86vector {
+void registerTransformDialectExtension(DialectRegistry &registry);
+
+} // namespace vector
+} // namespace mlir
+
+#endif // MLIR_DIALECT_X86VECTOR_TRANSFORMOPS_X86VECTORTRANSFORMOPS_H
diff --git a/mlir/lib/Dialect/X86Vector/TransformOps/CMakeLists.txt b/mlir/lib/Dialect/X86Vector/TransformOps/CMakeLists.txt
index 8814547620e58..5f85f7af60d01 100644
--- a/mlir/lib/Dialect/X86Vector/TransformOps/CMakeLists.txt
+++ b/mlir/lib/Dialect/X86Vector/TransformOps/CMakeLists.txt
@@ -14,7 +14,7 @@ add_mlir_dialect_library(MLIRX86VectorTransformOps
   MLIRSideEffectInterfaces
   MLIRTransformDialect
   MLIRTransformDialectUtils
-  MLIRVectorDialect
+  MLIRX86VectorDialect
   MLIRVectorToSCF
   MLIRX86VectorTransforms
   )
diff --git a/mlir/lib/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp b/mlir/lib/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp
index 1570972db1855..5702106558409 100644
--- a/mlir/lib/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp
+++ b/mlir/lib/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp
@@ -14,11 +14,13 @@
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/Dialect/X86Vector/Transforms.h"
 
-#include "mlir/Dialect/Transform/IR/TransformAttrs.h"
-#include "mlir/Dialect/Transform/IR/TransformTypes.h"
+#include "mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.h"
+
 #include "mlir/IR/OpImplementation.h"
 #include "mlir/IR/RegionKindInterface.h"
 
+#include "mlir/Dialect/X86Vector/X86VectorDialect.h"
+
 using namespace mlir;
 using namespace mlir::x86vector;
 using namespace mlir::transform;
@@ -27,10 +29,31 @@ using namespace mlir::transform;
 
 void mlir::transform::ApplyVectorContractNanokernelLoweringPatternsOp::populatePatterns(
     RewritePatternSet &patterns) {
-  vector::populateVectorTransferLoweringPatterns(patterns);//,
+  x86vector::populateVectorContractNanokernelLoweringPatterns(patterns);//,
                                                  //getVectorSize());
 }
 
+//===----------------------------------------------------------------------===//
+// Transform op registration
+//===----------------------------------------------------------------------===//
+
+namespace {
+class X86VectorTransformDialectExtension
+    : public transform::TransformDialectExtension<
+          X86VectorTransformDialectExtension> {
+public:
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(X86VectorTransformDialectExtension)
+
+  X86VectorTransformDialectExtension() {
+    declareGeneratedDialect<x86vector::X86VectorDialect>();
+    declareGeneratedDialect<LLVM::LLVMDialect>();
+    registerTransformOps<
+#define GET_OP_LIST
+#include "mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp.inc"
+        >();
+  }
+};
+} // namespace
 
 #define GET_OP_CLASSES
 #include "mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp.inc"
@@ -39,4 +62,3 @@ void mlir::x86vector::registerTransformDialectExtension(
     DialectRegistry &registry) {
   registry.addExtensions<X86VectorTransformDialectExtension>();
 }
-

>From ca4e25291d23c9140fbbc4904f816437aa2d48bb Mon Sep 17 00:00:00 2001
From: Arun Thangamani <arun.thangamani at intel.com>
Date: Thu, 23 Oct 2025 23:08:27 -0700
Subject: [PATCH 4/6] Round1: env set up + clang format fix

---
 .../TransformOps/X86VectorTransformOps.h      |  11 +-
 .../TransformOps/X86VectorTransformOps.td     |   9 +-
 .../mlir/Dialect/X86Vector/Transforms.h       |   4 +-
 .../X86Vector/TransformOps/CMakeLists.txt     |   3 -
 .../TransformOps/X86VectorTransformOps.cpp    |  23 +-
 .../X86Vector/Transforms/NanoKernels.cpp      | 363 +++++++++---------
 mlir/lib/RegisterAllExtensions.cpp            |   2 +
 .../vector-contract-to-nanokernels.mlir       |  48 +++
 8 files changed, 256 insertions(+), 207 deletions(-)
 create mode 100644 mlir/test/Dialect/X86Vector/vector-contract-to-nanokernels.mlir

diff --git a/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.h b/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.h
index abb5da75e5bfd..e1d8b8762e799 100644
--- a/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.h
+++ b/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.h
@@ -1,4 +1,4 @@
-//===- VectorTransformOps.h - Vector transform ops --------------*- C++ -*-===//
+//===- X86VectorTransformOps.h - X86Vector transform ops --------*- C++ -*-===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -12,13 +12,8 @@
 #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
 #include "mlir/IR/OpImplementation.h"
 
-namespace mlir {
-namespace x86vector {
-} // namespace vector
-} // namespace mlir
-
 //===----------------------------------------------------------------------===//
-// Vector Transform Operations
+// X86Vector Transform Operations
 //===----------------------------------------------------------------------===//
 
 #define GET_OP_CLASSES
@@ -30,7 +25,7 @@ class DialectRegistry;
 namespace x86vector {
 void registerTransformDialectExtension(DialectRegistry &registry);
 
-} // namespace vector
+} // namespace x86vector
 } // namespace mlir
 
 #endif // MLIR_DIALECT_X86VECTOR_TRANSFORMOPS_X86VECTORTRANSFORMOPS_H
diff --git a/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td b/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td
index 23f0eebaebe34..9db2b36a2a8aa 100644
--- a/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td
+++ b/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td
@@ -25,13 +25,12 @@ def ApplyVectorContractNanokernelLoweringPatternsOp : Op<Transform_Dialect,
     specific nanokernels.
   }];
 
-  //let arguments = (ins DefaultValuedAttr<I64Attr, "8">:$vector_size);
+  let arguments = (ins DefaultValuedAttr<I64Attr, "8">:$vector_size);
 
-  //let assemblyFormat = [{
-    //(`vector_size` `=` $vector_size^)? attr-dict
-  //}];
+  let assemblyFormat = [{
+    (`vector_size` `=` $vector_size^)? attr-dict
+  }];
 
-  let assemblyFormat = "attr-dict";
 }
 
 
diff --git a/mlir/include/mlir/Dialect/X86Vector/Transforms.h b/mlir/include/mlir/Dialect/X86Vector/Transforms.h
index 6ddb4e542ba54..a9487adba002a 100644
--- a/mlir/include/mlir/Dialect/X86Vector/Transforms.h
+++ b/mlir/include/mlir/Dialect/X86Vector/Transforms.h
@@ -85,8 +85,8 @@ struct MaskHelper {
 
 //===----------------------------------------------------------------------===//
 // Nano-kernels
-void populateVectorContractNanokernelLoweringPatterns(RewritePatternSet &patterns); //RewriterBase &rewriter,
-                           //vector::ContractionOp contractOp, int64_t vectorSize);
+void populateVectorContractNanokernelLoweringPatterns(
+    RewritePatternSet &patterns, std::optional<unsigned> vectorSize = 8);
 
 //===----------------------------------------------------------------------===//
 /// Helpers extracted from:
diff --git a/mlir/lib/Dialect/X86Vector/TransformOps/CMakeLists.txt b/mlir/lib/Dialect/X86Vector/TransformOps/CMakeLists.txt
index 5f85f7af60d01..f4c9f8a05acbc 100644
--- a/mlir/lib/Dialect/X86Vector/TransformOps/CMakeLists.txt
+++ b/mlir/lib/Dialect/X86Vector/TransformOps/CMakeLists.txt
@@ -9,12 +9,9 @@ add_mlir_dialect_library(MLIRX86VectorTransformOps
   MLIRLLVMCommonConversion
   MLIRLLVMDialect
   MLIRVectorDialect
-  MLIRVectorToLLVM
-  MLIRVectorTransforms
   MLIRSideEffectInterfaces
   MLIRTransformDialect
   MLIRTransformDialectUtils
   MLIRX86VectorDialect
-  MLIRVectorToSCF
   MLIRX86VectorTransforms
   )
diff --git a/mlir/lib/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp b/mlir/lib/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp
index 5702106558409..e003e3ad7cd08 100644
--- a/mlir/lib/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp
+++ b/mlir/lib/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp
@@ -1,4 +1,5 @@
-//===- X86VectorTransformOps.cpp - Implementation of Vector transform ops --===//
+//===- X86VectorTransformOps.cpp - Implementation of Vector transform ops
+//--===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -6,31 +7,26 @@
 //
 //===----------------------------------------------------------------------===//
 
-
+#include "mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.h"
 #include "mlir/Conversion/LLVMCommon/TypeConverter.h"
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/Dialect/Transform/IR/TransformDialect.h"
 #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/Dialect/X86Vector/Transforms.h"
-
-#include "mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.h"
+#include "mlir/Dialect/X86Vector/X86VectorDialect.h"
 
 #include "mlir/IR/OpImplementation.h"
 #include "mlir/IR/RegionKindInterface.h"
 
-#include "mlir/Dialect/X86Vector/X86VectorDialect.h"
-
 using namespace mlir;
 using namespace mlir::x86vector;
 using namespace mlir::transform;
 
-
-
-void mlir::transform::ApplyVectorContractNanokernelLoweringPatternsOp::populatePatterns(
-    RewritePatternSet &patterns) {
-  x86vector::populateVectorContractNanokernelLoweringPatterns(patterns);//,
-                                                 //getVectorSize());
+void mlir::transform::ApplyVectorContractNanokernelLoweringPatternsOp::
+    populatePatterns(RewritePatternSet &patterns) {
+  x86vector::populateVectorContractNanokernelLoweringPatterns(patterns,
+                                                              getVectorSize());
 }
 
 //===----------------------------------------------------------------------===//
@@ -42,7 +38,8 @@ class X86VectorTransformDialectExtension
     : public transform::TransformDialectExtension<
           X86VectorTransformDialectExtension> {
 public:
-  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(X86VectorTransformDialectExtension)
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
+      X86VectorTransformDialectExtension)
 
   X86VectorTransformDialectExtension() {
     declareGeneratedDialect<x86vector::X86VectorDialect>();
diff --git a/mlir/lib/Dialect/X86Vector/Transforms/NanoKernels.cpp b/mlir/lib/Dialect/X86Vector/Transforms/NanoKernels.cpp
index 5aeba6cad0445..583334333a49d 100644
--- a/mlir/lib/Dialect/X86Vector/Transforms/NanoKernels.cpp
+++ b/mlir/lib/Dialect/X86Vector/Transforms/NanoKernels.cpp
@@ -16,30 +16,28 @@
 #include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
+#include "mlir/Dialect/X86Vector/Transforms.h"
 #include "mlir/Dialect/X86Vector/X86VectorDialect.h"
-#include "mlir/IR/BuiltinAttributes.h"
-#include "mlir/Pass/Pass.h"
-#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 
-#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/Dominance.h"
-#include "mlir/Interfaces/DestinationStyleOpInterface.h"
-
-#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
-#include "mlir/Dialect/X86Vector/Transforms.h"
 #include "mlir/IR/PatternMatch.h"
 
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
 using namespace mlir;
 using namespace mlir::vector;
 using namespace mlir::x86vector;
 
 static FailureOr<SmallVector<scf::ForOp>>
-getNestedLoop(vector::ContractionOp contractOp, int64_t dimCount) {
+getNestedLoop(vector::ContractionOp contractOp, unsigned int dimCount) {
   SmallVector<scf::ForOp> list;
   Operation *current = contractOp;
   // It is register tiled loop structure on batch reduce matmul
   // (M->N->Batch-reduce->K).
-  for (int i = 0; i < dimCount; i++) {
+  for (unsigned int i = 0; i < dimCount; i++) {
     Operation *parent = current->getParentOfType<scf::ForOp>();
     if (!parent)
       return failure();
@@ -51,7 +49,7 @@ getNestedLoop(vector::ContractionOp contractOp, int64_t dimCount) {
 
 static LogicalResult checkNestedLoop(SmallVector<scf::ForOp> loops,
                                      SmallVector<memref::SubViewOp> subviews,
-                                     int64_t dimCount) {
+                                     unsigned int dimCount) {
   auto subviewOpLhsOffsets = subviews[0].getOffsets();
   auto subviewOpRhsOffsets = subviews[1].getOffsets();
   auto subviewOpAccOffsets = subviews[2].getOffsets();
@@ -93,15 +91,16 @@ static LogicalResult checkNestedLoop(SmallVector<scf::ForOp> loops,
 }
 
 static SmallVector<Value> loadAcc(Location loc, RewriterBase &rewriter,
-                                  Type elementType, int64_t M, int64_t N,
-                                  int64_t vectorSize, Value subviewOpAcc) {
+                                  Type elementType, unsigned int M,
+                                  unsigned int N, unsigned int vectorSize,
+                                  Value subviewOpAcc) {
 
   SmallVector<Value> loopItrArgs;
-  int64_t outerBound = M;
-  int64_t innerBound = N;
+  unsigned int outerBound = M;
+  unsigned int innerBound = N;
 
-  int64_t outerStep = 1;
-  int64_t innerStep = vectorSize;
+  unsigned int outerStep = 1;
+  unsigned int innerStep = vectorSize;
 
   if ((N / vectorSize) > M) {
     outerBound = N;
@@ -111,8 +110,8 @@ static SmallVector<Value> loadAcc(Location loc, RewriterBase &rewriter,
     innerStep = 1;
   }
 
-  for (int i = 0; i < outerBound; i = i + outerStep) {
-    for (int j = 0; j < innerBound; j = j + innerStep) {
+  for (unsigned int i = 0; i < outerBound; i = i + outerStep) {
+    for (unsigned int j = 0; j < innerBound; j = j + innerStep) {
       Value indexOp_A = rewriter.create<arith::ConstantIndexOp>(loc, i);
       Value indexOp_B = rewriter.create<arith::ConstantIndexOp>(loc, j);
 
@@ -132,28 +131,28 @@ static SmallVector<Value> loadAcc(Location loc, RewriterBase &rewriter,
 }
 
 SmallVector<Value> nanoKernels(RewriterBase &rewriter, Location loc,
-                               Type elementType, int64_t vectorSize,
-                               int64_t vnni, int64_t M, int64_t N,
-                               ValueRange acc, Value matA, Value matB,
-                               int64_t dimCount) {
+                               Type elementType, unsigned int vectorSize,
+                               unsigned int vnni, unsigned int M,
+                               unsigned int N, ValueRange acc, Value matA,
+                               Value matB, unsigned int dimCount) {
 
   SmallVector<Value> accVector;
   SmallVector<Value> matLoad;
   Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0);
 
-  int64_t outerBound = M;
-  int64_t outerStep = 1;
+  unsigned int outerBound = M;
+  unsigned int outerStep = 1;
 
-  int64_t innerBound = N;
-  int64_t innerStep = vectorSize;
+  unsigned int innerBound = N;
+  unsigned int innerStep = vectorSize;
 
   Value outerMatrix = matA;
   Value innerMatrix = matB;
 
-  int64_t outerVectSize = vnni;
-  int64_t innerVectSize = vectorSize;
+  unsigned int outerVectSize = vnni;
+  unsigned int innerVectSize = vectorSize;
 
-  int64_t fmaBound = M;
+  unsigned int fmaBound = M;
 
   if ((N / vectorSize) < M) {
     outerBound = N;
@@ -171,7 +170,7 @@ SmallVector<Value> nanoKernels(RewriterBase &rewriter, Location loc,
     fmaBound = N / vectorSize;
   }
 
-  for (int i = 0; i < outerBound; i = i + outerStep) {
+  for (unsigned int i = 0; i < outerBound; i = i + outerStep) {
     Value indexOp_i = rewriter.create<arith::ConstantIndexOp>(loc, i);
     Value valueRow;
 
@@ -200,7 +199,7 @@ SmallVector<Value> nanoKernels(RewriterBase &rewriter, Location loc,
     matLoad.push_back(valueRow);
   }
 
-  for (int j = 0, k = 0; j < innerBound; j = j + innerStep) {
+  for (unsigned int j = 0, k = 0; j < innerBound; j = j + innerStep) {
     Value indexOp_j = rewriter.create<arith::ConstantIndexOp>(loc, j);
     Value valueRow;
 
@@ -225,7 +224,7 @@ SmallVector<Value> nanoKernels(RewriterBase &rewriter, Location loc,
           loc, VectorType::get(innerVectSize, elementType), innerMatrix, index);
     }
 
-    for (int i = 0; i < fmaBound; i = i + 1) {
+    for (unsigned int i = 0; i < fmaBound; i = i + 1) {
       auto fmaOdd =
           rewriter.create<vector::FMAOp>(loc, matLoad[i], valueRow, acc[k]);
       k++;
@@ -237,14 +236,14 @@ SmallVector<Value> nanoKernels(RewriterBase &rewriter, Location loc,
 }
 
 Value accVector(RewriterBase &rewriter, Location loc, VectorType vecType,
-                SmallVector<Value> FMAs, Value accVec, int64_t vecSize,
-                int64_t M, int64_t N) {
+                SmallVector<Value> FMAs, Value accVec, unsigned int vecSize,
+                unsigned int M, unsigned int N) {
 
   auto strides = rewriter.getI64ArrayAttr({1});
   if ((N / vecSize) > M) {
-    for (int j = 0, k = 0; j < (N / vecSize); j++) {
-      for (int i = 0; i < M; i++) {
-        int64_t off = (j * vecSize) + (i * N);
+    for (unsigned int j = 0, k = 0; j < (N / vecSize); j++) {
+      for (unsigned int i = 0; i < M; i++) {
+        unsigned int off = (j * vecSize) + (i * N);
         auto offsets = rewriter.getI64ArrayAttr({off});
         accVec = rewriter.create<vector::InsertStridedSliceOp>(
             loc, vecType, FMAs[k], accVec, offsets, strides);
@@ -253,7 +252,7 @@ Value accVector(RewriterBase &rewriter, Location loc, VectorType vecType,
     }
 
   } else {
-    for (int i = 0, k = 0; i < M * N; i = i + vecSize) {
+    for (unsigned int i = 0, k = 0; i < M * N; i = i + vecSize) {
       auto offsets = rewriter.getI64ArrayAttr({i});
       accVec = rewriter.create<vector::InsertStridedSliceOp>(
           loc, vecType, FMAs[k], accVec, offsets, strides);
@@ -267,8 +266,10 @@ scf::ForOp createLoop(RewriterBase &rewriter, Location loc, scf::ForOp kForOp,
                       vector::TransferReadOp vectorReadOpLhs,
                       vector::TransferReadOp vectorReadOpRhs,
                       Value ivNewReductionForOp, Type elementType,
-                      int64_t vectorSize, int64_t vnni, int64_t M, int64_t N,
-                      ValueRange iterArgsNewReductionForOp, int64_t dimCount) {
+                      unsigned int vectorSize, unsigned int vnni,
+                      unsigned int M, unsigned int N,
+                      ValueRange iterArgsNewReductionForOp,
+                      unsigned int dimCount) {
   auto newKForOp = rewriter.create<scf::ForOp>(
       kForOp.getLoc(), kForOp.getLowerBound(), kForOp.getUpperBound(),
       kForOp.getStep(), iterArgsNewReductionForOp,
@@ -304,8 +305,10 @@ scf::ForOp createLoop(RewriterBase &rewriter, Location loc, scf::ForOp kForOp,
 scf::ForOp createLoop(RewriterBase &rewriter, Location loc, scf::ForOp kForOp,
                       vector::TransferReadOp vectorReadOpLhs,
                       vector::TransferReadOp vectorReadOpRhs, Type elementType,
-                      int64_t vectorSize, int64_t vnni, int64_t M, int64_t N,
-                      ValueRange iterArgsNewReductionForOp, int64_t dimCount) {
+                      unsigned int vectorSize, unsigned int vnni,
+                      unsigned int M, unsigned int N,
+                      ValueRange iterArgsNewReductionForOp,
+                      unsigned int dimCount) {
 
   auto newKForOp = rewriter.create<scf::ForOp>(
       kForOp.getLoc(), kForOp.getLowerBound(), kForOp.getUpperBound(),
@@ -335,166 +338,174 @@ scf::ForOp createLoop(RewriterBase &rewriter, Location loc, scf::ForOp kForOp,
   return newKForOp;
 }
 
-//LogicalResult mlir::x86vector::nanoKernels(RewriterBase &rewriter,
-                  //                         vector::ContractionOp contractOp)//, int64_t vectorSize) {
-
-struct VectorContractNanokernelLowering final : public OpRewritePattern<vector::ContractionOp> {
-  using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
+struct VectorContractNanokernelLowering
+    : public OpRewritePattern<vector::ContractionOp> {
+  VectorContractNanokernelLowering(MLIRContext *context,
+                                   std::optional<unsigned> vecSize)
+      : OpRewritePattern<vector::ContractionOp>(context),
+        userVectorSize(vecSize) {}
 
   LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
                                 PatternRewriter &rewriter) const override {
-  auto loc = contractOp.getLoc();
-  int64_t vectorSize = 16;
-
-  if (contractOp.getKind() != vector::CombiningKind::ADD) {
-    return rewriter.notifyMatchFailure(contractOp,
-                                       "Expects add combining kind");
-  }
-
-  auto dimCount = contractOp.getRhsType().getRank() + 1;
 
-  if ((dimCount != 3) && (dimCount != 4))
-    return rewriter.notifyMatchFailure(contractOp,
-                                       "Expects batch-reduce or batch matmuls");
+    auto loc = contractOp.getLoc();
 
-  // Get the M, N, K, and batch-reduce loops
-  auto loops = getNestedLoop(contractOp, dimCount);
-  if (failed(loops))
-    return rewriter.notifyMatchFailure(contractOp,
-                                       "Invalid loop nest in contract pattern");
+    unsigned int vectorSize = 8;
 
-  auto nestedLoops = *loops;
-  scf::ForOp kForOp = nestedLoops[0];
-  scf::ForOp reductionForOp; 
+    if (userVectorSize)
+      vectorSize = *userVectorSize;
 
-  vector::TransferReadOp vectorReadOpAcc;
+    if (contractOp.getKind() != vector::CombiningKind::ADD) {
+      return rewriter.notifyMatchFailure(contractOp,
+                                         "Expects add combining kind");
+    }
 
-  if (dimCount == 4) {
-    reductionForOp = nestedLoops[1];
-    vectorReadOpAcc =
-        reductionForOp.getInitArgs()[0].getDefiningOp<vector::TransferReadOp>();
-  }
+    auto dimCount = contractOp.getRhsType().getRank() + 1;
 
-  if (dimCount == 3) {
-    vectorReadOpAcc =
-        kForOp.getInitArgs()[0].getDefiningOp<vector::TransferReadOp>();
-  }
+    if ((dimCount != 3) && (dimCount != 4))
+      return rewriter.notifyMatchFailure(
+          contractOp, "Expects batch-reduce or batch matmuls");
 
-  auto vectorReadOpLhs =
-      contractOp.getLhs().getDefiningOp<vector::TransferReadOp>();
-  auto vectorReadOpRhs =
-      contractOp.getRhs().getDefiningOp<vector::TransferReadOp>();
+    // Get the M, N, K, and batch-reduce loops
+    auto loops = getNestedLoop(contractOp, dimCount);
+    if (failed(loops))
+      return rewriter.notifyMatchFailure(
+          contractOp, "Invalid loop nest in contract pattern");
 
-  if (!vectorReadOpAcc || !vectorReadOpLhs || !vectorReadOpRhs)
-    return failure();
+    auto nestedLoops = *loops;
+    scf::ForOp kForOp = nestedLoops[0];
+    scf::ForOp reductionForOp;
 
-  auto subviewOpAcc =
-      vectorReadOpAcc.getOperand(0).getDefiningOp<memref::SubViewOp>();
-  auto subviewOpLhs =
-      vectorReadOpLhs.getOperand(0).getDefiningOp<memref::SubViewOp>();
-  auto subviewOpRhs =
-      vectorReadOpRhs.getOperand(0).getDefiningOp<memref::SubViewOp>();
+    vector::TransferReadOp vectorReadOpAcc;
 
-  if (!subviewOpAcc || !subviewOpLhs || !subviewOpRhs)
-    return failure();
-
-  SmallVector<memref::SubViewOp> subviews;
-  subviews.push_back(subviewOpLhs);
-  subviews.push_back(subviewOpRhs);
-  subviews.push_back(subviewOpAcc);
+    if (dimCount == 4) {
+      reductionForOp = nestedLoops[1];
+      vectorReadOpAcc = reductionForOp.getInitArgs()[0]
+                            .getDefiningOp<vector::TransferReadOp>();
+    }
 
-  // The M, N, K, and batch-reduce loop iv should match the iv's 
-  // used in the subviews
-  auto checkLoops = checkNestedLoop(*loops, subviews, dimCount);
-  if (failed(checkLoops))
-    return rewriter.notifyMatchFailure(
-        contractOp, "Loops doesn't match the iv in subviews");
+    if (dimCount == 3) {
+      vectorReadOpAcc =
+          kForOp.getInitArgs()[0].getDefiningOp<vector::TransferReadOp>();
+    }
 
-  auto elementType =
-      (cast<MemRefType>(subviewOpLhs.getType())).getElementType();
+    auto vectorReadOpLhs =
+        contractOp.getLhs().getDefiningOp<vector::TransferReadOp>();
+    auto vectorReadOpRhs =
+        contractOp.getRhs().getDefiningOp<vector::TransferReadOp>();
 
-  // TODO: Support for BF16 Type
-  if (!elementType.isF32())
-    return rewriter.notifyMatchFailure(
-        contractOp, "Only, FP32 type is supported");
+    if (!vectorReadOpAcc || !vectorReadOpLhs || !vectorReadOpRhs)
+      return failure();
 
-  auto lhsType = dyn_cast<ShapedType>(vectorReadOpLhs.getType());
-  auto rhsType = dyn_cast<ShapedType>(vectorReadOpRhs.getType());
+    auto subviewOpAcc =
+        vectorReadOpAcc.getOperand(0).getDefiningOp<memref::SubViewOp>();
+    auto subviewOpLhs =
+        vectorReadOpLhs.getOperand(0).getDefiningOp<memref::SubViewOp>();
+    auto subviewOpRhs =
+        vectorReadOpRhs.getOperand(0).getDefiningOp<memref::SubViewOp>();
 
-  // Get M, N, and K dimension size
-  int64_t M = lhsType.getDimSize(lhsType.getRank() - 2);
-  int64_t N = rhsType.getDimSize(rhsType.getRank() - 1);
-  int64_t K = lhsType.getDimSize(lhsType.getRank() - 1);
-  int64_t vnni = 1;
+    if (!subviewOpAcc || !subviewOpLhs || !subviewOpRhs)
+      return failure();
 
-  if (K != 1)
-    return rewriter.notifyMatchFailure(contractOp, "The k-dim should be 1");
+    SmallVector<memref::SubViewOp> subviews;
+    subviews.push_back(subviewOpLhs);
+    subviews.push_back(subviewOpRhs);
+    subviews.push_back(subviewOpAcc);
+
+    // The M, N, K, and batch-reduce loop iv should match the iv's
+    // used in the subviews
+    auto checkLoops = checkNestedLoop(*loops, subviews, dimCount);
+    if (failed(checkLoops))
+      return rewriter.notifyMatchFailure(
+          contractOp, "Loops doesn't match the iv in subviews");
+
+    auto elementType =
+        (cast<MemRefType>(subviewOpLhs.getType())).getElementType();
+
+    // TODO: Support for BF16 Type
+    if (!elementType.isF32())
+      return rewriter.notifyMatchFailure(contractOp,
+                                         "Only, FP32 type is supported");
+
+    auto lhsType = dyn_cast<ShapedType>(vectorReadOpLhs.getType());
+    auto rhsType = dyn_cast<ShapedType>(vectorReadOpRhs.getType());
+
+    // Get M, N, and K dimension size
+    unsigned int M = lhsType.getDimSize(lhsType.getRank() - 2);
+    unsigned int N = rhsType.getDimSize(rhsType.getRank() - 1);
+    unsigned int K = lhsType.getDimSize(lhsType.getRank() - 1);
+    unsigned int vnni = 1;
+
+    if (K != 1)
+      return rewriter.notifyMatchFailure(contractOp, "The k-dim should be 1");
+
+    if (dimCount == 4 && lhsType.getDimSize(lhsType.getRank() - 3) != 1)
+      return rewriter.notifyMatchFailure(contractOp,
+                                         "The reduction-dim should be 1");
+
+    if (dimCount == 4)
+      rewriter.setInsertionPoint(reductionForOp);
+
+    if (dimCount == 3)
+      rewriter.setInsertionPoint(kForOp);
+
+    // Load  MxN C sub matrix into acc vectors (e.g, <vectorSizexf32>)
+    SmallVector<Value> loopItrArgs =
+        loadAcc(loc, rewriter, elementType, M, N, vectorSize, subviewOpAcc);
+
+    // Create the batch-reduce and K-loop with acc vectors as the loop
+    // iterargs (batch-reduce matmul) + nanokernel generation
+    scf::ForOp newLoop;
+    if (dimCount == 4) {
+      newLoop = rewriter.create<scf::ForOp>(
+          reductionForOp.getLoc(), reductionForOp.getLowerBound(),
+          reductionForOp.getUpperBound(), reductionForOp.getStep(), loopItrArgs,
+          [&](OpBuilder &rewriterNewReductionForOp,
+              Location locNewReductionForOp, Value ivNewReductionForOp,
+              ValueRange iterArgsNewReductionForOp) {
+            scf::ForOp newKForOp = createLoop(
+                rewriter, loc, kForOp, vectorReadOpLhs, vectorReadOpRhs,
+                ivNewReductionForOp, elementType, vectorSize, vnni, M, N,
+                iterArgsNewReductionForOp, dimCount);
+
+            rewriterNewReductionForOp.create<scf::YieldOp>(
+                locNewReductionForOp, newKForOp.getResults());
+          });
+    }
 
-  if (dimCount == 4 && lhsType.getDimSize(lhsType.getRank() - 3) != 1)
-    return rewriter.notifyMatchFailure(contractOp,
-                                       "The reduction-dim should be 1");
+    // Create only the K-loop (batch matmul) + nanokernel generation
+    if (dimCount == 3) {
+      newLoop = createLoop(rewriter, loc, kForOp, vectorReadOpLhs,
+                           vectorReadOpRhs, elementType, vectorSize, vnni, M, N,
+                           loopItrArgs, dimCount);
+    }
 
+    // Combine all acc vectors into a MxN C matrix
+    auto vecType = VectorType::get({M * N}, rewriter.getF32Type());
+    auto zeroAttr =
+        DenseElementsAttr::get(vecType, rewriter.getF32FloatAttr(0.0));
+    Value accVec = rewriter.create<arith::ConstantOp>(loc, vecType, zeroAttr);
 
-  if (dimCount == 4)
-    rewriter.setInsertionPoint(reductionForOp);
+    accVec = accVector(rewriter, loc, vecType, newLoop.getResults(), accVec,
+                       vectorSize, M, N);
 
-  if (dimCount == 3)
-    rewriter.setInsertionPoint(kForOp);
+    auto accTy = dyn_cast<VectorType>(contractOp.getAccType());
+    auto reshapeAcc = rewriter.create<vector::ShapeCastOp>(loc, accTy, accVec);
 
-  // Load  MxN C sub matrix into acc vectors (e.g, <vectorSizexf32>)
-  SmallVector<Value> loopItrArgs =
-      loadAcc(loc, rewriter, elementType, M, N, vectorSize, subviewOpAcc);
+    // Replace all the use of vector.contract with results of nanokernels
+    if (dimCount == 4)
+      rewriter.replaceAllUsesWith(reductionForOp.getResult(0), reshapeAcc);
 
-  // Create the batch-reduce and K-loop with acc vectors as the loop
-  // iterargs (batch-reduce matmul) + nanokernel generation
-  scf::ForOp newLoop;
-  if (dimCount == 4) {
-    newLoop = rewriter.create<scf::ForOp>(
-        reductionForOp.getLoc(), reductionForOp.getLowerBound(),
-        reductionForOp.getUpperBound(), reductionForOp.getStep(), loopItrArgs,
-        [&](OpBuilder &rewriterNewReductionForOp, Location locNewReductionForOp,
-            Value ivNewReductionForOp, ValueRange iterArgsNewReductionForOp) {
-          scf::ForOp newKForOp = createLoop(
-              rewriter, loc, kForOp, vectorReadOpLhs, vectorReadOpRhs,
-              ivNewReductionForOp, elementType, vectorSize, vnni, M, N,
-              iterArgsNewReductionForOp, dimCount);
-
-          rewriterNewReductionForOp.create<scf::YieldOp>(
-              locNewReductionForOp, newKForOp.getResults());
-        });
-  }
+    if (dimCount == 3)
+      rewriter.replaceAllUsesWith(kForOp.getResult(0), reshapeAcc);
 
-  // Create only the K-loop (batch matmul) + nanokernel generation
-  if (dimCount == 3) {
-    newLoop =
-        createLoop(rewriter, loc, kForOp, vectorReadOpLhs, vectorReadOpRhs,
-                   elementType, vectorSize, vnni, M, N, loopItrArgs, dimCount);
+    return success();
   }
-
-
-  // Combine all acc vectors into a MxN C matrix
-  auto vecType = VectorType::get({M * N}, rewriter.getF32Type());
-  auto zeroAttr =
-      DenseElementsAttr::get(vecType, rewriter.getF32FloatAttr(0.0));
-  Value accVec = rewriter.create<arith::ConstantOp>(loc, vecType, zeroAttr);
-
-  accVec = accVector(rewriter, loc, vecType, newLoop.getResults(), accVec, vectorSize, M, N);
-
-  auto accTy = dyn_cast<VectorType>(contractOp.getAccType());
-  auto reshapeAcc = rewriter.create<vector::ShapeCastOp>(loc, accTy, accVec);
-
-  // Replace all the use of vector.contract with results of nanokernels
-  if (dimCount == 4)
-    rewriter.replaceAllUsesWith(reductionForOp.getResult(0), reshapeAcc);
-
-  if (dimCount == 3)
-    rewriter.replaceAllUsesWith(kForOp.getResult(0), reshapeAcc);
-
-  return success();
-}
+  std::optional<unsigned> userVectorSize;
 };
 
-
-void x86vector::populateVectorContractNanokernelLoweringPatterns(RewritePatternSet &patterns) {
-  patterns.add<VectorContractNanokernelLowering>(patterns.getContext());
+void x86vector::populateVectorContractNanokernelLoweringPatterns(
+    RewritePatternSet &patterns, std::optional<unsigned> userVectorSize) {
+  patterns.add<VectorContractNanokernelLowering>(patterns.getContext(),
+                                                 userVectorSize);
 }
diff --git a/mlir/lib/RegisterAllExtensions.cpp b/mlir/lib/RegisterAllExtensions.cpp
index 3839172fd0b42..efcd09fc1b924 100644
--- a/mlir/lib/RegisterAllExtensions.cpp
+++ b/mlir/lib/RegisterAllExtensions.cpp
@@ -56,6 +56,7 @@
 #include "mlir/Dialect/Transform/SMTExtension/SMTExtension.h"
 #include "mlir/Dialect/Transform/TuneExtension/TuneExtension.h"
 #include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.h"
+#include "mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.h"
 #include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h"
 #include "mlir/Target/LLVMIR/Dialect/GPU/GPUToLLVMIRTranslation.h"
 #include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
@@ -112,6 +113,7 @@ void mlir::registerAllExtensions(DialectRegistry &registry) {
   transform::registerSMTExtension(registry);
   transform::registerTuneExtension(registry);
   vector::registerTransformDialectExtension(registry);
+  x86vector::registerTransformDialectExtension(registry);
   arm_neon::registerTransformDialectExtension(registry);
   arm_sve::registerTransformDialectExtension(registry);
 
diff --git a/mlir/test/Dialect/X86Vector/vector-contract-to-nanokernels.mlir b/mlir/test/Dialect/X86Vector/vector-contract-to-nanokernels.mlir
new file mode 100644
index 0000000000000..78ff150bb776e
--- /dev/null
+++ b/mlir/test/Dialect/X86Vector/vector-contract-to-nanokernels.mlir
@@ -0,0 +1,48 @@
+// RUN: mlir-opt %s -transform-interpreter -cse -split-input-file | FileCheck %s
+
+#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
+#map2 = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+module {
+  func.func @fp32_vectorSize_16(%arg0: memref<1x4x32xf32>, %arg1: memref<1x32x96xf32>, %arg2: memref<4x96xf32>) -> memref<4x96xf32> {
+    %0 = ub.poison : f32
+    %c0 = arith.constant 0 : index
+    %c4 = arith.constant 4 : index
+    %c96 = arith.constant 96 : index
+    %c1 = arith.constant 1 : index
+    %c32 = arith.constant 32 : index
+    scf.for %arg3 = %c0 to %c4 step %c4 {
+      scf.for %arg4 = %c0 to %c96 step %c96 {
+        %subview = memref.subview %arg2[%arg3, %arg4] [4, 96] [1, 1] : memref<4x96xf32> to memref<4x96xf32, strided<[96, 1], offset: ?>>
+        %1 = vector.transfer_read %subview[%c0, %c0], %0 {in_bounds = [true, true]} : memref<4x96xf32, strided<[96, 1], offset: ?>>, vector<4x96xf32>
+        %2 = scf.for %arg5 = %c0 to %c1 step %c1 iter_args(%arg6 = %1) -> (vector<4x96xf32>) {
+          %3 = scf.for %arg7 = %c0 to %c32 step %c1 iter_args(%arg8 = %arg6) -> (vector<4x96xf32>) {
+            %subview_0 = memref.subview %arg0[%arg5, %arg3, %arg7] [1, 4, 1] [1, 1, 1] : memref<1x4x32xf32> to memref<1x4x1xf32, strided<[128, 32, 1], offset: ?>>
+            %subview_1 = memref.subview %arg1[%arg5, %arg7, %arg4] [1, 1, 96] [1, 1, 1] : memref<1x32x96xf32> to memref<1x1x96xf32, strided<[3072, 96, 1], offset: ?>>
+            %4 = vector.transfer_read %subview_0[%c0, %c0, %c0], %0 {in_bounds = [true, true, true]} : memref<1x4x1xf32, strided<[128, 32, 1], offset: ?>>, vector<1x4x1xf32>
+            %5 = vector.transfer_read %subview_1[%c0, %c0, %c0], %0 {in_bounds = [true, true, true]} : memref<1x1x96xf32, strided<[3072, 96, 1], offset: ?>>, vector<1x1x96xf32>
+            %6 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %4, %5, %arg8 : vector<1x4x1xf32>, vector<1x1x96xf32> into vector<4x96xf32>
+            scf.yield %6 : vector<4x96xf32>
+          }
+          scf.yield %3 : vector<4x96xf32>
+        }
+        vector.transfer_write %2, %subview[%c0, %c0] {in_bounds = [true, true]} : vector<4x96xf32>, memref<4x96xf32, strided<[96, 1], offset: ?>>
+      }
+    }
+    return %arg2 : memref<4x96xf32>
+  }
+}
+
+// CHECK-LABEL: func.func @fp32_vectorSize_16(
+// CHECK-COUNT-24: vector.fma{{.*}}vector<16xf32>
+// CHECK-NOT: vector.contract
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.x86vector.vector_contract_nanokernel_lowering vector_size = 16
+    } : !transform.any_op
+    transform.yield
+  }
+}

>From ca52bdc8ee1f06622053b10e1c0c48b8c734372e Mon Sep 17 00:00:00 2001
From: Arun Thangamani <arun.thangamani at intel.com>
Date: Fri, 24 Oct 2025 09:09:37 -0700
Subject: [PATCH 5/6] rebase from the main branch

---
 .../X86Vector/Transforms/NanoKernels.cpp      | 285 +++++++++++-------
 .../vector-contract-to-nanokernels.mlir       |  45 +++
 2 files changed, 221 insertions(+), 109 deletions(-)

diff --git a/mlir/lib/Dialect/X86Vector/Transforms/NanoKernels.cpp b/mlir/lib/Dialect/X86Vector/Transforms/NanoKernels.cpp
index 583334333a49d..c8270c96daeea 100644
--- a/mlir/lib/Dialect/X86Vector/Transforms/NanoKernels.cpp
+++ b/mlir/lib/Dialect/X86Vector/Transforms/NanoKernels.cpp
@@ -90,12 +90,12 @@ static LogicalResult checkNestedLoop(SmallVector<scf::ForOp> loops,
   return success();
 }
 
-static SmallVector<Value> loadAcc(Location loc, RewriterBase &rewriter,
-                                  Type elementType, unsigned int M,
-                                  unsigned int N, unsigned int vectorSize,
-                                  Value subviewOpAcc) {
+static SmallVector<Value>
+loadAccumulatorBeforeGEMM(Location loc, RewriterBase &rewriter,
+                          Type elementType, unsigned int M, unsigned int N,
+                          unsigned int vectorSize, Value subviewOpAcc) {
 
-  SmallVector<Value> loopItrArgs;
+  SmallVector<Value> accumulators;
   unsigned int outerBound = M;
   unsigned int innerBound = N;
 
@@ -112,34 +112,74 @@ static SmallVector<Value> loadAcc(Location loc, RewriterBase &rewriter,
 
   for (unsigned int i = 0; i < outerBound; i = i + outerStep) {
     for (unsigned int j = 0; j < innerBound; j = j + innerStep) {
-      Value indexOp_A = rewriter.create<arith::ConstantIndexOp>(loc, i);
-      Value indexOp_B = rewriter.create<arith::ConstantIndexOp>(loc, j);
+      Value indexOp_A = arith::ConstantIndexOp::create(rewriter, loc, i);
+      Value indexOp_B = arith::ConstantIndexOp::create(rewriter, loc, j);
 
       if ((N / vectorSize) > M) {
         indexOp_A = indexOp_B;
-        indexOp_B = rewriter.create<arith::ConstantIndexOp>(loc, i);
+        indexOp_B = arith::ConstantIndexOp::create(rewriter, loc, i);
       }
 
-      auto valueCRow = rewriter.create<vector::LoadOp>(
-          loc, VectorType::get(vectorSize, elementType), subviewOpAcc,
+      auto valueCRow = vector::LoadOp::create(
+          rewriter, loc, VectorType::get(vectorSize, elementType), subviewOpAcc,
           ValueRange{indexOp_A, indexOp_B});
-      loopItrArgs.push_back(valueCRow);
+      accumulators.push_back(valueCRow);
     }
   }
 
-  return loopItrArgs;
+  return accumulators;
 }
 
-SmallVector<Value> nanoKernels(RewriterBase &rewriter, Location loc,
-                               Type elementType, unsigned int vectorSize,
-                               unsigned int vnni, unsigned int M,
-                               unsigned int N, ValueRange acc, Value matA,
-                               Value matB, unsigned int dimCount) {
-
-  SmallVector<Value> accVector;
+// Function accepts A Matrix, B Matrix, C Matrix (as vectors) and generate
+// equivalent target specific nanokernels. Returns the final accumulator as
+// output. Based on M tile, N tile, and vector size it generated optimized
+// nanokernels with condition of reduction and K dimension of the input matrix
+// are 1.
+//
+// Input: Matrix A, Matrix B, Accmulator as M*(N/vector size) vectors, M tile
+// size, N tile size, Vector size.
+//
+// Output:
+// case i: M > (N/vector size). For example, M=3; N=32; vector size = 16.
+//  load_B0 = load B[0-15] into vector<16xf32>
+//  load_B1 = load B[16-31] into vector<16xf32>
+//  bcst_A0 = load A[0] and broadcast it into vector<16xf32>
+//  o/p_Acc[0] = vector.fma bcst_A0, load_B0, i/p_Acc[0]
+//  o/p_Acc[1] = vector.fma bcst_A0, load_B1, i/p_Acc[1]
+//  bcst_A1 = load A[1] and broadcast it into vector<16xf32>
+//  o/p_Acc[2] = vector.fma bcst_A1, load_B0, i/p_Acc[2]
+//  o/p_Acc[3] = vector.fma bcst_A1, load_B1, i/p_Acc[3]
+//  bcst_A2 = load A[2] and broadcast it into vector<16xf32>
+//  o/p_Acc[4] = vector.fma bcst_A2, load_B0, i/p_Acc[4]
+//  o/p_Acc[5] = vector.fma bcst_A2, load_B1, i/p_Acc[5]
+//
+// case ii: M <= (N/vector size). For example, M=2; N=48; vector size = 16.
+//  bcst_A0 = load A[0] and broadcast it into vector<16xf32>
+//  bcst_A1 = load A[1] and broadcast it into vector<16xf32>
+//  bcst_A2 = load A[2] and broadcast it into vector<16xf32>
+//  load_B0 = load B[0-15] into vector<16xf32>
+//  o/p_Acc[0] = vector.fma bcst_A0, load_B0, i/p_Acc[0]
+//  o/p_Acc[1] = vector.fma bcst_A1, load_B0, i/p_Acc[1]
+//  load_B1 = load B[16-31] into vector<16xf32>
+//  o/p_Acc[2] = vector.fma bcst_A0, load_B1, i/p_Acc[2]
+//  o/p_Acc[3] = vector.fma bcst_A1, load_B1, i/p_Acc[3]
+//  load_B2 = load B[32-47] into vector<16xf32>
+//  o/p_Acc[4] = vector.fma bcst_A0, load_B2, i/p_Acc[4]
+//  o/p_Acc[5] = vector.fma bcst_A1, load_B2, i/p_Acc[5]
+//
+// return o/p_Acc;
+SmallVector<Value>
+generateNanokernels(RewriterBase &rewriter, Location loc, Type elementType,
+                    unsigned int vectorSize, unsigned int vnni, unsigned int M,
+                    unsigned int N, ValueRange acc, Value matA, Value matB,
+                    unsigned int dimCount) {
+
+  SmallVector<Value> accumulators;
   SmallVector<Value> matLoad;
-  Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+  Value c0 = arith::ConstantIndexOp::create(rewriter, loc, 0);
 
+  // Start with assumption that M tile size is smaller and create  the
+  // helper variables
   unsigned int outerBound = M;
   unsigned int outerStep = 1;
 
@@ -154,6 +194,7 @@ SmallVector<Value> nanoKernels(RewriterBase &rewriter, Location loc,
 
   unsigned int fmaBound = M;
 
+  // update helper variables if N tile size is smaller
   if ((N / vectorSize) < M) {
     outerBound = N;
     innerBound = M;
@@ -170,37 +211,52 @@ SmallVector<Value> nanoKernels(RewriterBase &rewriter, Location loc,
     fmaBound = N / vectorSize;
   }
 
+  // Load all the element of A or B matrix
   for (unsigned int i = 0; i < outerBound; i = i + outerStep) {
-    Value indexOp_i = rewriter.create<arith::ConstantIndexOp>(loc, i);
+    Value indexOp_i = arith::ConstantIndexOp::create(rewriter, loc, i);
     Value valueRow;
 
     if ((N / vectorSize) > M) {
 
+      // With the assumption as batch-reduce matmul initialize reduction, M, and
+      // K dimension.
       SmallVector<Value> index = {c0, indexOp_i, c0};
+
+      // Remove reduction dimension if it is a batch matmul
       if (dimCount == 3) {
         index.erase(index.begin());
       }
 
-      Value row = rewriter.create<vector::LoadOp>(
-          loc, VectorType::get(outerVectSize, elementType), outerMatrix, index);
-      valueRow = rewriter.create<vector::BroadcastOp>(
-          loc, VectorType::get(vectorSize, rewriter.getF32Type()), row);
+      // A Matrix load + broadcast
+      Value row = vector::LoadOp::create(
+          rewriter, loc, VectorType::get(outerVectSize, elementType),
+          outerMatrix, index);
+      valueRow = vector::BroadcastOp::create(
+          rewriter, loc, VectorType::get(vectorSize, rewriter.getF32Type()),
+          row);
     } else {
 
+      // With the assumption as batch-reduce matmul initialize reduction, K, and
+      // N dimension.
       SmallVector<Value> index = {c0, c0, indexOp_i};
+
+      // Remove reduction dimension if it is a batch matmul
       if (dimCount == 3) {
         index.erase(index.begin());
       }
 
-      valueRow = rewriter.create<vector::LoadOp>(
-          loc, VectorType::get(outerVectSize, elementType), outerMatrix, index);
+      // B Matrix load.
+      valueRow = vector::LoadOp::create(
+          rewriter, loc, VectorType::get(outerVectSize, elementType),
+          outerMatrix, index);
     }
 
     matLoad.push_back(valueRow);
   }
 
+  // Load elements of A/B Matrix one at a time and compute FMA
   for (unsigned int j = 0, k = 0; j < innerBound; j = j + innerStep) {
-    Value indexOp_j = rewriter.create<arith::ConstantIndexOp>(loc, j);
+    Value indexOp_j = arith::ConstantIndexOp::create(rewriter, loc, j);
     Value valueRow;
 
     if ((N / vectorSize) < M) {
@@ -208,11 +264,14 @@ SmallVector<Value> nanoKernels(RewriterBase &rewriter, Location loc,
       if (dimCount == 3) {
         index.erase(index.begin());
       }
-      Value row = rewriter.create<vector::LoadOp>(
-          loc, VectorType::get(innerVectSize, elementType), innerMatrix,
-          ValueRange(index));
-      valueRow = rewriter.create<vector::BroadcastOp>(
-          loc, VectorType::get(vectorSize, rewriter.getF32Type()), row);
+
+      // A Matrix load + broadcast
+      Value row = vector::LoadOp::create(
+          rewriter, loc, VectorType::get(innerVectSize, elementType),
+          innerMatrix, ValueRange(index));
+      valueRow = vector::BroadcastOp::create(
+          rewriter, loc, VectorType::get(vectorSize, rewriter.getF32Type()),
+          row);
     } else {
 
       SmallVector<Value> index = {c0, c0, indexOp_j};
@@ -220,58 +279,35 @@ SmallVector<Value> nanoKernels(RewriterBase &rewriter, Location loc,
         index.erase(index.begin());
       }
 
-      valueRow = rewriter.create<vector::LoadOp>(
-          loc, VectorType::get(innerVectSize, elementType), innerMatrix, index);
+      // B Matrix load
+      valueRow = vector::LoadOp::create(
+          rewriter, loc, VectorType::get(innerVectSize, elementType),
+          innerMatrix, index);
     }
 
+    // FMAs
     for (unsigned int i = 0; i < fmaBound; i = i + 1) {
       auto fmaOdd =
-          rewriter.create<vector::FMAOp>(loc, matLoad[i], valueRow, acc[k]);
+          vector::FMAOp::create(rewriter, loc, matLoad[i], valueRow, acc[k]);
       k++;
-      accVector.push_back(fmaOdd);
+      accumulators.push_back(fmaOdd);
     }
   }
 
-  return accVector;
-}
-
-Value accVector(RewriterBase &rewriter, Location loc, VectorType vecType,
-                SmallVector<Value> FMAs, Value accVec, unsigned int vecSize,
-                unsigned int M, unsigned int N) {
-
-  auto strides = rewriter.getI64ArrayAttr({1});
-  if ((N / vecSize) > M) {
-    for (unsigned int j = 0, k = 0; j < (N / vecSize); j++) {
-      for (unsigned int i = 0; i < M; i++) {
-        unsigned int off = (j * vecSize) + (i * N);
-        auto offsets = rewriter.getI64ArrayAttr({off});
-        accVec = rewriter.create<vector::InsertStridedSliceOp>(
-            loc, vecType, FMAs[k], accVec, offsets, strides);
-        k++;
-      }
-    }
-
-  } else {
-    for (unsigned int i = 0, k = 0; i < M * N; i = i + vecSize) {
-      auto offsets = rewriter.getI64ArrayAttr({i});
-      accVec = rewriter.create<vector::InsertStridedSliceOp>(
-          loc, vecType, FMAs[k], accVec, offsets, strides);
-      k++;
-    }
-  }
-  return accVec;
+  return accumulators;
 }
 
-scf::ForOp createLoop(RewriterBase &rewriter, Location loc, scf::ForOp kForOp,
-                      vector::TransferReadOp vectorReadOpLhs,
-                      vector::TransferReadOp vectorReadOpRhs,
-                      Value ivNewReductionForOp, Type elementType,
-                      unsigned int vectorSize, unsigned int vnni,
-                      unsigned int M, unsigned int N,
-                      ValueRange iterArgsNewReductionForOp,
-                      unsigned int dimCount) {
-  auto newKForOp = rewriter.create<scf::ForOp>(
-      kForOp.getLoc(), kForOp.getLowerBound(), kForOp.getUpperBound(),
+// Function to re-create K dimension loop with accumulator as IterArgs for
+// lowering a batch-reduce vector contraction to a system specific nanokernels.
+scf::ForOp createGEMMLoopsWithAccAsIterArgs(
+    RewriterBase &rewriter, Location loc, scf::ForOp kForOp,
+    vector::TransferReadOp vectorReadOpLhs,
+    vector::TransferReadOp vectorReadOpRhs, Value ivNewReductionForOp,
+    Type elementType, unsigned int vectorSize, unsigned int vnni,
+    unsigned int M, unsigned int N, ValueRange iterArgsNewReductionForOp,
+    unsigned int dimCount) {
+  auto newKForOp = scf::ForOp::create(
+      rewriter, kForOp.getLoc(), kForOp.getLowerBound(), kForOp.getUpperBound(),
       kForOp.getStep(), iterArgsNewReductionForOp,
       [&](OpBuilder &rewriterNewKForOp, Location locNewKForOp,
           Value ivNewKForOp, ValueRange iterArgsNewKForOp) {
@@ -291,27 +327,28 @@ scf::ForOp createLoop(RewriterBase &rewriter, Location loc, scf::ForOp kForOp,
         auto rhsClone = rewriterNewKForOp.clone(
             *vectorReadOpRhs.getBase().getDefiningOp(), rhsMapping);
 
-        auto evenFMAs =
-            nanoKernels(rewriter, kForOp.getLoc(), elementType, vectorSize,
-                        vnni, M, N, iterArgsNewKForOp, lhsClone->getResult(0),
-                        rhsClone->getResult(0), dimCount);
+        auto evenFMAs = generateNanokernels(
+            rewriter, kForOp.getLoc(), elementType, vectorSize, vnni, M, N,
+            iterArgsNewKForOp, lhsClone->getResult(0), rhsClone->getResult(0),
+            dimCount);
 
-        rewriterNewKForOp.create<scf::YieldOp>(locNewKForOp, evenFMAs);
+        scf::YieldOp::create(rewriterNewKForOp, locNewKForOp, evenFMAs);
       });
 
   return newKForOp;
 }
 
-scf::ForOp createLoop(RewriterBase &rewriter, Location loc, scf::ForOp kForOp,
-                      vector::TransferReadOp vectorReadOpLhs,
-                      vector::TransferReadOp vectorReadOpRhs, Type elementType,
-                      unsigned int vectorSize, unsigned int vnni,
-                      unsigned int M, unsigned int N,
-                      ValueRange iterArgsNewReductionForOp,
-                      unsigned int dimCount) {
-
-  auto newKForOp = rewriter.create<scf::ForOp>(
-      kForOp.getLoc(), kForOp.getLowerBound(), kForOp.getUpperBound(),
+// Function to re-create K dimension loop with accumulator as IterArgs for
+// lowering a batch vector contraction to a system specific nanokernels.
+scf::ForOp createGEMMLoopsWithAccAsIterArgs(
+    RewriterBase &rewriter, Location loc, scf::ForOp kForOp,
+    vector::TransferReadOp vectorReadOpLhs,
+    vector::TransferReadOp vectorReadOpRhs, Type elementType,
+    unsigned int vectorSize, unsigned int vnni, unsigned int M, unsigned int N,
+    ValueRange iterArgsNewReductionForOp, unsigned int dimCount) {
+
+  auto newKForOp = scf::ForOp::create(
+      rewriter, kForOp.getLoc(), kForOp.getLowerBound(), kForOp.getUpperBound(),
       kForOp.getStep(), iterArgsNewReductionForOp,
       [&](OpBuilder &rewriterNewKForOp, Location locNewKForOp,
           Value ivNewKForOp, ValueRange iterArgsNewKForOp) {
@@ -328,16 +365,45 @@ scf::ForOp createLoop(RewriterBase &rewriter, Location loc, scf::ForOp kForOp,
             *vectorReadOpRhs.getBase().getDefiningOp(), rhsMapping);
 
         auto evenFMAs =
-            nanoKernels(rewriter, loc, elementType, vectorSize, vnni, M, N,
-                        iterArgsNewKForOp, lhsClone->getResult(0),
-                        rhsClone->getResult(0), dimCount);
+            generateNanokernels(rewriter, loc, elementType, vectorSize, vnni, M,
+                                N, iterArgsNewKForOp, lhsClone->getResult(0),
+                                rhsClone->getResult(0), dimCount);
 
-        rewriterNewKForOp.create<scf::YieldOp>(locNewKForOp, evenFMAs);
+        scf::YieldOp::create(rewriterNewKForOp, locNewKForOp, evenFMAs);
       });
 
   return newKForOp;
 }
 
+Value mergeAccumulatedVectorAsMatrix(RewriterBase &rewriter, Location loc,
+                                     VectorType vecType,
+                                     SmallVector<Value> FMAs, Value accVec,
+                                     unsigned int vecSize, unsigned int M,
+                                     unsigned int N) {
+
+  auto strides = rewriter.getI64ArrayAttr({1});
+  if ((N / vecSize) > M) {
+    for (unsigned int j = 0, k = 0; j < (N / vecSize); j++) {
+      for (unsigned int i = 0; i < M; i++) {
+        unsigned int off = (j * vecSize) + (i * N);
+        auto offsets = rewriter.getI64ArrayAttr({off});
+        accVec = vector::InsertStridedSliceOp::create(
+            rewriter, loc, vecType, FMAs[k], accVec, offsets, strides);
+        k++;
+      }
+    }
+
+  } else {
+    for (unsigned int i = 0, k = 0; i < M * N; i = i + vecSize) {
+      auto offsets = rewriter.getI64ArrayAttr({i});
+      accVec = vector::InsertStridedSliceOp::create(
+          rewriter, loc, vecType, FMAs[k], accVec, offsets, strides);
+      k++;
+    }
+  }
+  return accVec;
+}
+
 struct VectorContractNanokernelLowering
     : public OpRewritePattern<vector::ContractionOp> {
   VectorContractNanokernelLowering(MLIRContext *context,
@@ -450,47 +516,48 @@ struct VectorContractNanokernelLowering
       rewriter.setInsertionPoint(kForOp);
 
     // Load  MxN C sub matrix into acc vectors (e.g, <vectorSizexf32>)
-    SmallVector<Value> loopItrArgs =
-        loadAcc(loc, rewriter, elementType, M, N, vectorSize, subviewOpAcc);
+    SmallVector<Value> accumulators = loadAccumulatorBeforeGEMM(
+        loc, rewriter, elementType, M, N, vectorSize, subviewOpAcc);
 
     // Create the batch-reduce and K-loop with acc vectors as the loop
     // iterargs (batch-reduce matmul) + nanokernel generation
     scf::ForOp newLoop;
     if (dimCount == 4) {
-      newLoop = rewriter.create<scf::ForOp>(
-          reductionForOp.getLoc(), reductionForOp.getLowerBound(),
-          reductionForOp.getUpperBound(), reductionForOp.getStep(), loopItrArgs,
+      newLoop = scf::ForOp::create(
+          rewriter, reductionForOp.getLoc(), reductionForOp.getLowerBound(),
+          reductionForOp.getUpperBound(), reductionForOp.getStep(),
+          accumulators,
           [&](OpBuilder &rewriterNewReductionForOp,
               Location locNewReductionForOp, Value ivNewReductionForOp,
               ValueRange iterArgsNewReductionForOp) {
-            scf::ForOp newKForOp = createLoop(
+            scf::ForOp newKForOp = createGEMMLoopsWithAccAsIterArgs(
                 rewriter, loc, kForOp, vectorReadOpLhs, vectorReadOpRhs,
                 ivNewReductionForOp, elementType, vectorSize, vnni, M, N,
                 iterArgsNewReductionForOp, dimCount);
 
-            rewriterNewReductionForOp.create<scf::YieldOp>(
-                locNewReductionForOp, newKForOp.getResults());
+            scf::YieldOp::create(rewriterNewReductionForOp,
+                                 locNewReductionForOp, newKForOp.getResults());
           });
     }
 
     // Create only the K-loop (batch matmul) + nanokernel generation
     if (dimCount == 3) {
-      newLoop = createLoop(rewriter, loc, kForOp, vectorReadOpLhs,
-                           vectorReadOpRhs, elementType, vectorSize, vnni, M, N,
-                           loopItrArgs, dimCount);
+      newLoop = createGEMMLoopsWithAccAsIterArgs(
+          rewriter, loc, kForOp, vectorReadOpLhs, vectorReadOpRhs, elementType,
+          vectorSize, vnni, M, N, accumulators, dimCount);
     }
 
     // Combine all acc vectors into a MxN C matrix
     auto vecType = VectorType::get({M * N}, rewriter.getF32Type());
     auto zeroAttr =
         DenseElementsAttr::get(vecType, rewriter.getF32FloatAttr(0.0));
-    Value accVec = rewriter.create<arith::ConstantOp>(loc, vecType, zeroAttr);
+    Value accVec = arith::ConstantOp::create(rewriter, loc, vecType, zeroAttr);
 
-    accVec = accVector(rewriter, loc, vecType, newLoop.getResults(), accVec,
-                       vectorSize, M, N);
+    accVec = mergeAccumulatedVectorAsMatrix(
+        rewriter, loc, vecType, newLoop.getResults(), accVec, vectorSize, M, N);
 
     auto accTy = dyn_cast<VectorType>(contractOp.getAccType());
-    auto reshapeAcc = rewriter.create<vector::ShapeCastOp>(loc, accTy, accVec);
+    auto reshapeAcc = vector::ShapeCastOp::create(rewriter, loc, accTy, accVec);
 
     // Replace all the use of vector.contract with results of nanokernels
     if (dimCount == 4)
diff --git a/mlir/test/Dialect/X86Vector/vector-contract-to-nanokernels.mlir b/mlir/test/Dialect/X86Vector/vector-contract-to-nanokernels.mlir
index 78ff150bb776e..184ba346e8638 100644
--- a/mlir/test/Dialect/X86Vector/vector-contract-to-nanokernels.mlir
+++ b/mlir/test/Dialect/X86Vector/vector-contract-to-nanokernels.mlir
@@ -46,3 +46,48 @@ module attributes {transform.with_named_sequence} {
     transform.yield
   }
 }
+
+// -----
+
+module {
+  func.func @fp32_batch_matmul_vector_size_8(%arg0: memref<4x32xf32>, %arg1: memref<32x96xf32>, %arg2: memref<4x96xf32>) -> memref<4x96xf32> {
+    %0 = ub.poison : f32
+    %c0 = arith.constant 0 : index
+    %c4 = arith.constant 4 : index
+    %c96 = arith.constant 96 : index
+    %c1 = arith.constant 1 : index
+    %c32 = arith.constant 32 : index
+    scf.for %arg3 = %c0 to %c4 step %c4 {
+      scf.for %arg4 = %c0 to %c96 step %c96 {
+        %subview = memref.subview %arg2[%arg3, %arg4] [4, 96] [1, 1] : memref<4x96xf32> to memref<4x96xf32, strided<[96, 1], offset: ?>>
+        %1 = vector.transfer_read %subview[%c0, %c0], %0 {in_bounds = [true, true]} : memref<4x96xf32, strided<[96, 1], offset: ?>>, vector<4x96xf32>
+
+          %3 = scf.for %arg7 = %c0 to %c32 step %c1 iter_args(%arg8 = %1) -> (vector<4x96xf32>) {
+            %subview_0 = memref.subview %arg0[%arg3, %arg7] [4, 1] [1, 1] : memref<4x32xf32> to memref<4x1xf32, strided<[32, 1], offset: ?>>
+            %subview_1 = memref.subview %arg1[%arg7, %arg4] [1, 96] [1, 1] : memref<32x96xf32> to memref<1x96xf32, strided<[96, 1], offset: ?>>
+            %4 = vector.transfer_read %subview_0[%c0, %c0], %0 {in_bounds = [true, true]} : memref<4x1xf32, strided<[32, 1], offset: ?>>, vector<4x1xf32>
+            %5 = vector.transfer_read %subview_1[%c0, %c0], %0 {in_bounds = [true, true]} : memref<1x96xf32, strided<[96, 1], offset: ?>>, vector<1x96xf32>
+            %6 = vector.contract {indexing_maps = [affine_map<(d1, d2, d3) -> (d1, d3)>, affine_map<(d1, d2, d3) -> (d3, d2)>, affine_map<(d1, d2, d3) -> (d1, d2)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %4, %5, %arg8 : vector<4x1xf32>, vector<1x96xf32> into vector<4x96xf32>
+            scf.yield %6 : vector<4x96xf32>
+          }
+
+        vector.transfer_write %3, %subview[%c0, %c0] {in_bounds = [true, true]} : vector<4x96xf32>, memref<4x96xf32, strided<[96, 1], offset: ?>>
+      }
+    }
+    return %arg2 : memref<4x96xf32>
+  }
+}
+
+// CHECK-LABEL: func.func @fp32_batch_matmul_vector_size_8(
+// CHECK-COUNT-48: vector.fma{{.*}}vector<8xf32>
+// CHECK-NOT: vector.contract
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.x86vector.vector_contract_nanokernel_lowering vector_size = 8
+    } : !transform.any_op
+    transform.yield
+  }
+}

>From f34f50d987593f182dcd6e772000855549f5607e Mon Sep 17 00:00:00 2001
From: Arun Thangamani <arun.thangamani at intel.com>
Date: Mon, 27 Oct 2025 03:34:41 -0700
Subject: [PATCH 6/6] added more commnets + more test-cases

---
 .../mlir/Dialect/X86Vector/Transforms.h       |   5 +-
 .../X86Vector/Transforms/NanoKernels.cpp      | 187 +++++++++++++-----
 .../vector-contract-to-nanokernels.mlir       | 126 +++++++++++-
 3 files changed, 262 insertions(+), 56 deletions(-)

diff --git a/mlir/include/mlir/Dialect/X86Vector/Transforms.h b/mlir/include/mlir/Dialect/X86Vector/Transforms.h
index a9487adba002a..6410c12265f12 100644
--- a/mlir/include/mlir/Dialect/X86Vector/Transforms.h
+++ b/mlir/include/mlir/Dialect/X86Vector/Transforms.h
@@ -84,7 +84,10 @@ struct MaskHelper {
 };
 
 //===----------------------------------------------------------------------===//
-// Nano-kernels
+// Transforms a scheduled pattern to lower a tiled batch or batch-reduce
+// vector contraction into a sequence of nanokernels.
+// The transformation is tailored to the target machine architecture
+// and guided by the user-specified vector size.
 void populateVectorContractNanokernelLoweringPatterns(
     RewritePatternSet &patterns, std::optional<unsigned> vectorSize = 8);
 
diff --git a/mlir/lib/Dialect/X86Vector/Transforms/NanoKernels.cpp b/mlir/lib/Dialect/X86Vector/Transforms/NanoKernels.cpp
index c8270c96daeea..4d0906a2ec057 100644
--- a/mlir/lib/Dialect/X86Vector/Transforms/NanoKernels.cpp
+++ b/mlir/lib/Dialect/X86Vector/Transforms/NanoKernels.cpp
@@ -7,7 +7,8 @@
 //===----------------------------------------------------------------------===//
 //
 // This file implements matmul rewrites as nanokernels with respect to target
-// machine for FP32 and BF16 (TODO) types.
+// machine for FP32 (for selective batch or batch-reduce matmul patterns) and
+// BF16 (TODO) types.
 //
 //===----------------------------------------------------------------------===//
 
@@ -31,12 +32,18 @@ using namespace mlir;
 using namespace mlir::vector;
 using namespace mlir::x86vector;
 
+// Enum to represent the type of matmul operation
+enum class MatMulType { Batch, BatchReduce, Others };
+
 static FailureOr<SmallVector<scf::ForOp>>
-getNestedLoop(vector::ContractionOp contractOp, unsigned int dimCount) {
+getTiledMatmulLoopNest(vector::ContractionOp contractOp,
+                       MatMulType matmulType) {
   SmallVector<scf::ForOp> list;
   Operation *current = contractOp;
-  // It is register tiled loop structure on batch reduce matmul
-  // (M->N->Batch-reduce->K).
+  unsigned int dimCount = matmulType == MatMulType::BatchReduce ? 4 : 3;
+
+  // It is register tiled loop structure on batch (or reduce) matmul
+  // (M->N->(reduce)->K).
   for (unsigned int i = 0; i < dimCount; i++) {
     Operation *parent = current->getParentOfType<scf::ForOp>();
     if (!parent)
@@ -47,14 +54,14 @@ getNestedLoop(vector::ContractionOp contractOp, unsigned int dimCount) {
   return list;
 }
 
-static LogicalResult checkNestedLoop(SmallVector<scf::ForOp> loops,
-                                     SmallVector<memref::SubViewOp> subviews,
-                                     unsigned int dimCount) {
+static LogicalResult checkMatmulLoopAndSubviewOffsetsMatching(
+    SmallVector<scf::ForOp> loops, SmallVector<memref::SubViewOp> subviews,
+    MatMulType matmulType) {
   auto subviewOpLhsOffsets = subviews[0].getOffsets();
   auto subviewOpRhsOffsets = subviews[1].getOffsets();
   auto subviewOpAccOffsets = subviews[2].getOffsets();
 
-  if (dimCount == 4) {
+  if (matmulType == MatMulType::BatchReduce) {
     Value ivK = loops[0].getInductionVar();
     if (ivK != subviewOpLhsOffsets[2] || ivK != subviewOpRhsOffsets[1])
       return failure();
@@ -73,7 +80,7 @@ static LogicalResult checkNestedLoop(SmallVector<scf::ForOp> loops,
       return failure();
   }
 
-  if (dimCount == 3) {
+  if (matmulType == MatMulType::Batch) {
     Value ivK = loops[0].getInductionVar();
     if (ivK != subviewOpLhsOffsets[1] || ivK != subviewOpRhsOffsets[0])
       return failure();
@@ -96,13 +103,16 @@ loadAccumulatorBeforeGEMM(Location loc, RewriterBase &rewriter,
                           unsigned int vectorSize, Value subviewOpAcc) {
 
   SmallVector<Value> accumulators;
+
+  // Initialize local variable on assumption that M tile is larger than N
   unsigned int outerBound = M;
   unsigned int innerBound = N;
 
   unsigned int outerStep = 1;
   unsigned int innerStep = vectorSize;
 
-  if ((N / vectorSize) > M) {
+  bool isNTileLarge = (N / vectorSize) > M;
+  if (isNTileLarge) {
     outerBound = N;
     innerBound = M;
 
@@ -115,7 +125,7 @@ loadAccumulatorBeforeGEMM(Location loc, RewriterBase &rewriter,
       Value indexOp_A = arith::ConstantIndexOp::create(rewriter, loc, i);
       Value indexOp_B = arith::ConstantIndexOp::create(rewriter, loc, j);
 
-      if ((N / vectorSize) > M) {
+      if (isNTileLarge) {
         indexOp_A = indexOp_B;
         indexOp_B = arith::ConstantIndexOp::create(rewriter, loc, i);
       }
@@ -130,17 +140,18 @@ loadAccumulatorBeforeGEMM(Location loc, RewriterBase &rewriter,
   return accumulators;
 }
 
-// Function accepts A Matrix, B Matrix, C Matrix (as vectors) and generate
-// equivalent target specific nanokernels. Returns the final accumulator as
-// output. Based on M tile, N tile, and vector size it generated optimized
-// nanokernels with condition of reduction and K dimension of the input matrix
-// are 1.
+// This function takes matrices A, B, and C (represented as vectors)
+// and generates equivalent target-specific nanokernels.
+// It returns the final accumulator as output.
+// Based on the M tile, N tile, and vector size, it generates optimized
+// nanokernels under the condition that the reduction and K dimension
+// of the input matrices are equal to 1.
 //
 // Input: Matrix A, Matrix B, Accmulator as M*(N/vector size) vectors, M tile
 // size, N tile size, Vector size.
 //
 // Output:
-// case i: M > (N/vector size). For example, M=3; N=32; vector size = 16.
+// case i: M >= (N/vector size). For example, M=3; N=32; vector size = 16.
 //  load_B0 = load B[0-15] into vector<16xf32>
 //  load_B1 = load B[16-31] into vector<16xf32>
 //  bcst_A0 = load A[0] and broadcast it into vector<16xf32>
@@ -153,7 +164,7 @@ loadAccumulatorBeforeGEMM(Location loc, RewriterBase &rewriter,
 //  o/p_Acc[4] = vector.fma bcst_A2, load_B0, i/p_Acc[4]
 //  o/p_Acc[5] = vector.fma bcst_A2, load_B1, i/p_Acc[5]
 //
-// case ii: M <= (N/vector size). For example, M=2; N=48; vector size = 16.
+// case ii: M < (N/vector size). For example, M=2; N=48; vector size = 16.
 //  bcst_A0 = load A[0] and broadcast it into vector<16xf32>
 //  bcst_A1 = load A[1] and broadcast it into vector<16xf32>
 //  bcst_A2 = load A[2] and broadcast it into vector<16xf32>
@@ -172,7 +183,7 @@ SmallVector<Value>
 generateNanokernels(RewriterBase &rewriter, Location loc, Type elementType,
                     unsigned int vectorSize, unsigned int vnni, unsigned int M,
                     unsigned int N, ValueRange acc, Value matA, Value matB,
-                    unsigned int dimCount) {
+                    MatMulType matmulType) {
 
   SmallVector<Value> accumulators;
   SmallVector<Value> matLoad;
@@ -195,7 +206,8 @@ generateNanokernels(RewriterBase &rewriter, Location loc, Type elementType,
   unsigned int fmaBound = M;
 
   // update helper variables if N tile size is smaller
-  if ((N / vectorSize) < M) {
+  bool isNTileLarge = (N / vectorSize) > M;
+  if (!isNTileLarge) {
     outerBound = N;
     innerBound = M;
 
@@ -216,14 +228,14 @@ generateNanokernels(RewriterBase &rewriter, Location loc, Type elementType,
     Value indexOp_i = arith::ConstantIndexOp::create(rewriter, loc, i);
     Value valueRow;
 
-    if ((N / vectorSize) > M) {
+    if (isNTileLarge) {
 
       // With the assumption as batch-reduce matmul initialize reduction, M, and
       // K dimension.
       SmallVector<Value> index = {c0, indexOp_i, c0};
 
       // Remove reduction dimension if it is a batch matmul
-      if (dimCount == 3) {
+      if (matmulType == MatMulType::Batch) {
         index.erase(index.begin());
       }
 
@@ -241,7 +253,7 @@ generateNanokernels(RewriterBase &rewriter, Location loc, Type elementType,
       SmallVector<Value> index = {c0, c0, indexOp_i};
 
       // Remove reduction dimension if it is a batch matmul
-      if (dimCount == 3) {
+      if (matmulType == MatMulType::Batch) {
         index.erase(index.begin());
       }
 
@@ -259,9 +271,9 @@ generateNanokernels(RewriterBase &rewriter, Location loc, Type elementType,
     Value indexOp_j = arith::ConstantIndexOp::create(rewriter, loc, j);
     Value valueRow;
 
-    if ((N / vectorSize) < M) {
+    if (!isNTileLarge) {
       SmallVector<Value> index = {c0, indexOp_j, c0};
-      if (dimCount == 3) {
+      if (matmulType == MatMulType::Batch) {
         index.erase(index.begin());
       }
 
@@ -275,7 +287,7 @@ generateNanokernels(RewriterBase &rewriter, Location loc, Type elementType,
     } else {
 
       SmallVector<Value> index = {c0, c0, indexOp_j};
-      if (dimCount == 3) {
+      if (matmulType == MatMulType::Batch) {
         index.erase(index.begin());
       }
 
@@ -305,7 +317,7 @@ scf::ForOp createGEMMLoopsWithAccAsIterArgs(
     vector::TransferReadOp vectorReadOpRhs, Value ivNewReductionForOp,
     Type elementType, unsigned int vectorSize, unsigned int vnni,
     unsigned int M, unsigned int N, ValueRange iterArgsNewReductionForOp,
-    unsigned int dimCount) {
+    MatMulType matmulType) {
   auto newKForOp = scf::ForOp::create(
       rewriter, kForOp.getLoc(), kForOp.getLowerBound(), kForOp.getUpperBound(),
       kForOp.getStep(), iterArgsNewReductionForOp,
@@ -330,7 +342,7 @@ scf::ForOp createGEMMLoopsWithAccAsIterArgs(
         auto evenFMAs = generateNanokernels(
             rewriter, kForOp.getLoc(), elementType, vectorSize, vnni, M, N,
             iterArgsNewKForOp, lhsClone->getResult(0), rhsClone->getResult(0),
-            dimCount);
+            matmulType);
 
         scf::YieldOp::create(rewriterNewKForOp, locNewKForOp, evenFMAs);
       });
@@ -345,7 +357,7 @@ scf::ForOp createGEMMLoopsWithAccAsIterArgs(
     vector::TransferReadOp vectorReadOpLhs,
     vector::TransferReadOp vectorReadOpRhs, Type elementType,
     unsigned int vectorSize, unsigned int vnni, unsigned int M, unsigned int N,
-    ValueRange iterArgsNewReductionForOp, unsigned int dimCount) {
+    ValueRange iterArgsNewReductionForOp, MatMulType matmulType) {
 
   auto newKForOp = scf::ForOp::create(
       rewriter, kForOp.getLoc(), kForOp.getLowerBound(), kForOp.getUpperBound(),
@@ -367,7 +379,7 @@ scf::ForOp createGEMMLoopsWithAccAsIterArgs(
         auto evenFMAs =
             generateNanokernels(rewriter, loc, elementType, vectorSize, vnni, M,
                                 N, iterArgsNewKForOp, lhsClone->getResult(0),
-                                rhsClone->getResult(0), dimCount);
+                                rhsClone->getResult(0), matmulType);
 
         scf::YieldOp::create(rewriterNewKForOp, locNewKForOp, evenFMAs);
       });
@@ -378,14 +390,15 @@ scf::ForOp createGEMMLoopsWithAccAsIterArgs(
 Value mergeAccumulatedVectorAsMatrix(RewriterBase &rewriter, Location loc,
                                      VectorType vecType,
                                      SmallVector<Value> FMAs, Value accVec,
-                                     unsigned int vecSize, unsigned int M,
+                                     unsigned int vectorSize, unsigned int M,
                                      unsigned int N) {
 
   auto strides = rewriter.getI64ArrayAttr({1});
-  if ((N / vecSize) > M) {
-    for (unsigned int j = 0, k = 0; j < (N / vecSize); j++) {
+  bool isNTileLarge = (N / vectorSize) > M;
+  if (isNTileLarge) {
+    for (unsigned int j = 0, k = 0; j < (N / vectorSize); j++) {
       for (unsigned int i = 0; i < M; i++) {
-        unsigned int off = (j * vecSize) + (i * N);
+        unsigned int off = (j * vectorSize) + (i * N);
         auto offsets = rewriter.getI64ArrayAttr({off});
         accVec = vector::InsertStridedSliceOp::create(
             rewriter, loc, vecType, FMAs[k], accVec, offsets, strides);
@@ -394,7 +407,7 @@ Value mergeAccumulatedVectorAsMatrix(RewriterBase &rewriter, Location loc,
     }
 
   } else {
-    for (unsigned int i = 0, k = 0; i < M * N; i = i + vecSize) {
+    for (unsigned int i = 0, k = 0; i < M * N; i = i + vectorSize) {
       auto offsets = rewriter.getI64ArrayAttr({i});
       accVec = vector::InsertStridedSliceOp::create(
           rewriter, loc, vecType, FMAs[k], accVec, offsets, strides);
@@ -404,6 +417,53 @@ Value mergeAccumulatedVectorAsMatrix(RewriterBase &rewriter, Location loc,
   return accVec;
 }
 
+// Rewriter pattern for vector.contract operation.
+// Input: vector.contract with tiled dimensions (batch or batch-matmul)
+// Matching Pattern:
+//   scf.for (0 to M) step m_tile {
+//     scf.for (0 to N) step n_tile {
+//       - Subview of Accumulator matrix - eg., acc : memref<m_tilexn_tilexf32>
+//       - %read = vector.transfer_read memref<m_tilexn_tilexf32> to
+//       vector<m_tilexn_tilexf32> %1 = scf.for (0 to reduce)
+//       iter_args_reduce=%read step reduce_tile {
+//          %2 scf.for (0 to K) iter_args_k = %iter_args_reduce step k_tile {
+//	       - Subview of A and B matrix
+//	       - Vector transfer read of A and B
+//	       - %acc = Vector.contract %read_A %read_B %iter_args_k
+//	    scf.yield %acc
+//	    }
+//	 scf.yield %2
+//	 }
+//	 vector.transfer_write %2 into accmulator matrix
+//    }
+// }
+//
+//
+// Rewrite IR:
+//   scf.for (0 to M) step m_tile {
+//     scf.for (0 to N) step n_tile {
+//       - Subview of Accumulator matrix - eg., acc : memref<m_tilexn_tilexf32>
+//       - %a = (n_tile / vector_size) * m_tile;
+//       // load the accumulator matrix as vector
+//       - %0 = load acc[0][0-15] into vector<16xf32>
+//       - %1 = load acc[0][16-31] into vector<16xf32>
+//       - %2 = load acc[1][0-15] into vector<16xf32>
+//       .
+//       .
+//       .
+//       - %a = load acc[m_tile-1][*-n_tile-1] into vector<16xf32>
+//       %3 = scf.for (0 to reduce) iter_args_reduce=%0 to %a step reduce_tile {
+//          %4 scf.for (0 to K) iter_args_k = %iter_args_reduce step k_tile {
+//             - emit nano kernels (as shown in commnets above
+//             generateNanokernels function)
+//          scf.yield %acc[0] to %acc[a-1]
+//          }
+//       scf.yield %4: [0] to [a-1]
+//       }
+//       %5 = vector.insert %3: [0] to [a-1] into vector<m_tilexn_tilexf32>
+//       vector.transfer_write %5 into accmulator matrix
+//    }
+// }
 struct VectorContractNanokernelLowering
     : public OpRewritePattern<vector::ContractionOp> {
   VectorContractNanokernelLowering(MLIRContext *context,
@@ -417,7 +477,6 @@ struct VectorContractNanokernelLowering
     auto loc = contractOp.getLoc();
 
     unsigned int vectorSize = 8;
-
     if (userVectorSize)
       vectorSize = *userVectorSize;
 
@@ -426,14 +485,28 @@ struct VectorContractNanokernelLowering
                                          "Expects add combining kind");
     }
 
-    auto dimCount = contractOp.getRhsType().getRank() + 1;
+    SmallVector<vector::IteratorType> contractIteratorTypes =
+        contractOp.getIteratorTypesArray();
+
+    unsigned int reductionCount =
+        std::count(contractIteratorTypes.begin(), contractIteratorTypes.end(),
+                   vector::IteratorType::reduction);
 
-    if ((dimCount != 3) && (dimCount != 4))
+    MatMulType matmulType = MatMulType::Others;
+
+    if (reductionCount == 1)
+      matmulType = MatMulType::Batch;
+
+    if (reductionCount == 2)
+      matmulType = MatMulType::BatchReduce;
+
+    if ((matmulType != MatMulType::BatchReduce) &&
+        (matmulType != MatMulType::Batch))
       return rewriter.notifyMatchFailure(
           contractOp, "Expects batch-reduce or batch matmuls");
 
     // Get the M, N, K, and batch-reduce loops
-    auto loops = getNestedLoop(contractOp, dimCount);
+    auto loops = getTiledMatmulLoopNest(contractOp, matmulType);
     if (failed(loops))
       return rewriter.notifyMatchFailure(
           contractOp, "Invalid loop nest in contract pattern");
@@ -442,15 +515,21 @@ struct VectorContractNanokernelLowering
     scf::ForOp kForOp = nestedLoops[0];
     scf::ForOp reductionForOp;
 
+    if (contractOp.getAcc().getDefiningOp<vector::TransferReadOp>()) {
+      return rewriter.notifyMatchFailure(
+          contractOp, "The Accumulator matrix should be hoisted outside the K "
+                      "or reduction loop");
+    }
+
     vector::TransferReadOp vectorReadOpAcc;
 
-    if (dimCount == 4) {
+    if (matmulType == MatMulType::BatchReduce) {
       reductionForOp = nestedLoops[1];
       vectorReadOpAcc = reductionForOp.getInitArgs()[0]
                             .getDefiningOp<vector::TransferReadOp>();
     }
 
-    if (dimCount == 3) {
+    if (matmulType == MatMulType::Batch) {
       vectorReadOpAcc =
           kForOp.getInitArgs()[0].getDefiningOp<vector::TransferReadOp>();
     }
@@ -480,10 +559,11 @@ struct VectorContractNanokernelLowering
 
     // The M, N, K, and batch-reduce loop iv should match the iv's
     // used in the subviews
-    auto checkLoops = checkNestedLoop(*loops, subviews, dimCount);
+    auto checkLoops =
+        checkMatmulLoopAndSubviewOffsetsMatching(*loops, subviews, matmulType);
     if (failed(checkLoops))
       return rewriter.notifyMatchFailure(
-          contractOp, "Loops doesn't match the iv in subviews");
+          contractOp, "tiled loops doesn't match the iv in subviews");
 
     auto elementType =
         (cast<MemRefType>(subviewOpLhs.getType())).getElementType();
@@ -505,14 +585,15 @@ struct VectorContractNanokernelLowering
     if (K != 1)
       return rewriter.notifyMatchFailure(contractOp, "The k-dim should be 1");
 
-    if (dimCount == 4 && lhsType.getDimSize(lhsType.getRank() - 3) != 1)
+    if (matmulType == MatMulType::BatchReduce &&
+        lhsType.getDimSize(lhsType.getRank() - 3) != 1)
       return rewriter.notifyMatchFailure(contractOp,
                                          "The reduction-dim should be 1");
 
-    if (dimCount == 4)
+    if (matmulType == MatMulType::BatchReduce)
       rewriter.setInsertionPoint(reductionForOp);
 
-    if (dimCount == 3)
+    if (matmulType == MatMulType::Batch)
       rewriter.setInsertionPoint(kForOp);
 
     // Load  MxN C sub matrix into acc vectors (e.g, <vectorSizexf32>)
@@ -522,7 +603,7 @@ struct VectorContractNanokernelLowering
     // Create the batch-reduce and K-loop with acc vectors as the loop
     // iterargs (batch-reduce matmul) + nanokernel generation
     scf::ForOp newLoop;
-    if (dimCount == 4) {
+    if (matmulType == MatMulType::BatchReduce) {
       newLoop = scf::ForOp::create(
           rewriter, reductionForOp.getLoc(), reductionForOp.getLowerBound(),
           reductionForOp.getUpperBound(), reductionForOp.getStep(),
@@ -533,7 +614,7 @@ struct VectorContractNanokernelLowering
             scf::ForOp newKForOp = createGEMMLoopsWithAccAsIterArgs(
                 rewriter, loc, kForOp, vectorReadOpLhs, vectorReadOpRhs,
                 ivNewReductionForOp, elementType, vectorSize, vnni, M, N,
-                iterArgsNewReductionForOp, dimCount);
+                iterArgsNewReductionForOp, matmulType);
 
             scf::YieldOp::create(rewriterNewReductionForOp,
                                  locNewReductionForOp, newKForOp.getResults());
@@ -541,13 +622,13 @@ struct VectorContractNanokernelLowering
     }
 
     // Create only the K-loop (batch matmul) + nanokernel generation
-    if (dimCount == 3) {
+    if (matmulType == MatMulType::Batch) {
       newLoop = createGEMMLoopsWithAccAsIterArgs(
           rewriter, loc, kForOp, vectorReadOpLhs, vectorReadOpRhs, elementType,
-          vectorSize, vnni, M, N, accumulators, dimCount);
+          vectorSize, vnni, M, N, accumulators, matmulType);
     }
 
-    // Combine all acc vectors into a MxN C matrix
+    // Combine all output accumulator vectors into a m_tilexn_tile C matrix
     auto vecType = VectorType::get({M * N}, rewriter.getF32Type());
     auto zeroAttr =
         DenseElementsAttr::get(vecType, rewriter.getF32FloatAttr(0.0));
@@ -560,10 +641,10 @@ struct VectorContractNanokernelLowering
     auto reshapeAcc = vector::ShapeCastOp::create(rewriter, loc, accTy, accVec);
 
     // Replace all the use of vector.contract with results of nanokernels
-    if (dimCount == 4)
+    if (matmulType == MatMulType::BatchReduce)
       rewriter.replaceAllUsesWith(reductionForOp.getResult(0), reshapeAcc);
 
-    if (dimCount == 3)
+    if (matmulType == MatMulType::Batch)
       rewriter.replaceAllUsesWith(kForOp.getResult(0), reshapeAcc);
 
     return success();
diff --git a/mlir/test/Dialect/X86Vector/vector-contract-to-nanokernels.mlir b/mlir/test/Dialect/X86Vector/vector-contract-to-nanokernels.mlir
index 184ba346e8638..32620657bd52d 100644
--- a/mlir/test/Dialect/X86Vector/vector-contract-to-nanokernels.mlir
+++ b/mlir/test/Dialect/X86Vector/vector-contract-to-nanokernels.mlir
@@ -4,7 +4,7 @@
 #map1 = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
 #map2 = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
 module {
-  func.func @fp32_vectorSize_16(%arg0: memref<1x4x32xf32>, %arg1: memref<1x32x96xf32>, %arg2: memref<4x96xf32>) -> memref<4x96xf32> {
+  func.func @fp32_batch_reduce_matmul_vector_size_16(%arg0: memref<1x4x32xf32>, %arg1: memref<1x32x96xf32>, %arg2: memref<4x96xf32>) -> memref<4x96xf32> {
     %0 = ub.poison : f32
     %c0 = arith.constant 0 : index
     %c4 = arith.constant 4 : index
@@ -33,7 +33,7 @@ module {
   }
 }
 
-// CHECK-LABEL: func.func @fp32_vectorSize_16(
+// CHECK-LABEL: func.func @fp32_batch_reduce_matmul_vector_size_16(
 // CHECK-COUNT-24: vector.fma{{.*}}vector<16xf32>
 // CHECK-NOT: vector.contract
 
@@ -91,3 +91,125 @@ module attributes {transform.with_named_sequence} {
     transform.yield
   }
 }
+
+// -----
+
+module {
+  func.func @not_tiled_no_rewriting(%arg0: memref<1x4x32xf32>, %arg1: memref<1x32x96xf32>, %arg2: memref<4x96xf32>) -> memref<4x96xf32> {
+    %c0 = arith.constant 0 : index
+    %0 = ub.poison : f32
+    %1 = vector.transfer_read %arg0[%c0, %c0, %c0], %0 {in_bounds = [true, true, true]} : memref<1x4x32xf32>, vector<1x4x32xf32>
+    %2 = vector.transfer_read %arg1[%c0, %c0, %c0], %0 {in_bounds = [true, true, true]} : memref<1x32x96xf32>, vector<1x32x96xf32>
+    %3 = vector.transfer_read %arg2[%c0, %c0], %0 {in_bounds = [true, true]} : memref<4x96xf32>, vector<4x96xf32>
+    %4 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>, affine_map<(d0, d1, d2, d3) -> (d1, d2)>], iterator_types = ["reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %1, %2, %3 : vector<1x4x32xf32>, vector<1x32x96xf32> into vector<4x96xf32>
+    vector.transfer_write %4, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<4x96xf32>, memref<4x96xf32>
+    return %arg2 : memref<4x96xf32>
+  }
+}
+
+// CHECK-LABEL: func.func @not_tiled_no_rewriting(
+// CHECK-NOT: vector.fma{{.*}}vector<8xf32>
+// CHECK: vector.contract
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.x86vector.vector_contract_nanokernel_lowering vector_size = 8
+    } : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+module {
+  func.func @tensor_type_no_rewriting(%arg0: tensor<32x32x32xf32>, %arg1: tensor<32x32x32xf32>, %arg2: tensor<32x32xf32>) -> tensor<32x32xf32> {
+    %0 = ub.poison : f32
+    %c0 = arith.constant 0 : index
+    %c32 = arith.constant 32 : index
+    %c4 = arith.constant 4 : index
+    %c16 = arith.constant 16 : index
+    %c1 = arith.constant 1 : index
+    %1 = scf.for %arg3 = %c0 to %c32 step %c4 iter_args(%arg4 = %arg2) -> (tensor<32x32xf32>) {
+      %2 = scf.for %arg5 = %c0 to %c32 step %c16 iter_args(%arg6 = %arg4) -> (tensor<32x32xf32>) {
+        %3 = scf.for %arg7 = %c0 to %c32 step %c1 iter_args(%arg8 = %arg6) -> (tensor<32x32xf32>) {
+          %4 = scf.for %arg9 = %c0 to %c32 step %c1 iter_args(%arg10 = %arg8) -> (tensor<32x32xf32>) {
+            %extracted_slice = tensor.extract_slice %arg0[%arg7, %arg3, %arg9] [1, 4, 1] [1, 1, 1] : tensor<32x32x32xf32> to tensor<1x4x1xf32>
+            %extracted_slice_0 = tensor.extract_slice %arg1[%arg7, %arg9, %arg5] [1, 1, 16] [1, 1, 1] : tensor<32x32x32xf32> to tensor<1x1x16xf32>
+            %extracted_slice_1 = tensor.extract_slice %arg10[%arg3, %arg5] [4, 16] [1, 1] : tensor<32x32xf32> to tensor<4x16xf32>
+            %5 = vector.transfer_read %extracted_slice[%c0, %c0, %c0], %0 {in_bounds = [true, true, true]} : tensor<1x4x1xf32>, vector<1x4x1xf32>
+            %6 = vector.transfer_read %extracted_slice_0[%c0, %c0, %c0], %0 {in_bounds = [true, true, true]} : tensor<1x1x16xf32>, vector<1x1x16xf32>
+            %7 = vector.transfer_read %extracted_slice_1[%c0, %c0], %0 {in_bounds = [true, true]} : tensor<4x16xf32>, vector<4x16xf32>
+            %8 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>, affine_map<(d0, d1, d2, d3) -> (d1, d2)>], iterator_types = ["reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %5, %6, %7 : vector<1x4x1xf32>, vector<1x1x16xf32> into vector<4x16xf32>
+            %9 = vector.transfer_write %8, %extracted_slice_1[%c0, %c0] {in_bounds = [true, true]} : vector<4x16xf32>, tensor<4x16xf32>
+            %inserted_slice = tensor.insert_slice %9 into %arg10[%arg3, %arg5] [4, 16] [1, 1] : tensor<4x16xf32> into tensor<32x32xf32>
+            scf.yield %inserted_slice : tensor<32x32xf32>
+          }
+          scf.yield %4 : tensor<32x32xf32>
+        }
+        scf.yield %3 : tensor<32x32xf32>
+      }
+      scf.yield %2 : tensor<32x32xf32>
+    }
+    return %1 : tensor<32x32xf32>
+  }
+}
+
+// CHECK-LABEL: func.func @tensor_type_no_rewriting(
+// CHECK-NOT: vector.fma{{.*}}vector<16xf32>
+// CHECK: vector.contract
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.x86vector.vector_contract_nanokernel_lowering vector_size = 16
+    } : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+module {
+  func.func @accumulator_not_hoisted_outside_K_or_reduction_loop(%arg0: memref<1x4x32xf32>, %arg1: memref<1x32x96xf32>, %arg2: memref<4x96xf32>) -> memref<4x96xf32> {
+    %0 = ub.poison : f32
+    %c0 = arith.constant 0 : index
+    %c4 = arith.constant 4 : index
+    %c96 = arith.constant 96 : index
+    %c32 = arith.constant 32 : index
+    %c1 = arith.constant 1 : index
+    scf.for %arg3 = %c0 to %c4 step %c4 {
+      scf.for %arg4 = %c0 to %c96 step %c32 {
+        %subview = memref.subview %arg2[%arg3, %arg4] [4, 32] [1, 1] : memref<4x96xf32> to memref<4x32xf32, strided<[96, 1], offset: ?>>
+        scf.for %arg5 = %c0 to %c1 step %c1 {
+          scf.for %arg6 = %c0 to %c32 step %c1 {
+            %subview_0 = memref.subview %arg0[%arg5, %arg3, %arg6] [1, 4, 1] [1, 1, 1] : memref<1x4x32xf32> to memref<1x4x1xf32, strided<[128, 32, 1], offset: ?>>
+            %subview_1 = memref.subview %arg1[%arg5, %arg6, %arg4] [1, 1, 32] [1, 1, 1] : memref<1x32x96xf32> to memref<1x1x32xf32, strided<[3072, 96, 1], offset: ?>>
+            %1 = vector.transfer_read %subview_0[%c0, %c0, %c0], %0 {in_bounds = [true, true, true]} : memref<1x4x1xf32, strided<[128, 32, 1], offset: ?>>, vector<1x4x1xf32>
+            %2 = vector.transfer_read %subview_1[%c0, %c0, %c0], %0 {in_bounds = [true, true, true]} : memref<1x1x32xf32, strided<[3072, 96, 1], offset: ?>>, vector<1x1x32xf32>
+            %3 = vector.transfer_read %subview[%c0, %c0], %0 {in_bounds = [true, true]} : memref<4x32xf32, strided<[96, 1], offset: ?>>, vector<4x32xf32>
+            %4 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>, affine_map<(d0, d1, d2, d3) -> (d1, d2)>], iterator_types = ["reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %1, %2, %3 : vector<1x4x1xf32>, vector<1x1x32xf32> into vector<4x32xf32>
+            vector.transfer_write %4, %subview[%c0, %c0] {in_bounds = [true, true]} : vector<4x32xf32>, memref<4x32xf32, strided<[96, 1], offset: ?>>
+          }
+        }
+      }
+    }
+    return %arg2 : memref<4x96xf32>
+  }
+}
+
+// CHECK-LABEL: func.func @accumulator_not_hoisted_outside_K_or_reduction_loop(
+// CHECK-NOT: vector.fma{{.*}}vector<16xf32>
+// CHECK: vector.contract
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.x86vector.vector_contract_nanokernel_lowering vector_size = 16
+    } : !transform.any_op
+    transform.yield
+  }
+}



More information about the Mlir-commits mailing list