[Mlir-commits] [mlir] [mlir][x86vector] Lower BF16 vector.contract to FMA using AVX2 BF16 packed ops. (PR #170267)

Arun Thangamani llvmlistbot at llvm.org
Tue Dec 16 03:24:55 PST 2025


https://github.com/arun-thmn updated https://github.com/llvm/llvm-project/pull/170267

>From fe09e05837b2ee469da250562e45c4bfa15865ff Mon Sep 17 00:00:00 2001
From: Arun Thangamani <arun.thangamani at intel.com>
Date: Tue, 2 Dec 2025 01:14:35 -0800
Subject: [PATCH 1/8] bf16 vector.contract lowering to vector.fma using AVX2
 BF16 packed ops.

---
 .../TransformOps/X86VectorTransformOps.td     |  10 +
 .../mlir/Dialect/X86Vector/Transforms.h       |   2 +
 .../Dialect/X86Vector/Utils/X86VectorUtils.h  |  43 ++++
 mlir/lib/Dialect/X86Vector/CMakeLists.txt     |   1 +
 .../TransformOps/X86VectorTransformOps.cpp    |   5 +
 .../X86Vector/Transforms/CMakeLists.txt       |   4 +-
 .../Transforms/VectorContractBF16ToFMA.cpp    | 225 +++++++++++++++++
 .../VectorContractToPackedTypeDotProduct.cpp  |  87 +------
 .../Dialect/X86Vector/Utils/CMakeLists.txt    |  13 +
 .../X86Vector/Utils/X86VectorUtils.cpp        | 110 +++++++++
 .../vector-contract-bf16-to-fma.mlir          | 229 ++++++++++++++++++
 11 files changed, 642 insertions(+), 87 deletions(-)
 create mode 100644 mlir/include/mlir/Dialect/X86Vector/Utils/X86VectorUtils.h
 create mode 100644 mlir/lib/Dialect/X86Vector/Transforms/VectorContractBF16ToFMA.cpp
 create mode 100644 mlir/lib/Dialect/X86Vector/Utils/CMakeLists.txt
 create mode 100644 mlir/lib/Dialect/X86Vector/Utils/X86VectorUtils.cpp
 create mode 100644 mlir/test/Dialect/X86Vector/vector-contract-bf16-to-fma.mlir

diff --git a/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td b/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td
index 3c5294ff14fc7..b4f1c0e12ef9b 100644
--- a/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td
+++ b/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td
@@ -38,6 +38,16 @@ def ApplyVectorContractToPackedTypeDotProductPatternsOp : Op<Transform_Dialect,
   let assemblyFormat = "attr-dict";
 }
 
+def ApplyVectorContractBF16ToFMAPatternsOp : Op<Transform_Dialect,
+    "apply_patterns.x86vector.vector_contract_bf16_to_fma",
+    [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+  let description = [{
+    Indicates that BF16 vector contract operation can be lowered to a FMA.
+  }];
+
+  let assemblyFormat = "attr-dict";
+}
+
 
 #endif // X86VECTOR_TRANSFORM_OPS
 
diff --git a/mlir/include/mlir/Dialect/X86Vector/Transforms.h b/mlir/include/mlir/Dialect/X86Vector/Transforms.h
index fc46dff63c2b7..aba903845e429 100644
--- a/mlir/include/mlir/Dialect/X86Vector/Transforms.h
+++ b/mlir/include/mlir/Dialect/X86Vector/Transforms.h
@@ -91,6 +91,8 @@ void populateVectorContractToFMAPatterns(RewritePatternSet &patterns);
 void populateVectorContractToPackedTypeDotProductPatterns(
     RewritePatternSet &patterns);
 
+void populateVectorContractBF16ToFMAPatterns(RewritePatternSet &patterns);
+
 //===----------------------------------------------------------------------===//
 /// Helpers extracted from:
 ///   - clang/lib/Headers/avxintrin.h
diff --git a/mlir/include/mlir/Dialect/X86Vector/Utils/X86VectorUtils.h b/mlir/include/mlir/Dialect/X86Vector/Utils/X86VectorUtils.h
new file mode 100644
index 0000000000000..8a76009ddb907
--- /dev/null
+++ b/mlir/include/mlir/Dialect/X86Vector/Utils/X86VectorUtils.h
@@ -0,0 +1,43 @@
+//===- X86VectorUtils.h - X86Vector Utilities -------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_X86VECTOR_UTILS_X86VECTORUTILS_H_
+#define MLIR_DIALECT_X86VECTOR_UTILS_X86VECTORUTILS_H_
+
+#include "mlir/Support/LogicalResult.h"
+#include "llvm/ADT/ArrayRef.h"
+#include <cstdint>
+#include <optional>
+#include <string>
+
+namespace mlir {
+class Type;
+class ShapedType;
+class OpOperand;
+class AffineDimExpr;
+class AffineMap;
+class VectorType;
+class Operation;
+
+namespace x86vector {
+enum class VnniOperandRank {
+  TRANSPOSE = 3,
+  GEMM = 3,
+  BRGEMM_INS = 4,
+  BRGEMM_OUTS = 3
+};
+
+// Return true if the operation is in VNNI layout.
+// Optionally, the check can be constrained to a specific VNNI blocking factor.
+bool isInVnniLayout(Operation *op, llvm::ArrayRef<AffineMap> indexingMaps,
+                    std::optional<unsigned> blockingFactor = std::nullopt);
+
+} // namespace x86vector
+} // namespace mlir
+
+#endif // MLIR_DIALECT_X86VECTOR_UTILS_X86VECTORUTILS_H_
diff --git a/mlir/lib/Dialect/X86Vector/CMakeLists.txt b/mlir/lib/Dialect/X86Vector/CMakeLists.txt
index cb1e9d01821a2..329a6c3e80254 100644
--- a/mlir/lib/Dialect/X86Vector/CMakeLists.txt
+++ b/mlir/lib/Dialect/X86Vector/CMakeLists.txt
@@ -1,3 +1,4 @@
 add_subdirectory(IR)
 add_subdirectory(Transforms)
 add_subdirectory(TransformOps)
+add_subdirectory(Utils)
diff --git a/mlir/lib/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp b/mlir/lib/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp
index 95db208207672..172f159b43f80 100644
--- a/mlir/lib/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp
+++ b/mlir/lib/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp
@@ -32,6 +32,11 @@ void mlir::transform::ApplyVectorContractToPackedTypeDotProductPatternsOp::
   x86vector::populateVectorContractToPackedTypeDotProductPatterns(patterns);
 }
 
+void mlir::transform::ApplyVectorContractBF16ToFMAPatternsOp::populatePatterns(
+    RewritePatternSet &patterns) {
+  x86vector::populateVectorContractBF16ToFMAPatterns(patterns);
+}
+
 //===----------------------------------------------------------------------===//
 // Transform op registration
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/X86Vector/Transforms/CMakeLists.txt b/mlir/lib/Dialect/X86Vector/Transforms/CMakeLists.txt
index 2cab50fb591c4..9eb94691753cf 100644
--- a/mlir/lib/Dialect/X86Vector/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/X86Vector/Transforms/CMakeLists.txt
@@ -3,14 +3,16 @@ add_mlir_dialect_library(MLIRX86VectorTransforms
   LegalizeForLLVMExport.cpp
   VectorContractToFMA.cpp
   VectorContractToPackedTypeDotProduct.cpp
+  VectorContractBF16ToFMA.cpp
 
   LINK_LIBS PUBLIC
   MLIRArithDialect
-  MLIRX86VectorDialect
   MLIRIR
   MLIRLinalgDialect
   MLIRLLVMCommonConversion
   MLIRLLVMDialect
   MLIRVectorDialect
   MLIRVectorUtils
+  MLIRX86VectorDialect
+  MLIRX86VectorUtils
   )
diff --git a/mlir/lib/Dialect/X86Vector/Transforms/VectorContractBF16ToFMA.cpp b/mlir/lib/Dialect/X86Vector/Transforms/VectorContractBF16ToFMA.cpp
new file mode 100644
index 0000000000000..58cc7f7497044
--- /dev/null
+++ b/mlir/lib/Dialect/X86Vector/Transforms/VectorContractBF16ToFMA.cpp
@@ -0,0 +1,225 @@
+//===- VectorContractBF16ToFMA.cpp
+//--------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
+#include "mlir/Dialect/X86Vector/Transforms.h"
+#include "mlir/Dialect/X86Vector/Utils/X86VectorUtils.h"
+#include "mlir/Dialect/X86Vector/X86VectorDialect.h"
+
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/Dominance.h"
+#include "mlir/IR/PatternMatch.h"
+
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+using namespace mlir;
+using namespace mlir::vector;
+using namespace mlir::x86vector;
+
+static FailureOr<llvm::SmallVector<mlir::memref::SubViewOp>>
+getSubviewFromVectorInput(Location loc, PatternRewriter &rewriter,
+                          mlir::Value prodOp, int64_t mnDim, int64_t vnniDim,
+                          int64_t mnDimIndx) {
+
+  llvm::SmallVector<mlir::memref::SubViewOp> subviews;
+
+  Value srcOperation;
+  SmallVector<OpFoldResult> indexVals;
+
+  if (auto transferRead =
+          prodOp.getDefiningOp<mlir::vector::TransferReadOp>()) {
+    srcOperation = transferRead.getOperand(0);
+    SmallVector<OpFoldResult> indexValues(transferRead.getIndices().begin(),
+                                          transferRead.getIndices().end());
+    indexVals = indexValues;
+  }
+
+  if (auto load = prodOp.getDefiningOp<mlir::vector::LoadOp>()) {
+    srcOperation = load.getOperand(0);
+    SmallVector<OpFoldResult> indexValues(load.getIndices().begin(),
+                                          load.getIndices().end());
+    indexVals = indexValues;
+  }
+
+  if (!srcOperation)
+    return failure();
+
+  llvm::SmallVector<OpFoldResult> strides;
+  llvm::SmallVector<OpFoldResult> sizes;
+
+  for (unsigned int i = 0; i < indexVals.size(); i++) {
+    strides.push_back(rewriter.getIndexAttr(1));
+    sizes.push_back(rewriter.getIndexAttr(1));
+  }
+
+  sizes[indexVals.size() - 1] = rewriter.getIndexAttr(vnniDim);
+  sizes[indexVals.size() - mnDimIndx] = rewriter.getIndexAttr(mnDim);
+
+  if (mnDim == 1) {
+    indexVals[indexVals.size() - 1] = rewriter.getIndexAttr(1);
+  }
+
+  auto subview = memref::SubViewOp::create(rewriter, loc, srcOperation,
+                                           indexVals, sizes, strides);
+  subviews.push_back(subview);
+
+  if (mnDim == 1) {
+    indexVals[indexVals.size() - 1] = rewriter.getIndexAttr(0);
+    sizes[indexVals.size() - 1] = rewriter.getIndexAttr(1);
+
+    auto unitDimEvenIndxSubview = memref::SubViewOp::create(
+        rewriter, loc, srcOperation, indexVals, sizes, strides);
+    subviews.push_back(unitDimEvenIndxSubview);
+  }
+
+  return subviews;
+}
+
+struct VectorContractBF16ToFMA
+    : public OpRewritePattern<vector::ContractionOp> {
+  using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
+                                PatternRewriter &rewriter) const override {
+
+    if (contractOp.getKind() != vector::CombiningKind::ADD)
+      return rewriter.notifyMatchFailure(contractOp,
+                                         "Expects add combining kind.");
+
+    VectorType lhsTy = contractOp.getLhsType();
+    if (!lhsTy.getElementType().isBF16())
+      return rewriter.notifyMatchFailure(contractOp,
+                                         "Only BF16 lowering is supported.");
+
+    if (!isInVnniLayout(contractOp.getOperation(),
+                        contractOp.getIndexingMapsArray(), 2))
+      return rewriter.notifyMatchFailure(contractOp,
+                                         "Input matrices not in VNNI format.");
+
+    ArrayRef<int64_t> lhsShape = lhsTy.getShape();
+    llvm::SmallVector<int64_t> nonUnitDimLhs;
+    llvm::copy_if(lhsShape, std::back_inserter(nonUnitDimLhs),
+                  [](int64_t dim) { return dim != 1; });
+
+    VectorType rhsTy = contractOp.getRhsType();
+    ArrayRef<int64_t> rhsShape = rhsTy.getShape();
+    llvm::SmallVector<int64_t> nonUnitDimRhs;
+    llvm::copy_if(rhsShape, std::back_inserter(nonUnitDimRhs),
+                  [](int64_t dim) { return dim != 1; });
+
+    if ((nonUnitDimLhs.size() - 1) > 0 && (nonUnitDimRhs.size() - 1) > 0)
+      return rewriter.notifyMatchFailure(contractOp,
+                                         "Excepts unit dimensions for either "
+                                         "LHS or RHS shape other than VNNI.");
+
+    if ((nonUnitDimLhs.size() - 1) != 1 && (nonUnitDimRhs.size() - 1) != 1)
+      return rewriter.notifyMatchFailure(
+          contractOp,
+          "Excepts a one non-unit A/B dimension for either LHS or RHS shape.");
+
+    VectorType accTy = dyn_cast<VectorType>(contractOp.getAccType());
+    if (!accTy)
+      return rewriter.notifyMatchFailure(contractOp, "Wrong accmulator type.");
+
+    if ((lhsTy.getElementType().isBF16() && !accTy.getElementType().isF32()))
+      return rewriter.notifyMatchFailure(
+          contractOp, "Only F32 acumulation supported for BF16 type.");
+
+    ArrayRef<int64_t> accShape = accTy.getShape();
+    llvm::SmallVector<int64_t> nonUnitDimAcc;
+    llvm::copy_if(accShape, std::back_inserter(nonUnitDimAcc),
+                  [](int64_t dim) { return dim != 1; });
+    if (nonUnitDimAcc.size() != 1)
+      return rewriter.notifyMatchFailure(
+          contractOp, "A or B should be a non-unit dim in acc.");
+
+    // Non-unit dimensions should match the vector length of BF16 or Int8
+    // dot-product.
+    unsigned int nonUnitDim = nonUnitDimLhs.size() == 2 ? nonUnitDimLhs.front()
+                                                        : nonUnitDimRhs.front();
+    if (nonUnitDim != 4 && nonUnitDim != 8 &&
+        !(nonUnitDimAcc.front() == nonUnitDim))
+      return rewriter.notifyMatchFailure(
+          contractOp, "BF16 packed load operation expects non-unit (LHR or "
+                      "RHS) dim and acc dim of size 4/8.");
+
+    auto loc = contractOp.getLoc();
+    auto castAcc = vector::ShapeCastOp::create(
+        rewriter, loc,
+        VectorType::get(nonUnitDimAcc.front(), accTy.getElementType()),
+        contractOp.getAcc());
+    mlir::VectorType dstType =
+        mlir::VectorType::get(nonUnitDimAcc.front(), rewriter.getF32Type());
+
+    llvm::SmallVector<mlir::memref::SubViewOp> unitDimSubview;
+    llvm::SmallVector<mlir::memref::SubViewOp> nonUnitDimSubview;
+
+    if ((nonUnitDimRhs.size() - 1) > 0) {
+
+      auto unitSubview = getSubviewFromVectorInput(
+          loc, rewriter, contractOp.getLhs(), 1, 1, 2);
+      auto nonUnitSubview = getSubviewFromVectorInput(
+          loc, rewriter, contractOp.getRhs(), nonUnitDimRhs.front(), 2, 2);
+      if (failed(unitSubview) || failed(nonUnitSubview))
+        return failure();
+
+      unitDimSubview = *unitSubview;
+      nonUnitDimSubview = *nonUnitSubview;
+
+    } else {
+      auto nonUnitSubview = getSubviewFromVectorInput(
+          loc, rewriter, contractOp.getLhs(), nonUnitDimRhs.front(), 2, 3);
+      auto unitSubview = getSubviewFromVectorInput(
+          loc, rewriter, contractOp.getRhs(), 1, 1, 2);
+      if (failed(unitSubview) || failed(nonUnitSubview))
+        return failure();
+
+      unitDimSubview = *unitSubview;
+      nonUnitDimSubview = *nonUnitSubview;
+    }
+
+    auto loadBcstOddIndxElementToF32 = x86vector::BcstToPackedF32Op::create(
+        rewriter, loc, dstType, unitDimSubview[0]);
+    auto loadOddIndxElementF32 = x86vector::CvtPackedOddIndexedToF32Op::create(
+        rewriter, loc, dstType, nonUnitDimSubview[0]);
+    auto oddIndxFMA =
+        vector::FMAOp::create(rewriter, loc, loadBcstOddIndxElementToF32,
+                              loadOddIndxElementF32, castAcc);
+
+    llvm::SmallVector<Operation *> users;
+    for (OpResult result : contractOp->getResults())
+      for (Operation *user : result.getUsers())
+        users.push_back(user);
+
+    if (users.size() == 1) {
+      rewriter.setInsertionPoint(users[0]);
+    }
+
+    auto loadBcstEvenIndxElementToF32 = x86vector::BcstToPackedF32Op::create(
+        rewriter, loc, dstType, unitDimSubview[1]);
+    auto loadEvenIndxElementF32 =
+        x86vector::CvtPackedEvenIndexedToF32Op::create(rewriter, loc, dstType,
+                                                       nonUnitDimSubview[0]);
+
+    vector::FMAOp fma =
+        vector::FMAOp::create(rewriter, loc, loadBcstEvenIndxElementToF32,
+                              loadEvenIndxElementF32, oddIndxFMA);
+
+    auto castFma = vector::ShapeCastOp::create(rewriter, loc, accTy, fma);
+    rewriter.replaceOp(contractOp, castFma);
+    return success();
+  }
+};
+
+void x86vector::populateVectorContractBF16ToFMAPatterns(
+    RewritePatternSet &patterns) {
+  patterns.add<VectorContractBF16ToFMA>(patterns.getContext());
+}
diff --git a/mlir/lib/Dialect/X86Vector/Transforms/VectorContractToPackedTypeDotProduct.cpp b/mlir/lib/Dialect/X86Vector/Transforms/VectorContractToPackedTypeDotProduct.cpp
index 1e64811db910b..a00a3e5bdd766 100644
--- a/mlir/lib/Dialect/X86Vector/Transforms/VectorContractToPackedTypeDotProduct.cpp
+++ b/mlir/lib/Dialect/X86Vector/Transforms/VectorContractToPackedTypeDotProduct.cpp
@@ -11,6 +11,7 @@
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/Dialect/Vector/Utils/VectorUtils.h"
 #include "mlir/Dialect/X86Vector/Transforms.h"
+#include "mlir/Dialect/X86Vector/Utils/X86VectorUtils.h"
 #include "mlir/Dialect/X86Vector/X86VectorDialect.h"
 
 #include "mlir/IR/BuiltinAttributes.h"
@@ -26,92 +27,6 @@ using namespace mlir::x86vector;
 
 namespace {
 
-static FailureOr<SmallVector<mlir::utils::IteratorType>>
-inferIteratorsFromOutMap(AffineMap map) {
-  if (!map.isProjectedPermutation())
-    return failure();
-  SmallVector<mlir::utils::IteratorType> iterators(
-      map.getNumDims(), mlir::utils::IteratorType::reduction);
-  for (auto expr : map.getResults())
-    if (auto dim = dyn_cast<AffineDimExpr>(expr))
-      iterators[dim.getPosition()] = mlir::utils::IteratorType::parallel;
-  return iterators;
-}
-
-// Returns true if the operation is in VNNI layout.
-// Optionally, the check can be constrained to a specific VNNI blocking factor.
-static bool isInVnniLayout(Operation *op, ArrayRef<AffineMap> indexingMaps,
-                           std::optional<unsigned> blockingFactor) {
-  // Narrow down type operations - VNNI only applies to contractions.
-  FailureOr<linalg::ContractionDimensions> dims =
-      linalg::inferContractionDims(indexingMaps);
-  if (failed(dims))
-    return false;
-
-  auto matA = op->getOperand(0);
-  auto matB = op->getOperand(1);
-  auto typeA = dyn_cast<ShapedType>(matA.getType());
-  auto typeB = dyn_cast<ShapedType>(matB.getType());
-  unsigned rankA = typeA.getRank();
-  unsigned rankB = typeB.getRank();
-  // VNNI format requires at least 1 parallel and 2 reduction dimensions.
-  if (rankA < 3 || rankB < 3)
-    return false;
-
-  // At least two reduction dimensions are expected:
-  // one for the VNNI factor and one for the K dimension
-  if (dims->k.size() < 2)
-    return false;
-
-  // Validate affine maps - VNNI computation should be defined by the two
-  // innermost reduction iterators.
-  // The input matrix dimensions layout must match the following:
-  //   - matrix A - [...][K/vnniFactor][vnniFactor]
-  //   - matrix B - [...][K/vnniFactor][N][vnniFactor]
-  auto maybeIters = inferIteratorsFromOutMap(indexingMaps[2]);
-  if (failed(maybeIters))
-    return false;
-  SmallVector<mlir::utils::IteratorType> iteratorTypes = *maybeIters;
-  AffineMap mapA = indexingMaps[0];
-  AffineMap mapB = indexingMaps[1];
-
-  auto vnniDimA = dyn_cast<AffineDimExpr>(mapA.getResult(rankA - 1));
-  auto vnniDimB = dyn_cast<AffineDimExpr>(mapB.getResult(rankB - 1));
-  if (!vnniDimA || !vnniDimB || vnniDimA != vnniDimB ||
-      iteratorTypes[vnniDimA.getPosition()] !=
-          mlir::utils::IteratorType::reduction)
-    return false;
-  auto redDimA = dyn_cast<AffineDimExpr>(mapA.getResult(rankA - 2));
-  auto redDimB = dyn_cast<AffineDimExpr>(mapB.getResult(rankB - 3));
-  if (!redDimA || !redDimB || redDimA != redDimB ||
-      iteratorTypes[redDimA.getPosition()] !=
-          mlir::utils::IteratorType::reduction)
-    return false;
-  auto parallelDimB = dyn_cast<AffineDimExpr>(mapB.getResult(rankB - 2));
-  if (!parallelDimB || iteratorTypes[parallelDimB.getPosition()] !=
-                           mlir::utils::IteratorType::parallel)
-    return false;
-
-  // VNNI factor must be:
-  //   - the innermost inputs' dimension
-  //   - statically known
-  //   - multiple of 2 or equal to the specified factor
-  auto vnniDimSize = typeB.getShape().back();
-  if (vnniDimSize == ShapedType::kDynamic || vnniDimSize == 0 ||
-      vnniDimSize % 2 != 0)
-    return false;
-  if (typeA.getShape().back() != vnniDimSize)
-    return false;
-  if (blockingFactor && vnniDimSize != *blockingFactor)
-    return false;
-
-  // The split reduction dimension size should also match.
-  if (typeA.getShape().end()[-2] != typeB.getShape().end()[-3])
-    return false;
-
-  return true;
-}
-
 // Implements packed type outer product contraction as a sequence
 // of broadcast and packed dot-product operations.
 //
diff --git a/mlir/lib/Dialect/X86Vector/Utils/CMakeLists.txt b/mlir/lib/Dialect/X86Vector/Utils/CMakeLists.txt
new file mode 100644
index 0000000000000..6a2da861737ed
--- /dev/null
+++ b/mlir/lib/Dialect/X86Vector/Utils/CMakeLists.txt
@@ -0,0 +1,13 @@
+add_mlir_dialect_library(MLIRX86VectorUtils
+  X86VectorUtils.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/X86Vector/Utils
+
+  LINK_LIBS PUBLIC
+  MLIRAffineDialect
+  MLIRDialectUtils
+  MLIRFuncDialect
+  MLIRIR
+  MLIRVectorDialect
+  )
diff --git a/mlir/lib/Dialect/X86Vector/Utils/X86VectorUtils.cpp b/mlir/lib/Dialect/X86Vector/Utils/X86VectorUtils.cpp
new file mode 100644
index 0000000000000..e06307eeedcdb
--- /dev/null
+++ b/mlir/lib/Dialect/X86Vector/Utils/X86VectorUtils.cpp
@@ -0,0 +1,110 @@
+//===- X86VectorUtils.cpp - MLIR Utilities for X86VectorOps   -------------===//
+//
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements utility methods for working with the Vector dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/X86Vector/Utils/X86VectorUtils.h"
+
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/TypeUtilities.h"
+#include "mlir/IR/Types.h"
+
+#define DEBUG_TYPE "x86vector-utils"
+
+using namespace mlir;
+
+static FailureOr<SmallVector<mlir::utils::IteratorType>>
+inferIteratorsFromOutMap(AffineMap map) {
+  if (!map.isProjectedPermutation())
+    return failure();
+  SmallVector<mlir::utils::IteratorType> iterators(
+      map.getNumDims(), mlir::utils::IteratorType::reduction);
+  for (auto expr : map.getResults())
+    if (auto dim = dyn_cast<AffineDimExpr>(expr))
+      iterators[dim.getPosition()] = mlir::utils::IteratorType::parallel;
+  return iterators;
+}
+
+// Returns true if the operation is in VNNI layout.
+// Optionally, the check can be constrained to a specific VNNI blocking factor.
+bool mlir::x86vector::isInVnniLayout(Operation *op, ArrayRef<AffineMap> indexingMaps,
+                           std::optional<unsigned> blockingFactor) {
+  // Narrow down type operations - VNNI only applies to contractions.
+  FailureOr<linalg::ContractionDimensions> dims =
+      linalg::inferContractionDims(indexingMaps);
+  if (failed(dims))
+    return false;
+
+  auto matA = op->getOperand(0);
+  auto matB = op->getOperand(1);
+  auto typeA = dyn_cast<ShapedType>(matA.getType());
+  auto typeB = dyn_cast<ShapedType>(matB.getType());
+  unsigned rankA = typeA.getRank();
+  unsigned rankB = typeB.getRank();
+  // VNNI format requires at least 1 parallel and 2 reduction dimensions.
+  if (rankA < 3 || rankB < 3)
+    return false;
+
+  // At least two reduction dimensions are expected:
+  // one for the VNNI factor and one for the K dimension
+  if (dims->k.size() < 2)
+    return false;
+
+  // Validate affine maps - VNNI computation should be defined by the two
+  // innermost reduction iterators.
+  // The input matrix dimensions layout must match the following:
+  //   - matrix A - [...][K/vnniFactor][vnniFactor]
+  //   - matrix B - [...][K/vnniFactor][N][vnniFactor]
+  auto maybeIters = inferIteratorsFromOutMap(indexingMaps[2]);
+  if (failed(maybeIters))
+    return false;
+  SmallVector<mlir::utils::IteratorType> iteratorTypes = *maybeIters;
+  AffineMap mapA = indexingMaps[0];
+  AffineMap mapB = indexingMaps[1];
+
+  auto vnniDimA = dyn_cast<AffineDimExpr>(mapA.getResult(rankA - 1));
+  auto vnniDimB = dyn_cast<AffineDimExpr>(mapB.getResult(rankB - 1));
+  if (!vnniDimA || !vnniDimB || vnniDimA != vnniDimB ||
+      iteratorTypes[vnniDimA.getPosition()] !=
+          mlir::utils::IteratorType::reduction)
+    return false;
+  auto redDimA = dyn_cast<AffineDimExpr>(mapA.getResult(rankA - 2));
+  auto redDimB = dyn_cast<AffineDimExpr>(mapB.getResult(rankB - 3));
+  if (!redDimA || !redDimB || redDimA != redDimB ||
+      iteratorTypes[redDimA.getPosition()] !=
+          mlir::utils::IteratorType::reduction)
+    return false;
+  auto parallelDimB = dyn_cast<AffineDimExpr>(mapB.getResult(rankB - 2));
+  if (!parallelDimB || iteratorTypes[parallelDimB.getPosition()] !=
+                           mlir::utils::IteratorType::parallel)
+    return false;
+
+  // VNNI factor must be:
+  //   - the innermost inputs' dimension
+  //   - statically known
+  //   - multiple of 2 or equal to the specified factor
+  auto vnniDimSize = typeB.getShape().back();
+  if (vnniDimSize == ShapedType::kDynamic || vnniDimSize == 0 ||
+      vnniDimSize % 2 != 0)
+    return false;
+  if (typeA.getShape().back() != vnniDimSize)
+    return false;
+  if (blockingFactor && vnniDimSize != *blockingFactor)
+    return false;
+
+  // The split reduction dimension size should also match.
+  if (typeA.getShape().end()[-2] != typeB.getShape().end()[-3])
+    return false;
+
+  return true;
+}
diff --git a/mlir/test/Dialect/X86Vector/vector-contract-bf16-to-fma.mlir b/mlir/test/Dialect/X86Vector/vector-contract-bf16-to-fma.mlir
new file mode 100644
index 0000000000000..0bee34f23a8a4
--- /dev/null
+++ b/mlir/test/Dialect/X86Vector/vector-contract-bf16-to-fma.mlir
@@ -0,0 +1,229 @@
+// RUN: mlir-opt %s -transform-interpreter -cse -split-input-file | FileCheck %s
+
+!vecA = vector<1x1x1x2xbf16>
+!vecB = vector<1x1x8x2xbf16>
+!vecC = vector<1x8xf32>
+!memrefA = memref<1x4x1x2xbf16>
+!memrefB = memref<1x1x32x2xbf16>
+#map = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d3, d4)>
+#map1 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d3, d2, d4)>
+#map2 = affine_map<(d0, d4, d1, d2, d3) -> (d1, d2)>
+func.func @brgemm_to_fma(
+  %arg0: !memrefA, %arg1: !memrefB, %arg2: !vecC) -> !vecC
+{
+  %c0 = arith.constant 0 : index
+  %0 = ub.poison : bf16
+  %1 = vector.transfer_read %arg0[%c0, %c0, %c0, %c0], %0 {in_bounds = [true, true, true, true]} : 
+        !memrefA, !vecA
+  %2 = vector.transfer_read %arg1[%c0, %c0, %c0, %c0], %0 {in_bounds = [true, true, true, true]} :
+        !memrefB, !vecB
+  %3 = vector.contract {
+    indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["reduction", "reduction", "parallel", "parallel", "reduction"],
+    kind = #vector.kind<add>}
+    %1, %2, %arg2
+    : !vecA, !vecB into !vecC
+  return %3 : !vecC
+}
+
+// CHECK-LABEL: @brgemm_to_fma
+// CHECK: x86vector.avx.bcst_to_f32.packed
+// CHECK: x86vector.avx.cvt.packed.odd.indexed_to_f32
+// CHECK: vector.fma
+// CHECK: x86vector.avx.bcst_to_f32.packed
+// CHECK: x86vector.avx.cvt.packed.even.indexed_to_f32
+// CHECK: vector.fma
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.x86vector.vector_contract_bf16_to_fma
+    } : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+!vecA = vector<1x1x1x2xbf16>
+!vecB = vector<1x1x8x2xbf16>
+!vecC = vector<1x8xf32>
+!memrefA = memref<1x4x1x2xbf16>
+!memrefB = memref<1x1x32x2xbf16>
+#map = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d3, d4)>
+#map1 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d3, d2, d4)>
+#map2 = affine_map<(d0, d4, d1, d2, d3) -> (d1, d2)>
+func.func @brgemm_to_fma_load(
+  %arg0: !memrefA, %arg1: !memrefB, %arg2: !vecC) -> !vecC
+{
+  %c0 = arith.constant 0 : index
+  %0 = ub.poison : bf16
+  %1 = vector.load %arg0[%c0, %c0, %c0, %c0] : 
+        !memrefA, !vecA
+  %2 = vector.load %arg1[%c0, %c0, %c0, %c0] :
+        !memrefB, !vecB
+  %3 = vector.contract {
+    indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["reduction", "reduction", "parallel", "parallel", "reduction"],
+    kind = #vector.kind<add>}
+    %1, %2, %arg2
+    : !vecA, !vecB into !vecC
+  return %3 : !vecC
+}
+
+// CHECK-LABEL: @brgemm_to_fma_load
+// CHECK: x86vector.avx.bcst_to_f32.packed
+// CHECK: x86vector.avx.cvt.packed.odd.indexed_to_f32
+// CHECK: vector.fma
+// CHECK: x86vector.avx.bcst_to_f32.packed
+// CHECK: x86vector.avx.cvt.packed.even.indexed_to_f32
+// CHECK: vector.fma
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.x86vector.vector_contract_bf16_to_fma
+    } : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+!vecA = vector<1x8x1x2xbf16>
+!vecB = vector<1x1x1x2xbf16>
+!vecC = vector<8x1xf32>
+!memrefA = memref<1x32x1x2xbf16>
+!memrefB = memref<1x1x4x2xbf16>
+#map = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d3, d4)>
+#map1 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d3, d2, d4)>
+#map2 = affine_map<(d0, d4, d1, d2, d3) -> (d1, d2)>
+func.func @brgemm_to_fma_load_bcst_B(
+  %arg0: !memrefA, %arg1: !memrefB, %arg2: !vecC) -> !vecC
+{
+  %c0 = arith.constant 0 : index
+  %0 = ub.poison : bf16
+  %1 = vector.load %arg0[%c0, %c0, %c0, %c0] : 
+        !memrefA, !vecA
+  %2 = vector.load %arg1[%c0, %c0, %c0, %c0] :
+        !memrefB, !vecB
+  %3 = vector.contract {
+    indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["reduction", "reduction", "parallel", "parallel", "reduction"],
+    kind = #vector.kind<add>}
+    %1, %2, %arg2
+    : !vecA, !vecB into !vecC
+  return %3 : !vecC
+}
+
+// CHECK-LABEL: @brgemm_to_fma_load_bcst_B
+// CHECK: x86vector.avx.bcst_to_f32.packed
+// CHECK: x86vector.avx.cvt.packed.odd.indexed_to_f32
+// CHECK: vector.fma
+// CHECK: x86vector.avx.bcst_to_f32.packed
+// CHECK: x86vector.avx.cvt.packed.even.indexed_to_f32
+// CHECK: vector.fma
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.x86vector.vector_contract_bf16_to_fma
+    } : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+!vecA = vector<1x1x1x2xbf16>
+!vecB = vector<1x1x8x2xbf16>
+!vecC = vector<1x1x8xf32>
+!memrefA = memref<1x4x1x2xbf16>
+!memrefB = memref<1x1x32x2xbf16>
+#map = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d3, d4)>
+#map1 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d3, d2, d4)>
+#map2 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d2)>
+func.func @batch_matmul_fma_load(
+  %arg0: !memrefA, %arg1: !memrefB, %arg2: !vecC) -> !vecC
+{
+  %c0 = arith.constant 0 : index
+  %0 = ub.poison : bf16
+  %1 = vector.load %arg0[%c0, %c0, %c0, %c0] : 
+        !memrefA, !vecA
+  %2 = vector.load %arg1[%c0, %c0, %c0, %c0] :
+        !memrefB, !vecB
+  %3 = vector.contract {
+    indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["parallel", "reduction", "parallel", "parallel", "reduction"],
+    kind = #vector.kind<add>}
+    %1, %2, %arg2
+    : !vecA, !vecB into !vecC
+  return %3 : !vecC
+}
+
+// CHECK-LABEL: @batch_matmul_fma_load
+// CHECK: x86vector.avx.bcst_to_f32.packed
+// CHECK: x86vector.avx.cvt.packed.odd.indexed_to_f32
+// CHECK: vector.fma
+// CHECK: x86vector.avx.bcst_to_f32.packed
+// CHECK: x86vector.avx.cvt.packed.even.indexed_to_f32
+// CHECK: vector.fma
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.x86vector.vector_contract_bf16_to_fma
+    } : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+!vecA = vector<1x1x2xbf16>
+!vecB = vector<1x8x2xbf16>
+!vecC = vector<1x8xf32>
+!memrefA = memref<4x1x2xbf16>
+!memrefB = memref<1x32x2xbf16>
+#map = affine_map<(d4, d1, d2, d3) -> (d1, d3, d4)>
+#map1 = affine_map<(d4, d1, d2, d3) -> (d3, d2, d4)>
+#map2 = affine_map<(d4, d1, d2, d3) -> (d1, d2)>
+func.func @matmul_outer_product_to_fma_load(
+  %arg0: !memrefA, %arg1: !memrefB, %arg2: !vecC) -> !vecC
+{
+  %c0 = arith.constant 0 : index
+  %0 = ub.poison : bf16
+  %1 = vector.load %arg0[%c0, %c0, %c0] : 
+        !memrefA, !vecA
+  %2 = vector.load %arg1[%c0, %c0, %c0] :
+        !memrefB, !vecB
+  %3 = vector.contract {
+    indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["reduction", "parallel", "parallel", "reduction"],
+    kind = #vector.kind<add>}
+    %1, %2, %arg2
+    : !vecA, !vecB into !vecC
+  return %3 : !vecC
+}
+
+// CHECK-LABEL: @matmul_outer_product_to_fma_load
+// CHECK: x86vector.avx.bcst_to_f32.packed
+// CHECK: x86vector.avx.cvt.packed.odd.indexed_to_f32
+// CHECK: vector.fma
+// CHECK: x86vector.avx.bcst_to_f32.packed
+// CHECK: x86vector.avx.cvt.packed.even.indexed_to_f32
+// CHECK: vector.fma
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.x86vector.vector_contract_bf16_to_fma
+    } : !transform.any_op
+    transform.yield
+  }
+}

>From 36499af835134b1d69ca4b6aab73b7ef4d6fc3e4 Mon Sep 17 00:00:00 2001
From: Arun Thangamani <arun.thangamani at intel.com>
Date: Tue, 2 Dec 2025 01:38:13 -0800
Subject: [PATCH 2/8] fix - clang format errors

---
 mlir/lib/Dialect/X86Vector/Utils/X86VectorUtils.cpp | 5 +++--
 1 file changed, 3 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/X86Vector/Utils/X86VectorUtils.cpp b/mlir/lib/Dialect/X86Vector/Utils/X86VectorUtils.cpp
index e06307eeedcdb..a934188bb058d 100644
--- a/mlir/lib/Dialect/X86Vector/Utils/X86VectorUtils.cpp
+++ b/mlir/lib/Dialect/X86Vector/Utils/X86VectorUtils.cpp
@@ -37,8 +37,9 @@ inferIteratorsFromOutMap(AffineMap map) {
 
 // Returns true if the operation is in VNNI layout.
 // Optionally, the check can be constrained to a specific VNNI blocking factor.
-bool mlir::x86vector::isInVnniLayout(Operation *op, ArrayRef<AffineMap> indexingMaps,
-                           std::optional<unsigned> blockingFactor) {
+bool mlir::x86vector::isInVnniLayout(Operation *op,
+                                     ArrayRef<AffineMap> indexingMaps,
+                                     std::optional<unsigned> blockingFactor) {
   // Narrow down type operations - VNNI only applies to contractions.
   FailureOr<linalg::ContractionDimensions> dims =
       linalg::inferContractionDims(indexingMaps);

>From af67c0c4b6c0673b460ac5b056f369523f221638 Mon Sep 17 00:00:00 2001
From: Arun Thangamani <arun.thangamani at intel.com>
Date: Mon, 8 Dec 2025 05:15:09 -0800
Subject: [PATCH 3/8] added new comments + negative test-cases

---
 .../TransformOps/X86VectorTransformOps.td     |  3 +-
 .../mlir/Dialect/X86Vector/Transforms.h       |  3 +
 .../Transforms/VectorContractBF16ToFMA.cpp    | 75 +++++++++++++++----
 .../vector-contract-bf16-to-fma.mlir          | 72 ++++++++++++++++++
 4 files changed, 139 insertions(+), 14 deletions(-)

diff --git a/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td b/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td
index b4f1c0e12ef9b..9c3ed1c8092a1 100644
--- a/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td
+++ b/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td
@@ -42,7 +42,8 @@ def ApplyVectorContractBF16ToFMAPatternsOp : Op<Transform_Dialect,
     "apply_patterns.x86vector.vector_contract_bf16_to_fma",
     [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
   let description = [{
-    Indicates that BF16 vector contract operation can be lowered to a FMA.
+    Collect patterns to lower a BF16 type vector.contract operation
+        to a FMA via emulation lowering using BF16 packed operations.
   }];
 
   let assemblyFormat = "attr-dict";
diff --git a/mlir/include/mlir/Dialect/X86Vector/Transforms.h b/mlir/include/mlir/Dialect/X86Vector/Transforms.h
index aba903845e429..c4960ae28cb4f 100644
--- a/mlir/include/mlir/Dialect/X86Vector/Transforms.h
+++ b/mlir/include/mlir/Dialect/X86Vector/Transforms.h
@@ -91,6 +91,9 @@ void populateVectorContractToFMAPatterns(RewritePatternSet &patterns);
 void populateVectorContractToPackedTypeDotProductPatterns(
     RewritePatternSet &patterns);
 
+// A set of patterns for lowering 32-bit packed BF16 vector contraction
+// operations to vector fused multiply-add (FMA) operations, following
+// the emulation-based approach using BF16 packed operations.
 void populateVectorContractBF16ToFMAPatterns(RewritePatternSet &patterns);
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/X86Vector/Transforms/VectorContractBF16ToFMA.cpp b/mlir/lib/Dialect/X86Vector/Transforms/VectorContractBF16ToFMA.cpp
index 58cc7f7497044..43c3a38b277b7 100644
--- a/mlir/lib/Dialect/X86Vector/Transforms/VectorContractBF16ToFMA.cpp
+++ b/mlir/lib/Dialect/X86Vector/Transforms/VectorContractBF16ToFMA.cpp
@@ -19,11 +19,27 @@
 
 #include "mlir/Pass/Pass.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/Support/Casting.h"
 
 using namespace mlir;
 using namespace mlir::vector;
 using namespace mlir::x86vector;
 
+// This function retrives the source operation of the load or transfer
+// reads and creates subviews for the BF16 packed-operations to
+// broadcast or load BF16 elements as F32 packed elements.
+//
+// For example:
+// ```
+//   vector.load %arg0[%c0, %c0, %c0]:memref<4x1x2xbf16>,vector<1x1x2xbf16>
+//   vector.load %arg0[%c0, %c0, %c0]:memref<1x32x2xbf16>,vector<1x8x2xbf16>
+// ```
+// to
+// ```
+//   memref.subview %arg0[%c0,%c0,%c1]:memref<4x1x2xbf16> to memref<1x1x1xbf16>
+//   memref.subview %arg1[%c0,%c0,%c0]:memref<1x32x2xbf16> to memref<1x8x2xbf16>
+//   memref.subview %arg0[%c0,%c0,%c0]:memref<4x1x2xbf16> to memref<1x1x1xbf16>
+// ```
 static FailureOr<llvm::SmallVector<mlir::memref::SubViewOp>>
 getSubviewFromVectorInput(Location loc, PatternRewriter &rewriter,
                           mlir::Value prodOp, int64_t mnDim, int64_t vnniDim,
@@ -52,6 +68,10 @@ getSubviewFromVectorInput(Location loc, PatternRewriter &rewriter,
   if (!srcOperation)
     return failure();
 
+  Type srcType = srcOperation.getType();
+  if (!llvm::isa<mlir::MemRefType>(srcType))
+    return failure();
+
   llvm::SmallVector<OpFoldResult> strides;
   llvm::SmallVector<OpFoldResult> sizes;
 
@@ -83,6 +103,24 @@ getSubviewFromVectorInput(Location loc, PatternRewriter &rewriter,
   return subviews;
 }
 
+// Implements outer product contraction as a sequence of BF16-packed
+// operation even/odd loads and FMA operations.
+//
+// For example:
+// ```
+//   %1 = vector.load from memref (%m1) -> vector<1x1x2xbf16>
+//   %2 = vector.load from memref (%m2) -> vector<1x8x2xbf16>
+//   return vector.contract %1, %2, %arg1
+// ```
+// to
+// ```
+//   %1 = x86vector.avx.bcst_to_f32.packed %m1[c1] -> vector<8xf32>
+//   %2 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %m2 -> vector<8xf32>
+//   %3 = vector.fma %1, %2, %arg1
+//   %4 = x86vector.avx.bcst_to_f32.packed %m1[c0] -> vector<8xf32>
+//   %5 = x86vector.avx.cvt.packed.even.indexed_to_f32 %m2 -> vector<8xf32>
+//   return vector.fma %4, %5, %3
+// ```
 struct VectorContractBF16ToFMA
     : public OpRewritePattern<vector::ContractionOp> {
   using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
@@ -94,6 +132,9 @@ struct VectorContractBF16ToFMA
       return rewriter.notifyMatchFailure(contractOp,
                                          "Expects add combining kind.");
 
+    // TODO: Move this validation to a comon utility folder. Planned to
+    // do once (code refactoring), all architecture specific nanokernel
+    // passes are merged into the repo.
     VectorType lhsTy = contractOp.getLhsType();
     if (!lhsTy.getElementType().isBF16())
       return rewriter.notifyMatchFailure(contractOp,
@@ -141,8 +182,7 @@ struct VectorContractBF16ToFMA
       return rewriter.notifyMatchFailure(
           contractOp, "A or B should be a non-unit dim in acc.");
 
-    // Non-unit dimensions should match the vector length of BF16 or Int8
-    // dot-product.
+    // Non-unit dimensions should match the vector length of BF16.
     unsigned int nonUnitDim = nonUnitDimLhs.size() == 2 ? nonUnitDimLhs.front()
                                                         : nonUnitDimRhs.front();
     if (nonUnitDim != 4 && nonUnitDim != 8 &&
@@ -151,25 +191,25 @@ struct VectorContractBF16ToFMA
           contractOp, "BF16 packed load operation expects non-unit (LHR or "
                       "RHS) dim and acc dim of size 4/8.");
 
+    // Lower vector.contract to FMAs with help of BF16 packed ops.
     auto loc = contractOp.getLoc();
-    auto castAcc = vector::ShapeCastOp::create(
-        rewriter, loc,
-        VectorType::get(nonUnitDimAcc.front(), accTy.getElementType()),
-        contractOp.getAcc());
-    mlir::VectorType dstType =
-        mlir::VectorType::get(nonUnitDimAcc.front(), rewriter.getF32Type());
-
     llvm::SmallVector<mlir::memref::SubViewOp> unitDimSubview;
     llvm::SmallVector<mlir::memref::SubViewOp> nonUnitDimSubview;
 
+    // create the unit-dimension LHS or RHS subview and the
+    // corresponding non-unit dimension LHS or RHS subview on the other-side.
+    // For example, if LHS has type vector<1x1x2xbf16> and RHS has type
+    // vector<1x8x2xbf16>, we create two subview for the LHS and one subview
+    // for the RHS. In the opposite case (non-unit dimension on the LHS), we
+    // do vice-versa.
     if ((nonUnitDimRhs.size() - 1) > 0) {
-
       auto unitSubview = getSubviewFromVectorInput(
           loc, rewriter, contractOp.getLhs(), 1, 1, 2);
       auto nonUnitSubview = getSubviewFromVectorInput(
           loc, rewriter, contractOp.getRhs(), nonUnitDimRhs.front(), 2, 2);
       if (failed(unitSubview) || failed(nonUnitSubview))
-        return failure();
+        return rewriter.notifyMatchFailure(
+            contractOp, " The input source is not MemRef Type.");
 
       unitDimSubview = *unitSubview;
       nonUnitDimSubview = *nonUnitSubview;
@@ -180,12 +220,21 @@ struct VectorContractBF16ToFMA
       auto unitSubview = getSubviewFromVectorInput(
           loc, rewriter, contractOp.getRhs(), 1, 1, 2);
       if (failed(unitSubview) || failed(nonUnitSubview))
-        return failure();
+        return rewriter.notifyMatchFailure(
+            contractOp, " The input source is not MemRef Type.");
 
       unitDimSubview = *unitSubview;
       nonUnitDimSubview = *nonUnitSubview;
     }
 
+    auto castAcc = vector::ShapeCastOp::create(
+        rewriter, loc,
+        VectorType::get(nonUnitDimAcc.front(), accTy.getElementType()),
+        contractOp.getAcc());
+    mlir::VectorType dstType =
+        mlir::VectorType::get(nonUnitDimAcc.front(), rewriter.getF32Type());
+
+    // Load, broadcast, and do FMA for odd indexed BF16 elements.
     auto loadBcstOddIndxElementToF32 = x86vector::BcstToPackedF32Op::create(
         rewriter, loc, dstType, unitDimSubview[0]);
     auto loadOddIndxElementF32 = x86vector::CvtPackedOddIndexedToF32Op::create(
@@ -203,12 +252,12 @@ struct VectorContractBF16ToFMA
       rewriter.setInsertionPoint(users[0]);
     }
 
+    // Load, broadcast, and do FMA for even indexed BF16 elements.
     auto loadBcstEvenIndxElementToF32 = x86vector::BcstToPackedF32Op::create(
         rewriter, loc, dstType, unitDimSubview[1]);
     auto loadEvenIndxElementF32 =
         x86vector::CvtPackedEvenIndexedToF32Op::create(rewriter, loc, dstType,
                                                        nonUnitDimSubview[0]);
-
     vector::FMAOp fma =
         vector::FMAOp::create(rewriter, loc, loadBcstEvenIndxElementToF32,
                               loadEvenIndxElementF32, oddIndxFMA);
diff --git a/mlir/test/Dialect/X86Vector/vector-contract-bf16-to-fma.mlir b/mlir/test/Dialect/X86Vector/vector-contract-bf16-to-fma.mlir
index 0bee34f23a8a4..0b92e9367365f 100644
--- a/mlir/test/Dialect/X86Vector/vector-contract-bf16-to-fma.mlir
+++ b/mlir/test/Dialect/X86Vector/vector-contract-bf16-to-fma.mlir
@@ -227,3 +227,75 @@ module attributes {transform.with_named_sequence} {
     transform.yield
   }
 }
+
+// -----
+
+#map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d2, d4, d1)>
+#map1 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d4, d3, d1)>
+#map2 = affine_map<(d0, d1, d2, d3, d4) -> (d2, d3)>
+
+func.func @negative_tensor_type(%arg0: tensor<4x64x32x2xbf16>, %arg1: tensor<4x32x64x2xbf16>, %arg2: vector<1x16xf32>) -> vector<1x16xf32> {
+  %0 = ub.poison : bf16
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c16 = arith.constant 16 : index
+  %extracted_slice = tensor.extract_slice %arg0[%c0, %c0, %c0, 0] [1, 4, 1, 2] [1, 1, 1, 1] : tensor<4x64x32x2xbf16> to tensor<1x4x1x2xbf16>
+  %extracted_slice_0 = tensor.extract_slice %arg1[%c0, %c0, %c0, 0] [1, 1, 32, 2] [1, 1, 1, 1] : tensor<4x32x64x2xbf16> to tensor<1x1x32x2xbf16>
+  %1 = vector.transfer_read %extracted_slice[%c0, %c0, %c0, %c0], %0 {in_bounds = [true, true, true, true]} : tensor<1x4x1x2xbf16>, vector<1x1x1x2xbf16>
+  %2 = vector.transfer_read %extracted_slice[%c0, %c1, %c0, %c0], %0 {in_bounds = [true, true, true, true]} : tensor<1x4x1x2xbf16>, vector<1x1x1x2xbf16>
+  %3 = vector.transfer_read %extracted_slice_0[%c0, %c0, %c0, %c0], %0 {in_bounds = [true, true, true, true]} : tensor<1x1x32x2xbf16>, vector<1x1x16x2xbf16>
+  %4 = vector.transfer_read %extracted_slice_0[%c0, %c0, %c16, %c0], %0 {in_bounds = [true, true, true, true]} : tensor<1x1x32x2xbf16>, vector<1x1x16x2xbf16>
+  %5 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["reduction", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %1, %3, %arg2 {unroll_shape = array<i64: 1, 2, 1, 16, 1>} : vector<1x1x1x2xbf16>, vector<1x1x16x2xbf16> into vector<1x16xf32>
+  %6 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["reduction", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %1, %4, %5 {unroll_shape = array<i64: 1, 2, 1, 16, 1>} : vector<1x1x1x2xbf16>, vector<1x1x16x2xbf16> into vector<1x16xf32>
+  %7 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["reduction", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %2, %3, %6 {unroll_shape = array<i64: 1, 2, 1, 16, 1>} : vector<1x1x1x2xbf16>, vector<1x1x16x2xbf16> into vector<1x16xf32>
+  %8 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["reduction", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %2, %4, %7 {unroll_shape = array<i64: 1, 2, 1, 16, 1>} : vector<1x1x1x2xbf16>, vector<1x1x16x2xbf16> into vector<1x16xf32>
+  return %8 : vector<1x16xf32>
+}
+
+// CHECK-LABEL: @negative_tensor_type
+// CHECK-NOT: vector.fma
+// CHECK: vector.contract
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %0 {
+      transform.apply_patterns.x86vector.vector_contract_bf16_to_fma
+    } : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+!vecA = vector<1x1x1x2xbf16>
+!vecB = vector<1x1x16x2xbf16>
+!vecC = vector<1x1x16xf32>
+#map = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d3, d4)>
+#map1 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d3, d2, d4)>
+#map2 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d2)>
+func.func @negative_no_memref_src(
+  %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC
+{
+  %0 = vector.contract {
+    indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["parallel", "reduction", "parallel", "parallel", "reduction"],
+    kind = #vector.kind<add>}
+    %arg0, %arg1, %arg2
+    : !vecA, !vecB into !vecC
+  return %0 : !vecC
+}
+
+// CHECK-LABEL: @negative_no_memref_src
+// CHECK: vector.contract
+// CHECK-NOT: vector.fma
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.x86vector.vector_contract_bf16_to_fma
+    } : !transform.any_op
+    transform.yield
+  }
+}

>From a9ba397ac125342fadf45e58c0c80dd873b8b3bd Mon Sep 17 00:00:00 2001
From: Arun Thangamani <arun.thangamani at intel.com>
Date: Tue, 9 Dec 2025 06:46:14 -0800
Subject: [PATCH 4/8] code refactor: new comments, code simplification etc..

---
 .../Transforms/VectorContractBF16ToFMA.cpp    | 98 +++++++++++--------
 .../X86Vector/Utils/X86VectorUtils.cpp        |  2 +-
 2 files changed, 58 insertions(+), 42 deletions(-)

diff --git a/mlir/lib/Dialect/X86Vector/Transforms/VectorContractBF16ToFMA.cpp b/mlir/lib/Dialect/X86Vector/Transforms/VectorContractBF16ToFMA.cpp
index 43c3a38b277b7..375364a4f1c7b 100644
--- a/mlir/lib/Dialect/X86Vector/Transforms/VectorContractBF16ToFMA.cpp
+++ b/mlir/lib/Dialect/X86Vector/Transforms/VectorContractBF16ToFMA.cpp
@@ -50,20 +50,17 @@ getSubviewFromVectorInput(Location loc, PatternRewriter &rewriter,
   Value srcOperation;
   SmallVector<OpFoldResult> indexVals;
 
-  if (auto transferRead =
-          prodOp.getDefiningOp<mlir::vector::TransferReadOp>()) {
-    srcOperation = transferRead.getOperand(0);
-    SmallVector<OpFoldResult> indexValues(transferRead.getIndices().begin(),
-                                          transferRead.getIndices().end());
-    indexVals = indexValues;
-  }
+  Operation *defOp = prodOp.getDefiningOp();
+  if (!defOp)
+    return failure();
 
-  if (auto load = prodOp.getDefiningOp<mlir::vector::LoadOp>()) {
-    srcOperation = load.getOperand(0);
-    SmallVector<OpFoldResult> indexValues(load.getIndices().begin(),
-                                          load.getIndices().end());
-    indexVals = indexValues;
-  }
+  llvm::TypeSwitch<Operation *>(defOp)
+      .Case<mlir::vector::TransferReadOp, mlir::vector::LoadOp>(
+          [&](auto readOp) {
+            srcOperation = readOp.getOperand(0);
+            indexVals = SmallVector<OpFoldResult>(readOp.getIndices().begin(),
+                                                  readOp.getIndices().end());
+          });
 
   if (!srcOperation)
     return failure();
@@ -75,14 +72,20 @@ getSubviewFromVectorInput(Location loc, PatternRewriter &rewriter,
   llvm::SmallVector<OpFoldResult> strides;
   llvm::SmallVector<OpFoldResult> sizes;
 
-  for (unsigned int i = 0; i < indexVals.size(); i++) {
+  auto nonVNNIDimSize = indexVals.size() - 1;
+  // Create the size and stride offsets.
+  for (unsigned int i = 0; i < nonVNNIDimSize; i++) {
     strides.push_back(rewriter.getIndexAttr(1));
     sizes.push_back(rewriter.getIndexAttr(1));
   }
 
-  sizes[indexVals.size() - 1] = rewriter.getIndexAttr(vnniDim);
+  strides.push_back(rewriter.getIndexAttr(1));
+  sizes.push_back(rewriter.getIndexAttr(vnniDim));
+
+  // update the unit/nonUnit Dim size eiither it is A(LHS) or B(RHS).
   sizes[indexVals.size() - mnDimIndx] = rewriter.getIndexAttr(mnDim);
 
+  // for unitDim, first broadcast odd element, so index is set to C1.
   if (mnDim == 1) {
     indexVals[indexVals.size() - 1] = rewriter.getIndexAttr(1);
   }
@@ -91,6 +94,11 @@ getSubviewFromVectorInput(Location loc, PatternRewriter &rewriter,
                                            indexVals, sizes, strides);
   subviews.push_back(subview);
 
+  // For unit-dims, two subviews should be created for the odd and even
+  // indexed BF16 element because x86vector.avx.bcst_to_f32.packed op
+  // loads and broadcast the first BF16 element into packed F32. It
+  // cannot distinguish between even and odd BF16 elements within a
+  // packed pair.
   if (mnDim == 1) {
     indexVals[indexVals.size() - 1] = rewriter.getIndexAttr(0);
     sizes[indexVals.size() - 1] = rewriter.getIndexAttr(1);
@@ -193,8 +201,6 @@ struct VectorContractBF16ToFMA
 
     // Lower vector.contract to FMAs with help of BF16 packed ops.
     auto loc = contractOp.getLoc();
-    llvm::SmallVector<mlir::memref::SubViewOp> unitDimSubview;
-    llvm::SmallVector<mlir::memref::SubViewOp> nonUnitDimSubview;
 
     // create the unit-dimension LHS or RHS subview and the
     // corresponding non-unit dimension LHS or RHS subview on the other-side.
@@ -202,30 +208,40 @@ struct VectorContractBF16ToFMA
     // vector<1x8x2xbf16>, we create two subview for the LHS and one subview
     // for the RHS. In the opposite case (non-unit dimension on the LHS), we
     // do vice-versa.
-    if ((nonUnitDimRhs.size() - 1) > 0) {
-      auto unitSubview = getSubviewFromVectorInput(
-          loc, rewriter, contractOp.getLhs(), 1, 1, 2);
-      auto nonUnitSubview = getSubviewFromVectorInput(
-          loc, rewriter, contractOp.getRhs(), nonUnitDimRhs.front(), 2, 2);
-      if (failed(unitSubview) || failed(nonUnitSubview))
-        return rewriter.notifyMatchFailure(
-            contractOp, " The input source is not MemRef Type.");
-
-      unitDimSubview = *unitSubview;
-      nonUnitDimSubview = *nonUnitSubview;
-
-    } else {
-      auto nonUnitSubview = getSubviewFromVectorInput(
-          loc, rewriter, contractOp.getLhs(), nonUnitDimRhs.front(), 2, 3);
-      auto unitSubview = getSubviewFromVectorInput(
-          loc, rewriter, contractOp.getRhs(), 1, 1, 2);
-      if (failed(unitSubview) || failed(nonUnitSubview))
-        return rewriter.notifyMatchFailure(
-            contractOp, " The input source is not MemRef Type.");
-
-      unitDimSubview = *unitSubview;
-      nonUnitDimSubview = *nonUnitSubview;
-    }
+    bool rhsHasMultipleNonUnitDims = (nonUnitDimRhs.size() - 1) > 0;
+    // Select which operand is "unit" and which is "non-unit".
+    Value unitSrc =
+        rhsHasMultipleNonUnitDims ? contractOp.getLhs() : contractOp.getRhs();
+    Value nonUnitSrc =
+        rhsHasMultipleNonUnitDims ? contractOp.getRhs() : contractOp.getLhs();
+
+    // mnDim index differs depending on the orientation.
+    int unitMnDim = rhsHasMultipleNonUnitDims ? 2 : 2;    // same for both
+    int nonUnitMnDim = rhsHasMultipleNonUnitDims ? 2 : 3; // A or B
+
+    // VNNI factor: always 1 for unit dims, 2 for non-unit dims.
+    int unitVnni = 1;
+    int nonUnitVnni = 2;
+
+    // Non-unit dim size.
+    int nonUnitSize = nonUnitDimRhs.front();
+
+    // Build subviews.
+    auto unitSubview = getSubviewFromVectorInput(
+        loc, rewriter, unitSrc, /*size=*/1, unitVnni, unitMnDim);
+
+    auto nonUnitSubview = getSubviewFromVectorInput(loc, rewriter, nonUnitSrc,
+                                                    /*size=*/nonUnitSize,
+                                                    nonUnitVnni, nonUnitMnDim);
+
+    // Check failures once.
+    if (failed(unitSubview) || failed(nonUnitSubview))
+      return rewriter.notifyMatchFailure(
+          contractOp, "The input source is not MemRef Type.");
+
+    llvm::SmallVector<mlir::memref::SubViewOp> unitDimSubview = *unitSubview;
+    llvm::SmallVector<mlir::memref::SubViewOp> nonUnitDimSubview =
+        *nonUnitSubview;
 
     auto castAcc = vector::ShapeCastOp::create(
         rewriter, loc,
diff --git a/mlir/lib/Dialect/X86Vector/Utils/X86VectorUtils.cpp b/mlir/lib/Dialect/X86Vector/Utils/X86VectorUtils.cpp
index a934188bb058d..f7e71a32766ab 100644
--- a/mlir/lib/Dialect/X86Vector/Utils/X86VectorUtils.cpp
+++ b/mlir/lib/Dialect/X86Vector/Utils/X86VectorUtils.cpp
@@ -66,7 +66,7 @@ bool mlir::x86vector::isInVnniLayout(Operation *op,
   // The input matrix dimensions layout must match the following:
   //   - matrix A - [...][K/vnniFactor][vnniFactor]
   //   - matrix B - [...][K/vnniFactor][N][vnniFactor]
-  auto maybeIters = inferIteratorsFromOutMap(indexingMaps[2]);
+  auto maybeIters = inferIteratorsFromOutMap(indexingMaps[2] /* outs */);
   if (failed(maybeIters))
     return false;
   SmallVector<mlir::utils::IteratorType> iteratorTypes = *maybeIters;

>From e05648e0a4ccb0a3291312b58eb4a0f33dc27a1e Mon Sep 17 00:00:00 2001
From: Arun Thangamani <arun.thangamani at intel.com>
Date: Wed, 10 Dec 2025 06:24:38 -0800
Subject: [PATCH 5/8] code refactor: new comments, code simplification etc..

---
 .../Dialect/X86Vector/Utils/X86VectorUtils.h  |  11 --
 .../Transforms/VectorContractBF16ToFMA.cpp    | 139 +++++++++---------
 .../Dialect/X86Vector/Utils/CMakeLists.txt    |   1 +
 .../X86Vector/Utils/X86VectorUtils.cpp        |   7 +-
 4 files changed, 73 insertions(+), 85 deletions(-)

diff --git a/mlir/include/mlir/Dialect/X86Vector/Utils/X86VectorUtils.h b/mlir/include/mlir/Dialect/X86Vector/Utils/X86VectorUtils.h
index 8a76009ddb907..2de9a3122cbd9 100644
--- a/mlir/include/mlir/Dialect/X86Vector/Utils/X86VectorUtils.h
+++ b/mlir/include/mlir/Dialect/X86Vector/Utils/X86VectorUtils.h
@@ -16,21 +16,10 @@
 #include <string>
 
 namespace mlir {
-class Type;
-class ShapedType;
-class OpOperand;
-class AffineDimExpr;
 class AffineMap;
-class VectorType;
 class Operation;
 
 namespace x86vector {
-enum class VnniOperandRank {
-  TRANSPOSE = 3,
-  GEMM = 3,
-  BRGEMM_INS = 4,
-  BRGEMM_OUTS = 3
-};
 
 // Return true if the operation is in VNNI layout.
 // Optionally, the check can be constrained to a specific VNNI blocking factor.
diff --git a/mlir/lib/Dialect/X86Vector/Transforms/VectorContractBF16ToFMA.cpp b/mlir/lib/Dialect/X86Vector/Transforms/VectorContractBF16ToFMA.cpp
index 375364a4f1c7b..0d3576c4d5e57 100644
--- a/mlir/lib/Dialect/X86Vector/Transforms/VectorContractBF16ToFMA.cpp
+++ b/mlir/lib/Dialect/X86Vector/Transforms/VectorContractBF16ToFMA.cpp
@@ -1,5 +1,4 @@
-//===- VectorContractBF16ToFMA.cpp
-//--------------------------------------------===//
+//===- VectorContractBF16ToFMA.cpp-----------------------------------------===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -29,83 +28,82 @@ using namespace mlir::x86vector;
 // reads and creates subviews for the BF16 packed-operations to
 // broadcast or load BF16 elements as F32 packed elements.
 //
-// For example:
+// Example(1) Unit Dim:
 // ```
 //   vector.load %arg0[%c0, %c0, %c0]:memref<4x1x2xbf16>,vector<1x1x2xbf16>
-//   vector.load %arg0[%c0, %c0, %c0]:memref<1x32x2xbf16>,vector<1x8x2xbf16>
 // ```
 // to
 // ```
 //   memref.subview %arg0[%c0,%c0,%c1]:memref<4x1x2xbf16> to memref<1x1x1xbf16>
-//   memref.subview %arg1[%c0,%c0,%c0]:memref<1x32x2xbf16> to memref<1x8x2xbf16>
 //   memref.subview %arg0[%c0,%c0,%c0]:memref<4x1x2xbf16> to memref<1x1x1xbf16>
 // ```
-static FailureOr<llvm::SmallVector<mlir::memref::SubViewOp>>
-getSubviewFromVectorInput(Location loc, PatternRewriter &rewriter,
-                          mlir::Value prodOp, int64_t mnDim, int64_t vnniDim,
-                          int64_t mnDimIndx) {
-
-  llvm::SmallVector<mlir::memref::SubViewOp> subviews;
-
-  Value srcOperation;
-  SmallVector<OpFoldResult> indexVals;
+//
+// Example(2) Non-unit Dim:
+// ```
+//   vector.load %arg1[%c0, %c0, %c0]:memref<1x32x2xbf16>,vector<1x8x2xbf16>
+// ```
+// to
+// ```
+//   memref.subview %arg1[%c0,%c0,%c0]:memref<1x32x2xbf16> to memref<1x8x2xbf16>
+// ```
+static FailureOr<SmallVector<memref::SubViewOp>>
+getSubviewFromVectorInput(Location loc, PatternRewriter &rewriter, Value prodOp,
+                          int64_t mnDimSize, int64_t vnniDimSize,
+                          int64_t mnDimIdx) {
 
   Operation *defOp = prodOp.getDefiningOp();
   if (!defOp)
     return failure();
 
-  llvm::TypeSwitch<Operation *>(defOp)
-      .Case<mlir::vector::TransferReadOp, mlir::vector::LoadOp>(
-          [&](auto readOp) {
-            srcOperation = readOp.getOperand(0);
-            indexVals = SmallVector<OpFoldResult>(readOp.getIndices().begin(),
-                                                  readOp.getIndices().end());
-          });
+  Value srcOperation;
+  SmallVector<OpFoldResult> indexVals;
+  llvm::TypeSwitch<Operation *>(defOp).Case<TransferReadOp, LoadOp>(
+      [&](auto readOp) {
+        srcOperation = readOp.getOperand(0);
+        indexVals = SmallVector<OpFoldResult>(readOp.getIndices().begin(),
+                                              readOp.getIndices().end());
+      });
 
   if (!srcOperation)
     return failure();
 
   Type srcType = srcOperation.getType();
-  if (!llvm::isa<mlir::MemRefType>(srcType))
+  if (!llvm::isa<MemRefType>(srcType))
     return failure();
 
-  llvm::SmallVector<OpFoldResult> strides;
-  llvm::SmallVector<OpFoldResult> sizes;
-
   auto nonVNNIDimSize = indexVals.size() - 1;
   // Create the size and stride offsets.
-  for (unsigned int i = 0; i < nonVNNIDimSize; i++) {
-    strides.push_back(rewriter.getIndexAttr(1));
-    sizes.push_back(rewriter.getIndexAttr(1));
-  }
+  auto one = rewriter.getIndexAttr(1);
+  SmallVector<OpFoldResult> strides(nonVNNIDimSize, one);
+  SmallVector<OpFoldResult> sizes(nonVNNIDimSize, one);
 
   strides.push_back(rewriter.getIndexAttr(1));
-  sizes.push_back(rewriter.getIndexAttr(vnniDim));
+  sizes.push_back(rewriter.getIndexAttr(vnniDimSize));
 
-  // update the unit/nonUnit Dim size eiither it is A(LHS) or B(RHS).
-  sizes[indexVals.size() - mnDimIndx] = rewriter.getIndexAttr(mnDim);
+  // update the unit/nonUnit Dim size either it is A(LHS) or B(RHS).
+  sizes[indexVals.size() - mnDimIdx] = rewriter.getIndexAttr(mnDimSize);
 
-  // for unitDim, first broadcast odd element, so index is set to C1.
-  if (mnDim == 1) {
+  // for unitDim, first broadcast odd element, so index is set to 1.
+  if (mnDimSize == 1)
     indexVals[indexVals.size() - 1] = rewriter.getIndexAttr(1);
-  }
 
+  llvm::SmallVector<memref::SubViewOp> subviews;
   auto subview = memref::SubViewOp::create(rewriter, loc, srcOperation,
                                            indexVals, sizes, strides);
   subviews.push_back(subview);
 
   // For unit-dims, two subviews should be created for the odd and even
-  // indexed BF16 element because x86vector.avx.bcst_to_f32.packed op
-  // loads and broadcast the first BF16 element into packed F32. It
+  // element in the VNNI tuple (2xbf16) because x86vector.avx.bcst_to_f32.packed
+  // op loads and broadcast the first BF16 element into packed F32. It
   // cannot distinguish between even and odd BF16 elements within a
   // packed pair.
-  if (mnDim == 1) {
+  if (mnDimSize == 1) {
     indexVals[indexVals.size() - 1] = rewriter.getIndexAttr(0);
     sizes[indexVals.size() - 1] = rewriter.getIndexAttr(1);
 
-    auto unitDimEvenIndxSubview = memref::SubViewOp::create(
+    auto unitDimEvenIdxSubview = memref::SubViewOp::create(
         rewriter, loc, srcOperation, indexVals, sizes, strides);
-    subviews.push_back(unitDimEvenIndxSubview);
+    subviews.push_back(unitDimEvenIdxSubview);
   }
 
   return subviews;
@@ -140,7 +138,7 @@ struct VectorContractBF16ToFMA
       return rewriter.notifyMatchFailure(contractOp,
                                          "Expects add combining kind.");
 
-    // TODO: Move this validation to a comon utility folder. Planned to
+    // TODO: Move this validation to a common utility folder. Planned to
     // do once (code refactoring), all architecture specific nanokernel
     // passes are merged into the repo.
     VectorType lhsTy = contractOp.getLhsType();
@@ -149,10 +147,19 @@ struct VectorContractBF16ToFMA
                                          "Only BF16 lowering is supported.");
 
     if (!isInVnniLayout(contractOp.getOperation(),
-                        contractOp.getIndexingMapsArray(), 2))
+                        contractOp.getIndexingMapsArray(),
+                        /*blockingFactor=*/2))
       return rewriter.notifyMatchFailure(contractOp,
                                          "Input matrices not in VNNI format.");
 
+    VectorType accTy = dyn_cast<VectorType>(contractOp.getAccType());
+    if (!accTy)
+      return rewriter.notifyMatchFailure(contractOp, "Wrong accmulator type.");
+
+    if ((lhsTy.getElementType().isBF16() && !accTy.getElementType().isF32()))
+      return rewriter.notifyMatchFailure(
+          contractOp, "Only F32 acumulation supported for BF16 type.");
+
     ArrayRef<int64_t> lhsShape = lhsTy.getShape();
     llvm::SmallVector<int64_t> nonUnitDimLhs;
     llvm::copy_if(lhsShape, std::back_inserter(nonUnitDimLhs),
@@ -174,13 +181,13 @@ struct VectorContractBF16ToFMA
           contractOp,
           "Excepts a one non-unit A/B dimension for either LHS or RHS shape.");
 
-    VectorType accTy = dyn_cast<VectorType>(contractOp.getAccType());
+    /* VectorType accTy = dyn_cast<VectorType>(contractOp.getAccType());
     if (!accTy)
       return rewriter.notifyMatchFailure(contractOp, "Wrong accmulator type.");
 
     if ((lhsTy.getElementType().isBF16() && !accTy.getElementType().isF32()))
       return rewriter.notifyMatchFailure(
-          contractOp, "Only F32 acumulation supported for BF16 type.");
+          contractOp, "Only F32 acumulation supported for BF16 type."); */
 
     ArrayRef<int64_t> accShape = accTy.getShape();
     llvm::SmallVector<int64_t> nonUnitDimAcc;
@@ -216,8 +223,7 @@ struct VectorContractBF16ToFMA
         rhsHasMultipleNonUnitDims ? contractOp.getRhs() : contractOp.getLhs();
 
     // mnDim index differs depending on the orientation.
-    int unitMnDim = rhsHasMultipleNonUnitDims ? 2 : 2;    // same for both
-    int nonUnitMnDim = rhsHasMultipleNonUnitDims ? 2 : 3; // A or B
+    int mnDimIdx = rhsHasMultipleNonUnitDims ? 2 : 3; // A or B
 
     // VNNI factor: always 1 for unit dims, 2 for non-unit dims.
     int unitVnni = 1;
@@ -228,55 +234,52 @@ struct VectorContractBF16ToFMA
 
     // Build subviews.
     auto unitSubview = getSubviewFromVectorInput(
-        loc, rewriter, unitSrc, /*size=*/1, unitVnni, unitMnDim);
+        loc, rewriter, unitSrc, /*size=*/1, unitVnni, mnDimIdx);
 
-    auto nonUnitSubview = getSubviewFromVectorInput(loc, rewriter, nonUnitSrc,
-                                                    /*size=*/nonUnitSize,
-                                                    nonUnitVnni, nonUnitMnDim);
+    auto nonUnitSubview =
+        getSubviewFromVectorInput(loc, rewriter, nonUnitSrc,
+                                  /*size=*/nonUnitSize, nonUnitVnni, mnDimIdx);
 
     // Check failures once.
     if (failed(unitSubview) || failed(nonUnitSubview))
       return rewriter.notifyMatchFailure(
           contractOp, "The input source is not MemRef Type.");
 
-    llvm::SmallVector<mlir::memref::SubViewOp> unitDimSubview = *unitSubview;
-    llvm::SmallVector<mlir::memref::SubViewOp> nonUnitDimSubview =
-        *nonUnitSubview;
+    SmallVector<memref::SubViewOp> unitDimSubview = *unitSubview;
+    SmallVector<memref::SubViewOp> nonUnitDimSubview = *nonUnitSubview;
 
     auto castAcc = vector::ShapeCastOp::create(
         rewriter, loc,
         VectorType::get(nonUnitDimAcc.front(), accTy.getElementType()),
         contractOp.getAcc());
-    mlir::VectorType dstType =
-        mlir::VectorType::get(nonUnitDimAcc.front(), rewriter.getF32Type());
+    VectorType dstType =
+        VectorType::get(nonUnitDimAcc.front(), rewriter.getF32Type());
 
     // Load, broadcast, and do FMA for odd indexed BF16 elements.
-    auto loadBcstOddIndxElementToF32 = x86vector::BcstToPackedF32Op::create(
+    auto loadBcstOddIdxElementToF32 = x86vector::BcstToPackedF32Op::create(
         rewriter, loc, dstType, unitDimSubview[0]);
-    auto loadOddIndxElementF32 = x86vector::CvtPackedOddIndexedToF32Op::create(
+    auto loadOddIdxElementF32 = x86vector::CvtPackedOddIndexedToF32Op::create(
         rewriter, loc, dstType, nonUnitDimSubview[0]);
-    auto oddIndxFMA =
-        vector::FMAOp::create(rewriter, loc, loadBcstOddIndxElementToF32,
-                              loadOddIndxElementF32, castAcc);
+    auto oddIdxFMA =
+        vector::FMAOp::create(rewriter, loc, loadBcstOddIdxElementToF32,
+                              loadOddIdxElementF32, castAcc);
 
     llvm::SmallVector<Operation *> users;
     for (OpResult result : contractOp->getResults())
       for (Operation *user : result.getUsers())
         users.push_back(user);
 
-    if (users.size() == 1) {
+    if (users.size() == 1)
       rewriter.setInsertionPoint(users[0]);
-    }
 
     // Load, broadcast, and do FMA for even indexed BF16 elements.
-    auto loadBcstEvenIndxElementToF32 = x86vector::BcstToPackedF32Op::create(
+    auto loadBcstEvenIdxElementToF32 = x86vector::BcstToPackedF32Op::create(
         rewriter, loc, dstType, unitDimSubview[1]);
-    auto loadEvenIndxElementF32 =
-        x86vector::CvtPackedEvenIndexedToF32Op::create(rewriter, loc, dstType,
-                                                       nonUnitDimSubview[0]);
+    auto loadEvenIdxElementF32 = x86vector::CvtPackedEvenIndexedToF32Op::create(
+        rewriter, loc, dstType, nonUnitDimSubview[0]);
     vector::FMAOp fma =
-        vector::FMAOp::create(rewriter, loc, loadBcstEvenIndxElementToF32,
-                              loadEvenIndxElementF32, oddIndxFMA);
+        vector::FMAOp::create(rewriter, loc, loadBcstEvenIdxElementToF32,
+                              loadEvenIdxElementF32, oddIdxFMA);
 
     auto castFma = vector::ShapeCastOp::create(rewriter, loc, accTy, fma);
     rewriter.replaceOp(contractOp, castFma);
diff --git a/mlir/lib/Dialect/X86Vector/Utils/CMakeLists.txt b/mlir/lib/Dialect/X86Vector/Utils/CMakeLists.txt
index 6a2da861737ed..595846489f6c9 100644
--- a/mlir/lib/Dialect/X86Vector/Utils/CMakeLists.txt
+++ b/mlir/lib/Dialect/X86Vector/Utils/CMakeLists.txt
@@ -9,5 +9,6 @@ add_mlir_dialect_library(MLIRX86VectorUtils
   MLIRDialectUtils
   MLIRFuncDialect
   MLIRIR
+  MLIRLinalgDialect
   MLIRVectorDialect
   )
diff --git a/mlir/lib/Dialect/X86Vector/Utils/X86VectorUtils.cpp b/mlir/lib/Dialect/X86Vector/Utils/X86VectorUtils.cpp
index f7e71a32766ab..bb31686a14616 100644
--- a/mlir/lib/Dialect/X86Vector/Utils/X86VectorUtils.cpp
+++ b/mlir/lib/Dialect/X86Vector/Utils/X86VectorUtils.cpp
@@ -5,10 +5,6 @@
 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 //
 //===----------------------------------------------------------------------===//
-//
-// This file implements utility methods for working with the Vector dialect.
-//
-//===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/X86Vector/Utils/X86VectorUtils.h"
 
@@ -19,9 +15,8 @@
 #include "mlir/IR/TypeUtilities.h"
 #include "mlir/IR/Types.h"
 
-#define DEBUG_TYPE "x86vector-utils"
-
 using namespace mlir;
+using namespace mlir::x86vector;
 
 static FailureOr<SmallVector<mlir::utils::IteratorType>>
 inferIteratorsFromOutMap(AffineMap map) {

>From 18e77daf35703c0ca9606186bdab714d4571eeb0 Mon Sep 17 00:00:00 2001
From: Arun Thangamani <arun.thangamani at intel.com>
Date: Thu, 11 Dec 2025 07:53:07 -0800
Subject: [PATCH 6/8] creating new namespace and removing hard coded values +
 others

---
 .../Transforms/VectorContractBF16ToFMA.cpp    | 51 ++++++++++---------
 .../X86Vector/Utils/X86VectorUtils.cpp        | 12 +++--
 .../vector-contract-bf16-to-fma.mlir          | 46 +++++++++++++++++
 3 files changed, 81 insertions(+), 28 deletions(-)

diff --git a/mlir/lib/Dialect/X86Vector/Transforms/VectorContractBF16ToFMA.cpp b/mlir/lib/Dialect/X86Vector/Transforms/VectorContractBF16ToFMA.cpp
index 0d3576c4d5e57..b9ccee9bd7186 100644
--- a/mlir/lib/Dialect/X86Vector/Transforms/VectorContractBF16ToFMA.cpp
+++ b/mlir/lib/Dialect/X86Vector/Transforms/VectorContractBF16ToFMA.cpp
@@ -24,7 +24,7 @@ using namespace mlir;
 using namespace mlir::vector;
 using namespace mlir::x86vector;
 
-// This function retrives the source operation of the load or transfer
+// This function retrieves the source operation of the load or transfer
 // reads and creates subviews for the BF16 packed-operations to
 // broadcast or load BF16 elements as F32 packed elements.
 //
@@ -81,7 +81,7 @@ getSubviewFromVectorInput(Location loc, PatternRewriter &rewriter, Value prodOp,
   sizes.push_back(rewriter.getIndexAttr(vnniDimSize));
 
   // update the unit/nonUnit Dim size either it is A(LHS) or B(RHS).
-  sizes[indexVals.size() - mnDimIdx] = rewriter.getIndexAttr(mnDimSize);
+  sizes[mnDimIdx] = rewriter.getIndexAttr(mnDimSize);
 
   // for unitDim, first broadcast odd element, so index is set to 1.
   if (mnDimSize == 1)
@@ -97,6 +97,9 @@ getSubviewFromVectorInput(Location loc, PatternRewriter &rewriter, Value prodOp,
   // op loads and broadcast the first BF16 element into packed F32. It
   // cannot distinguish between even and odd BF16 elements within a
   // packed pair.
+  // Example:
+  // memref.subview %arg0[%c0,%c1]:memref<1x2xbf16> to memref<1x1xbf16> // Odd
+  // memref.subview %arg0[%c0,%c0]:memref<1x2xbf16> to memref<1x1xbf16> // Even
   if (mnDimSize == 1) {
     indexVals[indexVals.size() - 1] = rewriter.getIndexAttr(0);
     sizes[indexVals.size() - 1] = rewriter.getIndexAttr(1);
@@ -156,20 +159,32 @@ struct VectorContractBF16ToFMA
     if (!accTy)
       return rewriter.notifyMatchFailure(contractOp, "Wrong accmulator type.");
 
-    if ((lhsTy.getElementType().isBF16() && !accTy.getElementType().isF32()))
+    if (!accTy.getElementType().isF32())
       return rewriter.notifyMatchFailure(
           contractOp, "Only F32 acumulation supported for BF16 type.");
 
     ArrayRef<int64_t> lhsShape = lhsTy.getShape();
     llvm::SmallVector<int64_t> nonUnitDimLhs;
-    llvm::copy_if(lhsShape, std::back_inserter(nonUnitDimLhs),
-                  [](int64_t dim) { return dim != 1; });
+    llvm::SmallVector<unsigned> nonUnitIndicesLhs;
+
+    for (auto it : llvm::enumerate(lhsShape)) {
+      if (it.value() != 1) {
+        nonUnitDimLhs.push_back(it.value());
+        nonUnitIndicesLhs.push_back(it.index());
+      }
+    }
 
     VectorType rhsTy = contractOp.getRhsType();
     ArrayRef<int64_t> rhsShape = rhsTy.getShape();
     llvm::SmallVector<int64_t> nonUnitDimRhs;
-    llvm::copy_if(rhsShape, std::back_inserter(nonUnitDimRhs),
-                  [](int64_t dim) { return dim != 1; });
+    llvm::SmallVector<unsigned> nonUnitIndicesRhs;
+
+    for (auto it : llvm::enumerate(rhsShape)) {
+      if (it.value() != 1) {
+        nonUnitDimRhs.push_back(it.value());
+        nonUnitIndicesRhs.push_back(it.index());
+      }
+    }
 
     if ((nonUnitDimLhs.size() - 1) > 0 && (nonUnitDimRhs.size() - 1) > 0)
       return rewriter.notifyMatchFailure(contractOp,
@@ -181,14 +196,6 @@ struct VectorContractBF16ToFMA
           contractOp,
           "Excepts a one non-unit A/B dimension for either LHS or RHS shape.");
 
-    /* VectorType accTy = dyn_cast<VectorType>(contractOp.getAccType());
-    if (!accTy)
-      return rewriter.notifyMatchFailure(contractOp, "Wrong accmulator type.");
-
-    if ((lhsTy.getElementType().isBF16() && !accTy.getElementType().isF32()))
-      return rewriter.notifyMatchFailure(
-          contractOp, "Only F32 acumulation supported for BF16 type."); */
-
     ArrayRef<int64_t> accShape = accTy.getShape();
     llvm::SmallVector<int64_t> nonUnitDimAcc;
     llvm::copy_if(accShape, std::back_inserter(nonUnitDimAcc),
@@ -223,7 +230,9 @@ struct VectorContractBF16ToFMA
         rhsHasMultipleNonUnitDims ? contractOp.getRhs() : contractOp.getLhs();
 
     // mnDim index differs depending on the orientation.
-    int mnDimIdx = rhsHasMultipleNonUnitDims ? 2 : 3; // A or B
+    int mnDimIdx = rhsHasMultipleNonUnitDims
+                       ? nonUnitIndicesRhs.front()
+                       : nonUnitIndicesLhs.front(); // A or B
 
     // VNNI factor: always 1 for unit dims, 2 for non-unit dims.
     int unitVnni = 1;
@@ -264,13 +273,9 @@ struct VectorContractBF16ToFMA
         vector::FMAOp::create(rewriter, loc, loadBcstOddIdxElementToF32,
                               loadOddIdxElementF32, castAcc);
 
-    llvm::SmallVector<Operation *> users;
-    for (OpResult result : contractOp->getResults())
-      for (Operation *user : result.getUsers())
-        users.push_back(user);
-
-    if (users.size() == 1)
-      rewriter.setInsertionPoint(users[0]);
+    OpResult vcResult = contractOp->getResult(0);
+    if (vcResult.hasOneUse())
+      rewriter.setInsertionPoint(*vcResult.getUsers().begin());
 
     // Load, broadcast, and do FMA for even indexed BF16 elements.
     auto loadBcstEvenIdxElementToF32 = x86vector::BcstToPackedF32Op::create(
diff --git a/mlir/lib/Dialect/X86Vector/Utils/X86VectorUtils.cpp b/mlir/lib/Dialect/X86Vector/Utils/X86VectorUtils.cpp
index bb31686a14616..ccb2e92fdd9e2 100644
--- a/mlir/lib/Dialect/X86Vector/Utils/X86VectorUtils.cpp
+++ b/mlir/lib/Dialect/X86Vector/Utils/X86VectorUtils.cpp
@@ -15,8 +15,8 @@
 #include "mlir/IR/TypeUtilities.h"
 #include "mlir/IR/Types.h"
 
-using namespace mlir;
-using namespace mlir::x86vector;
+namespace mlir {
+namespace x86vector {
 
 static FailureOr<SmallVector<mlir::utils::IteratorType>>
 inferIteratorsFromOutMap(AffineMap map) {
@@ -32,9 +32,8 @@ inferIteratorsFromOutMap(AffineMap map) {
 
 // Returns true if the operation is in VNNI layout.
 // Optionally, the check can be constrained to a specific VNNI blocking factor.
-bool mlir::x86vector::isInVnniLayout(Operation *op,
-                                     ArrayRef<AffineMap> indexingMaps,
-                                     std::optional<unsigned> blockingFactor) {
+bool isInVnniLayout(Operation *op, ArrayRef<AffineMap> indexingMaps,
+                    std::optional<unsigned> blockingFactor) {
   // Narrow down type operations - VNNI only applies to contractions.
   FailureOr<linalg::ContractionDimensions> dims =
       linalg::inferContractionDims(indexingMaps);
@@ -104,3 +103,6 @@ bool mlir::x86vector::isInVnniLayout(Operation *op,
 
   return true;
 }
+
+} // namespace x86vector
+} // namespace mlir
diff --git a/mlir/test/Dialect/X86Vector/vector-contract-bf16-to-fma.mlir b/mlir/test/Dialect/X86Vector/vector-contract-bf16-to-fma.mlir
index 0b92e9367365f..aa7cfc84b0f7b 100644
--- a/mlir/test/Dialect/X86Vector/vector-contract-bf16-to-fma.mlir
+++ b/mlir/test/Dialect/X86Vector/vector-contract-bf16-to-fma.mlir
@@ -230,6 +230,52 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
+!vecA = vector<1x1x1x1x2xbf16>
+!vecB = vector<1x1x1x8x2xbf16>
+!vecC = vector<1x8xf32>
+!memrefA = memref<1x1x4x1x2xbf16>
+!memrefB = memref<1x1x1x32x2xbf16>
+#map = affine_map<(d5, d0, d4, d1, d2, d3) -> (d5, d0, d1, d3, d4)>
+#map1 = affine_map<(d5, d0, d4, d1, d2, d3) -> (d5, d0, d3, d2, d4)>
+#map2 = affine_map<(d5, d0, d4, d1, d2, d3) -> (d1, d2)>
+func.func @many_dimensions(
+  %arg0: !memrefA, %arg1: !memrefB, %arg2: !vecC) -> !vecC
+{
+  %c0 = arith.constant 0 : index
+  %0 = ub.poison : bf16
+  %1 = vector.load %arg0[%c0, %c0, %c0, %c0, %c0] :
+        !memrefA, !vecA
+  %2 = vector.load %arg1[%c0, %c0, %c0, %c0, %c0] :
+        !memrefB, !vecB
+  %3 = vector.contract {
+    indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["reduction", "reduction", "reduction", "parallel", "parallel", "reduction"],
+    kind = #vector.kind<add>}
+    %1, %2, %arg2
+    : !vecA, !vecB into !vecC
+  return %3 : !vecC
+}
+
+// CHECK-LABEL: @many_dimensions
+// CHECK: x86vector.avx.bcst_to_f32.packed
+// CHECK: x86vector.avx.cvt.packed.odd.indexed_to_f32
+// CHECK: vector.fma
+// CHECK: x86vector.avx.bcst_to_f32.packed
+// CHECK: x86vector.avx.cvt.packed.even.indexed_to_f32
+// CHECK: vector.fma
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.x86vector.vector_contract_bf16_to_fma
+    } : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
 #map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d2, d4, d1)>
 #map1 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d4, d3, d1)>
 #map2 = affine_map<(d0, d1, d2, d3, d4) -> (d2, d3)>

>From 406b66c077ae724483ef81f68cb13b3da692c981 Mon Sep 17 00:00:00 2001
From: Arun Thangamani <arun.thangamani at intel.com>
Date: Fri, 12 Dec 2025 01:19:09 -0800
Subject: [PATCH 7/8] added new test-cases

---
 .../Transforms/VectorContractBF16ToFMA.cpp    | 80 ++++++++---------
 .../vector-contract-bf16-to-fma.mlir          | 85 +++++++++++++++----
 2 files changed, 103 insertions(+), 62 deletions(-)

diff --git a/mlir/lib/Dialect/X86Vector/Transforms/VectorContractBF16ToFMA.cpp b/mlir/lib/Dialect/X86Vector/Transforms/VectorContractBF16ToFMA.cpp
index b9ccee9bd7186..cae8cfbf7c00f 100644
--- a/mlir/lib/Dialect/X86Vector/Transforms/VectorContractBF16ToFMA.cpp
+++ b/mlir/lib/Dialect/X86Vector/Transforms/VectorContractBF16ToFMA.cpp
@@ -48,48 +48,61 @@ using namespace mlir::x86vector;
 // ```
 static FailureOr<SmallVector<memref::SubViewOp>>
 getSubviewFromVectorInput(Location loc, PatternRewriter &rewriter, Value prodOp,
-                          int64_t mnDimSize, int64_t vnniDimSize,
-                          int64_t mnDimIdx) {
+                          ArrayRef<int64_t> nonUnitDimShape, bool isUnitDim) {
 
   Operation *defOp = prodOp.getDefiningOp();
   if (!defOp)
     return failure();
 
-  Value srcOperation;
+  Value srcBuff;
   SmallVector<OpFoldResult> indexVals;
   llvm::TypeSwitch<Operation *>(defOp).Case<TransferReadOp, LoadOp>(
       [&](auto readOp) {
-        srcOperation = readOp.getOperand(0);
+        srcBuff = readOp.getOperand(0);
         indexVals = SmallVector<OpFoldResult>(readOp.getIndices().begin(),
                                               readOp.getIndices().end());
       });
 
-  if (!srcOperation)
+  if (!srcBuff)
     return failure();
 
-  Type srcType = srcOperation.getType();
+  Type srcType = srcBuff.getType();
   if (!llvm::isa<MemRefType>(srcType))
     return failure();
 
+  int64_t mnDimSize = 1;
+  unsigned mnDimIdx = 0;
+
+  if (!isUnitDim) {
+    for (auto it : llvm::enumerate(nonUnitDimShape)) {
+      if (it.value() != 1) {
+        mnDimSize = it.value();
+        mnDimIdx = it.index();
+        break;
+      }
+    }
+  }
+
+  int vnniDimSize = isUnitDim ? 1 : 2;
+
   auto nonVNNIDimSize = indexVals.size() - 1;
   // Create the size and stride offsets.
   auto one = rewriter.getIndexAttr(1);
-  SmallVector<OpFoldResult> strides(nonVNNIDimSize, one);
+  SmallVector<OpFoldResult> strides(indexVals.size(), one);
   SmallVector<OpFoldResult> sizes(nonVNNIDimSize, one);
 
-  strides.push_back(rewriter.getIndexAttr(1));
   sizes.push_back(rewriter.getIndexAttr(vnniDimSize));
 
   // update the unit/nonUnit Dim size either it is A(LHS) or B(RHS).
   sizes[mnDimIdx] = rewriter.getIndexAttr(mnDimSize);
 
   // for unitDim, first broadcast odd element, so index is set to 1.
-  if (mnDimSize == 1)
+  if (isUnitDim)
     indexVals[indexVals.size() - 1] = rewriter.getIndexAttr(1);
 
   llvm::SmallVector<memref::SubViewOp> subviews;
-  auto subview = memref::SubViewOp::create(rewriter, loc, srcOperation,
-                                           indexVals, sizes, strides);
+  auto subview = memref::SubViewOp::create(rewriter, loc, srcBuff, indexVals,
+                                           sizes, strides);
   subviews.push_back(subview);
 
   // For unit-dims, two subviews should be created for the odd and even
@@ -97,6 +110,7 @@ getSubviewFromVectorInput(Location loc, PatternRewriter &rewriter, Value prodOp,
   // op loads and broadcast the first BF16 element into packed F32. It
   // cannot distinguish between even and odd BF16 elements within a
   // packed pair.
+  //
   // Example:
   // memref.subview %arg0[%c0,%c1]:memref<1x2xbf16> to memref<1x1xbf16> // Odd
   // memref.subview %arg0[%c0,%c0]:memref<1x2xbf16> to memref<1x1xbf16> // Even
@@ -105,7 +119,7 @@ getSubviewFromVectorInput(Location loc, PatternRewriter &rewriter, Value prodOp,
     sizes[indexVals.size() - 1] = rewriter.getIndexAttr(1);
 
     auto unitDimEvenIdxSubview = memref::SubViewOp::create(
-        rewriter, loc, srcOperation, indexVals, sizes, strides);
+        rewriter, loc, srcBuff, indexVals, sizes, strides);
     subviews.push_back(unitDimEvenIdxSubview);
   }
 
@@ -165,26 +179,14 @@ struct VectorContractBF16ToFMA
 
     ArrayRef<int64_t> lhsShape = lhsTy.getShape();
     llvm::SmallVector<int64_t> nonUnitDimLhs;
-    llvm::SmallVector<unsigned> nonUnitIndicesLhs;
-
-    for (auto it : llvm::enumerate(lhsShape)) {
-      if (it.value() != 1) {
-        nonUnitDimLhs.push_back(it.value());
-        nonUnitIndicesLhs.push_back(it.index());
-      }
-    }
+    llvm::copy_if(lhsShape, std::back_inserter(nonUnitDimLhs),
+                  [](int64_t dim) { return dim != 1; });
 
     VectorType rhsTy = contractOp.getRhsType();
     ArrayRef<int64_t> rhsShape = rhsTy.getShape();
     llvm::SmallVector<int64_t> nonUnitDimRhs;
-    llvm::SmallVector<unsigned> nonUnitIndicesRhs;
-
-    for (auto it : llvm::enumerate(rhsShape)) {
-      if (it.value() != 1) {
-        nonUnitDimRhs.push_back(it.value());
-        nonUnitIndicesRhs.push_back(it.index());
-      }
-    }
+    llvm::copy_if(rhsShape, std::back_inserter(nonUnitDimRhs),
+                  [](int64_t dim) { return dim != 1; });
 
     if ((nonUnitDimLhs.size() - 1) > 0 && (nonUnitDimRhs.size() - 1) > 0)
       return rewriter.notifyMatchFailure(contractOp,
@@ -229,25 +231,15 @@ struct VectorContractBF16ToFMA
     Value nonUnitSrc =
         rhsHasMultipleNonUnitDims ? contractOp.getRhs() : contractOp.getLhs();
 
-    // mnDim index differs depending on the orientation.
-    int mnDimIdx = rhsHasMultipleNonUnitDims
-                       ? nonUnitIndicesRhs.front()
-                       : nonUnitIndicesLhs.front(); // A or B
-
-    // VNNI factor: always 1 for unit dims, 2 for non-unit dims.
-    int unitVnni = 1;
-    int nonUnitVnni = 2;
-
-    // Non-unit dim size.
-    int nonUnitSize = nonUnitDimRhs.front();
+    ArrayRef<int64_t> nonUnitDimShape =
+        rhsHasMultipleNonUnitDims ? rhsShape : lhsShape;
 
     // Build subviews.
-    auto unitSubview = getSubviewFromVectorInput(
-        loc, rewriter, unitSrc, /*size=*/1, unitVnni, mnDimIdx);
+    auto unitSubview = getSubviewFromVectorInput(loc, rewriter, unitSrc,
+                                                 nonUnitDimShape, true);
 
-    auto nonUnitSubview =
-        getSubviewFromVectorInput(loc, rewriter, nonUnitSrc,
-                                  /*size=*/nonUnitSize, nonUnitVnni, mnDimIdx);
+    auto nonUnitSubview = getSubviewFromVectorInput(loc, rewriter, nonUnitSrc,
+                                                    nonUnitDimShape, false);
 
     // Check failures once.
     if (failed(unitSubview) || failed(nonUnitSubview))
diff --git a/mlir/test/Dialect/X86Vector/vector-contract-bf16-to-fma.mlir b/mlir/test/Dialect/X86Vector/vector-contract-bf16-to-fma.mlir
index aa7cfc84b0f7b..c55a859340600 100644
--- a/mlir/test/Dialect/X86Vector/vector-contract-bf16-to-fma.mlir
+++ b/mlir/test/Dialect/X86Vector/vector-contract-bf16-to-fma.mlir
@@ -230,6 +230,52 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
+!vecA = vector<8x1x2xbf16>
+!vecB = vector<1x1x2xbf16>
+!vecC = vector<8x1xf32>
+!memrefA = memref<32x1x2xbf16>
+!memrefB = memref<1x4x2xbf16>
+#map = affine_map<(d4, d1, d2, d3) -> (d1, d3, d4)>
+#map1 = affine_map<(d4, d1, d2, d3) -> (d3, d2, d4)>
+#map2 = affine_map<(d4, d1, d2, d3) -> (d1, d2)>
+func.func @matmul_to_fma_load_bcst_B(
+  %arg0: !memrefA, %arg1: !memrefB, %arg2: !vecC) -> !vecC
+{
+  %c0 = arith.constant 0 : index
+  %0 = ub.poison : bf16
+  %1 = vector.load %arg0[%c0, %c0, %c0] :
+        !memrefA, !vecA
+  %2 = vector.load %arg1[%c0, %c0, %c0] :
+        !memrefB, !vecB
+  %3 = vector.contract {
+    indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["reduction", "parallel", "parallel", "reduction"],
+    kind = #vector.kind<add>}
+    %1, %2, %arg2
+    : !vecA, !vecB into !vecC
+  return %3 : !vecC
+}
+
+// CHECK-LABEL: @matmul_to_fma_load_bcst_B
+// CHECK: x86vector.avx.bcst_to_f32.packed
+// CHECK: x86vector.avx.cvt.packed.odd.indexed_to_f32
+// CHECK: vector.fma
+// CHECK: x86vector.avx.bcst_to_f32.packed
+// CHECK: x86vector.avx.cvt.packed.even.indexed_to_f32
+// CHECK: vector.fma
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.x86vector.vector_contract_bf16_to_fma
+    } : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
 !vecA = vector<1x1x1x1x2xbf16>
 !vecB = vector<1x1x1x8x2xbf16>
 !vecC = vector<1x8xf32>
@@ -276,26 +322,29 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
-#map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d2, d4, d1)>
-#map1 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d4, d3, d1)>
-#map2 = affine_map<(d0, d1, d2, d3, d4) -> (d2, d3)>
-
-func.func @negative_tensor_type(%arg0: tensor<4x64x32x2xbf16>, %arg1: tensor<4x32x64x2xbf16>, %arg2: vector<1x16xf32>) -> vector<1x16xf32> {
+!vecA = vector<1x1x2xbf16>
+!vecB = vector<1x8x2xbf16>
+!vecC = vector<1x8xf32>
+!tensorA = tensor<4x1x2xbf16>
+!tensorB = tensor<1x32x2xbf16>
+#map = affine_map<(d1, d2, d3, d4) -> (d2, d4, d1)>
+#map1 = affine_map<(d1, d2, d3, d4) -> (d4, d3, d1)>
+#map2 = affine_map<(d1, d2, d3, d4) -> (d2, d3)>
+func.func @negative_tensor_type(%arg0: !tensorA, %arg1: !tensorB, %arg2: !vecC) -> !vecC {
   %0 = ub.poison : bf16
   %c0 = arith.constant 0 : index
-  %c1 = arith.constant 1 : index
-  %c16 = arith.constant 16 : index
-  %extracted_slice = tensor.extract_slice %arg0[%c0, %c0, %c0, 0] [1, 4, 1, 2] [1, 1, 1, 1] : tensor<4x64x32x2xbf16> to tensor<1x4x1x2xbf16>
-  %extracted_slice_0 = tensor.extract_slice %arg1[%c0, %c0, %c0, 0] [1, 1, 32, 2] [1, 1, 1, 1] : tensor<4x32x64x2xbf16> to tensor<1x1x32x2xbf16>
-  %1 = vector.transfer_read %extracted_slice[%c0, %c0, %c0, %c0], %0 {in_bounds = [true, true, true, true]} : tensor<1x4x1x2xbf16>, vector<1x1x1x2xbf16>
-  %2 = vector.transfer_read %extracted_slice[%c0, %c1, %c0, %c0], %0 {in_bounds = [true, true, true, true]} : tensor<1x4x1x2xbf16>, vector<1x1x1x2xbf16>
-  %3 = vector.transfer_read %extracted_slice_0[%c0, %c0, %c0, %c0], %0 {in_bounds = [true, true, true, true]} : tensor<1x1x32x2xbf16>, vector<1x1x16x2xbf16>
-  %4 = vector.transfer_read %extracted_slice_0[%c0, %c0, %c16, %c0], %0 {in_bounds = [true, true, true, true]} : tensor<1x1x32x2xbf16>, vector<1x1x16x2xbf16>
-  %5 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["reduction", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %1, %3, %arg2 {unroll_shape = array<i64: 1, 2, 1, 16, 1>} : vector<1x1x1x2xbf16>, vector<1x1x16x2xbf16> into vector<1x16xf32>
-  %6 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["reduction", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %1, %4, %5 {unroll_shape = array<i64: 1, 2, 1, 16, 1>} : vector<1x1x1x2xbf16>, vector<1x1x16x2xbf16> into vector<1x16xf32>
-  %7 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["reduction", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %2, %3, %6 {unroll_shape = array<i64: 1, 2, 1, 16, 1>} : vector<1x1x1x2xbf16>, vector<1x1x16x2xbf16> into vector<1x16xf32>
-  %8 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["reduction", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %2, %4, %7 {unroll_shape = array<i64: 1, 2, 1, 16, 1>} : vector<1x1x1x2xbf16>, vector<1x1x16x2xbf16> into vector<1x16xf32>
-  return %8 : vector<1x16xf32>
+  %c8 = arith.constant 8 : index
+  %1 = vector.transfer_read %arg0[%c0, %c0, %c0], %0 {in_bounds = [true, true, true]} :
+        !tensorA, !vecA
+  %2 = vector.transfer_read %arg1[%c0, %c8, %c0], %0 {in_bounds = [true, true, true]} :
+        !tensorB, !vecB
+  %3 = vector.contract {
+    indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["reduction", "parallel", "parallel", "reduction"],
+    kind = #vector.kind<add>}
+    %1, %2, %arg2
+    : !vecA, !vecB into !vecC
+  return %3 : !vecC
 }
 
 // CHECK-LABEL: @negative_tensor_type

>From 350f93566b33ff80eeb04874df247d7552838f61 Mon Sep 17 00:00:00 2001
From: Arun Thangamani <arun.thangamani at intel.com>
Date: Tue, 16 Dec 2025 03:24:37 -0800
Subject: [PATCH 8/8] added new validation around vector.transfer_read op + new
 test-cases around it

---
 .../Transforms/VectorContractBF16ToFMA.cpp    | 104 +++++++---
 .../vector-contract-bf16-to-fma.mlir          | 177 ++++++++++++++++++
 2 files changed, 255 insertions(+), 26 deletions(-)

diff --git a/mlir/lib/Dialect/X86Vector/Transforms/VectorContractBF16ToFMA.cpp b/mlir/lib/Dialect/X86Vector/Transforms/VectorContractBF16ToFMA.cpp
index cae8cfbf7c00f..7e13d84f74eef 100644
--- a/mlir/lib/Dialect/X86Vector/Transforms/VectorContractBF16ToFMA.cpp
+++ b/mlir/lib/Dialect/X86Vector/Transforms/VectorContractBF16ToFMA.cpp
@@ -6,6 +6,8 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/Dialect/Vector/Utils/VectorUtils.h"
 #include "mlir/Dialect/X86Vector/Transforms.h"
@@ -24,6 +26,63 @@ using namespace mlir;
 using namespace mlir::vector;
 using namespace mlir::x86vector;
 
+static bool validateVectorProdOp(Value prodOp) {
+  Operation *defOp = prodOp.getDefiningOp();
+  if (!defOp)
+    return false;
+
+  // If the LHS/RHS op is transfer_read return false if:
+  // (1) - It has false in-bounds
+  // (2) - The permutation map is not identical
+  if (auto readOp = prodOp.getDefiningOp<mlir::vector::TransferReadOp>()) {
+    ArrayAttr inBoundsAttr = readOp.getInBoundsAttr();
+    if (inBoundsAttr) {
+
+      for (Attribute attr : inBoundsAttr) {
+        auto boolAttr = llvm::dyn_cast<BoolAttr>(attr);
+        if (!boolAttr || !boolAttr.getValue()) {
+          return false;
+        }
+      }
+    }
+
+    if (!readOp.getPermutationMap().isIdentity())
+      return false;
+  }
+
+  Value srcBuff;
+  SmallVector<OpFoldResult> indexVals;
+  llvm::TypeSwitch<Operation *>(defOp).Case<TransferReadOp, LoadOp>(
+      [&](auto readOp) {
+        srcBuff = readOp.getOperand(0);
+        indexVals = SmallVector<OpFoldResult>(readOp.getIndices().begin(),
+                                              readOp.getIndices().end());
+      });
+
+  if (!srcBuff)
+    return false;
+
+  // Return false, if the source is not a memref type
+  Type srcType = srcBuff.getType();
+  if (!llvm::isa<MemRefType>(srcType))
+    return false;
+
+  // Return false, if the innermost stride of the memref is not 1.
+  auto [strides, offset] =
+      llvm::cast<mlir::MemRefType>(srcType).getStridesAndOffset();
+  if (!strides.empty()) {
+    int64_t s = strides.back();
+    if (s != mlir::ShapedType::kDynamic && s != 1)
+      return false;
+  }
+
+  // Return false if the vnni offset of load or transfer_read is not zero.
+  if (getConstantIntValue(indexVals.back()) != 0)
+    return false;
+
+  return true;
+}
+
 // This function retrieves the source operation of the load or transfer
 // reads and creates subviews for the BF16 packed-operations to
 // broadcast or load BF16 elements as F32 packed elements.
@@ -46,13 +105,11 @@ using namespace mlir::x86vector;
 // ```
 //   memref.subview %arg1[%c0,%c0,%c0]:memref<1x32x2xbf16> to memref<1x8x2xbf16>
 // ```
-static FailureOr<SmallVector<memref::SubViewOp>>
+static SmallVector<memref::SubViewOp>
 getSubviewFromVectorInput(Location loc, PatternRewriter &rewriter, Value prodOp,
                           ArrayRef<int64_t> nonUnitDimShape, bool isUnitDim) {
 
   Operation *defOp = prodOp.getDefiningOp();
-  if (!defOp)
-    return failure();
 
   Value srcBuff;
   SmallVector<OpFoldResult> indexVals;
@@ -63,13 +120,6 @@ getSubviewFromVectorInput(Location loc, PatternRewriter &rewriter, Value prodOp,
                                               readOp.getIndices().end());
       });
 
-  if (!srcBuff)
-    return failure();
-
-  Type srcType = srcBuff.getType();
-  if (!llvm::isa<MemRefType>(srcType))
-    return failure();
-
   int64_t mnDimSize = 1;
   unsigned mnDimIdx = 0;
 
@@ -215,6 +265,20 @@ struct VectorContractBF16ToFMA
           contractOp, "BF16 packed load operation expects non-unit (LHR or "
                       "RHS) dim and acc dim of size 4/8.");
 
+    if (!validateVectorProdOp(contractOp.getLhs()))
+      return rewriter.notifyMatchFailure(
+          contractOp,
+          "The LHS is in invalid format. Either it has false inbound or "
+          "non-identical permuation map or the vnni offset is not zero or src "
+          "is not MemRef type or has non-unit vnni stride");
+
+    if (!validateVectorProdOp(contractOp.getRhs()))
+      return rewriter.notifyMatchFailure(
+          contractOp,
+          "The LHS is in invalid format. Either it has false inbound or "
+          "non-identical permuation map or the vnni offset is not zero or src "
+          "is not MemRef type or has non-unit vnni stride");
+
     // Lower vector.contract to FMAs with help of BF16 packed ops.
     auto loc = contractOp.getLoc();
 
@@ -235,19 +299,11 @@ struct VectorContractBF16ToFMA
         rhsHasMultipleNonUnitDims ? rhsShape : lhsShape;
 
     // Build subviews.
-    auto unitSubview = getSubviewFromVectorInput(loc, rewriter, unitSrc,
-                                                 nonUnitDimShape, true);
-
-    auto nonUnitSubview = getSubviewFromVectorInput(loc, rewriter, nonUnitSrc,
-                                                    nonUnitDimShape, false);
+    auto unitDimSubview = getSubviewFromVectorInput(loc, rewriter, unitSrc,
+                                                    nonUnitDimShape, true);
 
-    // Check failures once.
-    if (failed(unitSubview) || failed(nonUnitSubview))
-      return rewriter.notifyMatchFailure(
-          contractOp, "The input source is not MemRef Type.");
-
-    SmallVector<memref::SubViewOp> unitDimSubview = *unitSubview;
-    SmallVector<memref::SubViewOp> nonUnitDimSubview = *nonUnitSubview;
+    auto nonUnitDimSubview = getSubviewFromVectorInput(
+        loc, rewriter, nonUnitSrc, nonUnitDimShape, false);
 
     auto castAcc = vector::ShapeCastOp::create(
         rewriter, loc,
@@ -265,10 +321,6 @@ struct VectorContractBF16ToFMA
         vector::FMAOp::create(rewriter, loc, loadBcstOddIdxElementToF32,
                               loadOddIdxElementF32, castAcc);
 
-    OpResult vcResult = contractOp->getResult(0);
-    if (vcResult.hasOneUse())
-      rewriter.setInsertionPoint(*vcResult.getUsers().begin());
-
     // Load, broadcast, and do FMA for even indexed BF16 elements.
     auto loadBcstEvenIdxElementToF32 = x86vector::BcstToPackedF32Op::create(
         rewriter, loc, dstType, unitDimSubview[1]);
diff --git a/mlir/test/Dialect/X86Vector/vector-contract-bf16-to-fma.mlir b/mlir/test/Dialect/X86Vector/vector-contract-bf16-to-fma.mlir
index c55a859340600..b8797f5a7d2d9 100644
--- a/mlir/test/Dialect/X86Vector/vector-contract-bf16-to-fma.mlir
+++ b/mlir/test/Dialect/X86Vector/vector-contract-bf16-to-fma.mlir
@@ -385,6 +385,183 @@ func.func @negative_no_memref_src(
 // CHECK: vector.contract
 // CHECK-NOT: vector.fma
 
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.x86vector.vector_contract_bf16_to_fma
+    } : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+!vecA = vector<1x1x2xbf16>
+!vecB = vector<1x8x2xbf16>
+!vecC = vector<1x8xf32>
+!memrefA = memref<4x1x2xbf16>
+!memrefB = memref<1x32x2xbf16>
+#map = affine_map<(d4, d1, d2, d3) -> (d1, d3, d4)>
+#map1 = affine_map<(d4, d1, d2, d3) -> (d3, d2, d4)>
+#map2 = affine_map<(d4, d1, d2, d3) -> (d1, d2)>
+func.func @negative_vnni_offset_1(
+  %arg0: !memrefA, %arg1: !memrefB, %arg2: !vecC) -> !vecC
+{
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %0 = ub.poison : bf16
+  %1 = vector.load %arg0[%c0, %c0, %c0] :
+        !memrefA, !vecA
+  %2 = vector.load %arg1[%c0, %c0, %c1] :
+        !memrefB, !vecB
+  %3 = vector.contract {
+    indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["reduction", "parallel", "parallel", "reduction"],
+    kind = #vector.kind<add>}
+    %1, %2, %arg2
+    : !vecA, !vecB into !vecC
+  return %3 : !vecC
+}
+
+// CHECK-LABEL: @negative_vnni_offset_1
+// CHECK: vector.contract
+// CHECK-NOT: vector.fma
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.x86vector.vector_contract_bf16_to_fma
+    } : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+!vecA = vector<1x1x2xbf16>
+!vecB = vector<1x8x2xbf16>
+!vecC = vector<1x8xf32>
+!memrefA = memref<4x1x2xbf16>
+!memrefB = memref<1x32x2xbf16>
+#map = affine_map<(d4, d1, d2, d3) -> (d1, d3, d4)>
+#map1 = affine_map<(d4, d1, d2, d3) -> (d3, d2, d4)>
+#map2 = affine_map<(d4, d1, d2, d3) -> (d1, d2)>
+#perm0 = affine_map<(d1, d2, d3) -> (d2, d1, d3)>
+func.func @negative_perm_map_not_identical(
+  %arg0: !memrefA, %arg1: !memrefB, %arg2: !vecC) -> !vecC
+{
+  %c0 = arith.constant 0 : index
+  %0 = ub.poison : bf16
+  %1 = vector.transfer_read %arg0[%c0, %c0, %c0], %0 {permutation_map = #perm0,
+        in_bounds = [true, true, true]} : !memrefA, !vecA
+  %2 = vector.transfer_read %arg1[%c0, %c0, %c0], %0 {in_bounds = [true, true, true]} :
+        !memrefB, !vecB
+  %3 = vector.contract {
+    indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["reduction", "parallel", "parallel", "reduction"],
+    kind = #vector.kind<add>}
+    %1, %2, %arg2
+    : !vecA, !vecB into !vecC
+  return %3 : !vecC
+}
+
+// CHECK-LABEL: @negative_perm_map_not_identical
+// CHECK: vector.contract
+// CHECK-NOT: vector.fma
+
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.x86vector.vector_contract_bf16_to_fma
+    } : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+!vecA = vector<1x1x2xbf16>
+!vecB = vector<1x8x2xbf16>
+!vecC = vector<1x8xf32>
+!memrefA = memref<4x1x2xbf16>
+!memrefB = memref<1x32x2xbf16>
+#map = affine_map<(d4, d1, d2, d3) -> (d1, d3, d4)>
+#map1 = affine_map<(d4, d1, d2, d3) -> (d3, d2, d4)>
+#map2 = affine_map<(d4, d1, d2, d3) -> (d1, d2)>
+func.func @negative_non_unit_stride(
+  %arg0: !memrefA, %arg1: !memrefB, %arg2: !vecC) -> !vecC
+{
+  %c0 = arith.constant 0 : index
+  %0 = ub.poison : bf16
+  %subview_1 = memref.subview %arg1[%c0, %c0, %c0] [1, 16, 2] [1, 1, 2] :
+               !memrefB to memref<1x16x2xbf16, strided<[64, 2, 2], offset: ?>>
+
+  %1 = vector.transfer_read %arg0[%c0, %c0, %c0], %0 {in_bounds = [true, true, true]} :
+        !memrefA, !vecA
+  %2 = vector.transfer_read %subview_1[%c0, %c0, %c0], %0 {in_bounds = [true, true, true]} :
+        memref<1x16x2xbf16, strided<[64, 2, 2], offset: ?>>, !vecB
+  %3 = vector.contract {
+    indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["reduction", "parallel", "parallel", "reduction"],
+    kind = #vector.kind<add>}
+    %1, %2, %arg2
+    : !vecA, !vecB into !vecC
+  return %3 : !vecC
+}
+
+// CHECK-LABEL: @negative_non_unit_stride
+// CHECK: vector.contract
+// CHECK-NOT: vector.fma
+
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.x86vector.vector_contract_bf16_to_fma
+    } : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+!vecA = vector<1x1x2xbf16>
+!vecB = vector<1x8x2xbf16>
+!vecC = vector<1x8xf32>
+!memrefA = memref<4x1x2xbf16>
+!memrefB = memref<1x32x2xbf16>
+#map = affine_map<(d4, d1, d2, d3) -> (d1, d3, d4)>
+#map1 = affine_map<(d4, d1, d2, d3) -> (d3, d2, d4)>
+#map2 = affine_map<(d4, d1, d2, d3) -> (d1, d2)>
+func.func @negative_false_bound(
+  %arg0: !memrefA, %arg1: !memrefB, %arg2: !vecC, %arg3: index) -> !vecC
+{
+  %c0 = arith.constant 0 : index
+  %0 = ub.poison : bf16
+
+  %1 = vector.transfer_read %arg0[%c0, %arg3, %c0], %0 {in_bounds = [true, false, true]} :
+        !memrefA, !vecA
+  %2 = vector.transfer_read %arg1[%c0, %c0, %c0], %0 {in_bounds = [true, true, true]} :
+        !memrefB, !vecB
+  %3 = vector.contract {
+    indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["reduction", "parallel", "parallel", "reduction"],
+    kind = #vector.kind<add>}
+    %1, %2, %arg2
+    : !vecA, !vecB into !vecC
+  return %3 : !vecC
+}
+
+// CHECK-LABEL: @negative_false_bound
+// CHECK: vector.contract
+// CHECK-NOT: vector.fma
+
+
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
     %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op



More information about the Mlir-commits mailing list