[Mlir-commits] [mlir] [mlir][x86vector] Shuffle BF16 vector.contract output for Flat layout. (PR #174590)

Arun Thangamani llvmlistbot at llvm.org
Thu Feb 12 01:34:55 PST 2026


https://github.com/arun-thmn updated https://github.com/llvm/llvm-project/pull/174590

>From 5ddf7e96ddeb975534668b0bd0c991a4c483a171 Mon Sep 17 00:00:00 2001
From: Arun Thangamani <arun.thangamani at intel.com>
Date: Tue, 6 Jan 2026 05:44:34 -0800
Subject: [PATCH 01/17] initial commit for shuffle VC output

---
 .../TransformOps/X86VectorTransformOps.td     |  11 +
 .../mlir/Dialect/X86Vector/Transforms.h       |   3 +
 .../Dialect/X86Vector/Utils/X86VectorUtils.h  |  26 ++
 .../TransformOps/X86VectorTransformOps.cpp    |   5 +
 .../X86Vector/Transforms/CMakeLists.txt       |   1 +
 .../ShuffleBF16VectorContractResult.cpp       | 150 ++++++++
 .../X86Vector/Utils/X86VectorUtils.cpp        | 325 ++++++++++++++++++
 .../shuffle-bf16-vector-contract-result.mlir  |  66 ++++
 8 files changed, 587 insertions(+)
 create mode 100644 mlir/lib/Dialect/X86Vector/Transforms/ShuffleBF16VectorContractResult.cpp
 create mode 100644 mlir/test/Dialect/shuffle-bf16-vector-contract-result.mlir

diff --git a/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td b/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td
index 3c73eadf82167..00c611a9f3a7a 100644
--- a/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td
+++ b/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td
@@ -60,6 +60,17 @@ def ApplySinkVectorProducerOpsPatternsOp : Op<Transform_Dialect,
   let assemblyFormat = "attr-dict";
 }
 
+def ApplyShuffleBF16VectorContractResultPatternsOp : Op<Transform_Dialect,
+    "apply_patterns.x86vector.shuffle_bf16_vector_contract_result",
+    [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+  let description = [{
+    Collect patterns to shuffle results of flat layout BF16 type 
+       vector.contract operations.
+  }];
+
+  let assemblyFormat = "attr-dict";
+}
+
 
 #endif // X86VECTOR_TRANSFORM_OPS
 
diff --git a/mlir/include/mlir/Dialect/X86Vector/Transforms.h b/mlir/include/mlir/Dialect/X86Vector/Transforms.h
index c25cdaf2d9428..538ffaac79998 100644
--- a/mlir/include/mlir/Dialect/X86Vector/Transforms.h
+++ b/mlir/include/mlir/Dialect/X86Vector/Transforms.h
@@ -100,6 +100,9 @@ void populateVectorContractBF16ToFMAPatterns(RewritePatternSet &patterns);
 // range by placing them at their earliest legal use site.
 void populateSinkVectorProducerOpsPatterns(RewritePatternSet &patterns);
 
+void populateShuffleBF16VectorContractResultPatterns(
+    RewritePatternSet &patterns);
+
 //===----------------------------------------------------------------------===//
 /// Helpers extracted from:
 ///   - clang/lib/Headers/avxintrin.h
diff --git a/mlir/include/mlir/Dialect/X86Vector/Utils/X86VectorUtils.h b/mlir/include/mlir/Dialect/X86Vector/Utils/X86VectorUtils.h
index 2de9a3122cbd9..7267c5c2032d9 100644
--- a/mlir/include/mlir/Dialect/X86Vector/Utils/X86VectorUtils.h
+++ b/mlir/include/mlir/Dialect/X86Vector/Utils/X86VectorUtils.h
@@ -9,6 +9,9 @@
 #ifndef MLIR_DIALECT_X86VECTOR_UTILS_X86VECTORUTILS_H_
 #define MLIR_DIALECT_X86VECTOR_UTILS_X86VECTORUTILS_H_
 
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/Value.h"
 #include "mlir/Support/LogicalResult.h"
 #include "llvm/ADT/ArrayRef.h"
 #include <cstdint>
@@ -26,6 +29,29 @@ namespace x86vector {
 bool isInVnniLayout(Operation *op, llvm::ArrayRef<AffineMap> indexingMaps,
                     std::optional<unsigned> blockingFactor = std::nullopt);
 
+bool validatePairVectorContract(vector::ContractionOp contractOp,
+                                vector::ContractionOp pairContOp,
+                                bool rhsHasMultipleNonUnitDims,
+                                int64_t nonUnitDimValue);
+
+Operation *traceToVectorReadLikeParentOperation(mlir::Value v);
+
+Operation *traceToVectorWriteLikeUserOperation(mlir::Value v);
+
+void shuffleBeforeWriteLikeOp(PatternRewriter &rewriter, Operation *op,
+                              Operation *op1, int64_t nonUnitDimAcc,
+                              VectorType accTy);
+
+void shuffleAfterReadLikeOp(PatternRewriter &rewriter, Operation *op,
+                            Operation *op1, vector::ContractionOp contractOp,
+                            vector::ContractionOp pairContractOp,
+                            int64_t nonUnitDimAcc, VectorType accTy);
+
+void shuffleNonUnitDimOperand(PatternRewriter &rewriter, Operation *op,
+                              Operation *op1, vector::ContractionOp contractOp,
+                              vector::ContractionOp pairContractOp,
+                              int64_t nonUnitDimAcc, VectorType Ty);
+
 } // namespace x86vector
 } // namespace mlir
 
diff --git a/mlir/lib/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp b/mlir/lib/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp
index e77d30c9c5ffb..e40ddd3a4b1c0 100644
--- a/mlir/lib/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp
+++ b/mlir/lib/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp
@@ -42,6 +42,11 @@ void mlir::transform::ApplySinkVectorProducerOpsPatternsOp::populatePatterns(
   x86vector::populateSinkVectorProducerOpsPatterns(patterns);
 }
 
+void mlir::transform::ApplyShuffleBF16VectorContractResultPatternsOp::
+    populatePatterns(RewritePatternSet &patterns) {
+  x86vector::populateShuffleBF16VectorContractResultPatterns(patterns);
+}
+
 //===----------------------------------------------------------------------===//
 // Transform op registration
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/X86Vector/Transforms/CMakeLists.txt b/mlir/lib/Dialect/X86Vector/Transforms/CMakeLists.txt
index bbd9be880eb0a..acbc7fcfb635e 100644
--- a/mlir/lib/Dialect/X86Vector/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/X86Vector/Transforms/CMakeLists.txt
@@ -5,6 +5,7 @@ add_mlir_dialect_library(MLIRX86VectorTransforms
   VectorContractToPackedTypeDotProduct.cpp
   VectorContractBF16ToFMA.cpp
   SinkVectorProducerOps.cpp
+  ShuffleBF16VectorContractResult.cpp
 
   LINK_LIBS PUBLIC
   MLIRArithDialect
diff --git a/mlir/lib/Dialect/X86Vector/Transforms/ShuffleBF16VectorContractResult.cpp b/mlir/lib/Dialect/X86Vector/Transforms/ShuffleBF16VectorContractResult.cpp
new file mode 100644
index 0000000000000..2081de9f2be91
--- /dev/null
+++ b/mlir/lib/Dialect/X86Vector/Transforms/ShuffleBF16VectorContractResult.cpp
@@ -0,0 +1,150 @@
+//===-
+//ShuffleBF16VectorContractResult.cpp-----------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
+#include "mlir/Dialect/X86Vector/Transforms.h"
+#include "mlir/Dialect/X86Vector/Utils/X86VectorUtils.h"
+#include "mlir/Dialect/X86Vector/X86VectorDialect.h"
+
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/Dominance.h"
+#include "mlir/IR/PatternMatch.h"
+
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/Support/Casting.h"
+
+using namespace mlir;
+using namespace mlir::vector;
+using namespace mlir::x86vector;
+
+struct ShuffleBF16VectorContractResult
+    : public OpRewritePattern<vector::ContractionOp> {
+  using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
+                                PatternRewriter &rewriter) const override {
+
+    if (contractOp.getKind() != vector::CombiningKind::ADD)
+      return rewriter.notifyMatchFailure(contractOp,
+                                         "Expects add combining kind.");
+
+    // TODO: Move this validation to a common utility folder. Planned to
+    // do once (code refactoring), all architecture specific nanokernel
+    // passes are merged into the repo.
+    VectorType lhsTy = contractOp.getLhsType();
+    if (!lhsTy.getElementType().isBF16())
+      return rewriter.notifyMatchFailure(contractOp,
+                                         "Only BF16 lowering is supported.");
+
+    if (isInVnniLayout(contractOp.getOperation(),
+                       contractOp.getIndexingMapsArray(),
+                       /*blockingFactor=*/2))
+      return rewriter.notifyMatchFailure(contractOp,
+                                         "Input matrices in VNNI format.");
+
+    VectorType accTy = dyn_cast<VectorType>(contractOp.getAccType());
+    if (!accTy)
+      return rewriter.notifyMatchFailure(contractOp, "Wrong accmulator type.");
+
+    if (!accTy.getElementType().isF32())
+      return rewriter.notifyMatchFailure(
+          contractOp, "Only F32 acumulation supported for BF16 type.");
+
+    ArrayRef<int64_t> accShape = accTy.getShape();
+    llvm::SmallVector<int64_t> nonUnitDimAcc;
+    llvm::copy_if(accShape, std::back_inserter(nonUnitDimAcc),
+                  [](int64_t dim) { return dim != 1; });
+
+    if (nonUnitDimAcc.size() != 1)
+      return rewriter.notifyMatchFailure(
+          contractOp, "A or B should be a non-unit dim in acc.");
+
+    int64_t nonUnitDimValue = nonUnitDimAcc.front();
+
+    if (nonUnitDimValue != 8 && nonUnitDimValue != 16)
+      return rewriter.notifyMatchFailure(
+          contractOp, "The accumulator dimension should be 8 or 16");
+
+    ArrayRef<int64_t> lhsShape = lhsTy.getShape();
+    llvm::SmallVector<int64_t> nonUnitDimLhs;
+    llvm::copy_if(lhsShape, std::back_inserter(nonUnitDimLhs),
+                  [](int64_t dim) { return dim != 1; });
+
+    VectorType rhsTy = contractOp.getRhsType();
+    ArrayRef<int64_t> rhsShape = rhsTy.getShape();
+    llvm::SmallVector<int64_t> nonUnitDimRhs;
+    llvm::copy_if(rhsShape, std::back_inserter(nonUnitDimRhs),
+                  [](int64_t dim) { return dim != 1; });
+
+    if ((nonUnitDimValue == 16) && (nonUnitDimLhs.size() - 1) > 0 &&
+        (nonUnitDimRhs.size() - 1) > 0)
+      return rewriter.notifyMatchFailure(contractOp,
+                                         "Excepts unit dimensions for either "
+                                         "LHS or RHS shape.");
+    if (nonUnitDimValue == 8 && nonUnitDimLhs.size() > 0 &&
+        nonUnitDimRhs.size() > 0)
+      return rewriter.notifyMatchFailure(contractOp,
+                                         "Excepts unit dimensions for either "
+                                         "LHS or RHS shape.");
+
+    vector::ContractionOp pairContractOp;
+    bool rhsHasMultipleNonUnitDims = nonUnitDimValue == 16
+                                         ? (nonUnitDimRhs.size() - 1) > 0
+                                         : nonUnitDimRhs.size() > 0;
+
+    Operation *nextOp = contractOp;
+    while ((nextOp = nextOp->getNextNode())) {
+      auto contOp = dyn_cast<vector::ContractionOp>(nextOp);
+
+      if (!contOp)
+        continue;
+
+      if (validatePairVectorContract(
+              contractOp, contOp, rhsHasMultipleNonUnitDims, nonUnitDimValue)) {
+        pairContractOp = contOp;
+        break;
+      }
+    }
+
+    if (!pairContractOp)
+      return failure();
+
+    Operation *accReadOp0 =
+        traceToVectorReadLikeParentOperation(contractOp.getAcc());
+    Operation *accReadOp1 =
+        traceToVectorReadLikeParentOperation(pairContractOp.getAcc());
+
+    Operation *resultWriteOp0 =
+        traceToVectorWriteLikeUserOperation(contractOp.getResult());
+    Operation *resultWriteOp1 =
+        traceToVectorWriteLikeUserOperation(pairContractOp.getResult());
+
+    if (!accReadOp0 || !accReadOp1)
+      return failure();
+
+    if (!resultWriteOp0 || !resultWriteOp1)
+      return failure();
+
+    shuffleAfterReadLikeOp(rewriter, accReadOp0, accReadOp1, contractOp,
+                           pairContractOp, nonUnitDimValue, accTy);
+    shuffleBeforeWriteLikeOp(rewriter, resultWriteOp0, resultWriteOp1,
+                             nonUnitDimValue, accTy);
+
+    return success();
+  }
+};
+
+void x86vector::populateShuffleBF16VectorContractResultPatterns(
+    RewritePatternSet &patterns) {
+  patterns.add<ShuffleBF16VectorContractResult>(patterns.getContext());
+}
diff --git a/mlir/lib/Dialect/X86Vector/Utils/X86VectorUtils.cpp b/mlir/lib/Dialect/X86Vector/Utils/X86VectorUtils.cpp
index ccb2e92fdd9e2..65cd0773160ef 100644
--- a/mlir/lib/Dialect/X86Vector/Utils/X86VectorUtils.cpp
+++ b/mlir/lib/Dialect/X86Vector/Utils/X86VectorUtils.cpp
@@ -10,11 +10,18 @@
 
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
 #include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/TypeUtilities.h"
 #include "mlir/IR/Types.h"
 
+#include "llvm/ADT/TypeSwitch.h"
+#include "llvm/Support/Casting.h"
+
+#include "llvm/ADT/ArrayRef.h"
+#include <cassert>
+
 namespace mlir {
 namespace x86vector {
 
@@ -104,5 +111,323 @@ bool isInVnniLayout(Operation *op, ArrayRef<AffineMap> indexingMaps,
   return true;
 }
 
+struct ShuffleMasks {
+  llvm::ArrayRef<int64_t> maskLo;
+  llvm::ArrayRef<int64_t> maskHi;
+};
+
+inline ShuffleMasks getShuffleMasks(int64_t nonUnitDimAcc) {
+  // We only support these two layouts for now.
+  assert((nonUnitDimAcc == 8 || nonUnitDimAcc == 16) &&
+         "Unsupported nonUnitDimAcc value");
+
+  static constexpr int64_t maskLo8[] = {0, 8, 1, 9, 2, 10, 3, 11};
+  static constexpr int64_t maskHi8[] = {4, 12, 5, 13, 6, 14, 7, 15};
+
+  static constexpr int64_t maskLo16[] = {0, 1, 2, 3, 16, 17, 18, 19,
+                                         4, 5, 6, 7, 20, 21, 22, 23};
+  static constexpr int64_t maskHi16[] = {8,  9,  10, 11, 24, 25, 26, 27,
+                                         12, 13, 14, 15, 28, 29, 30, 31};
+
+  if (nonUnitDimAcc == 16)
+    return {maskLo16, maskHi16};
+
+  // nonUnitDimAcc == 8
+  return {maskLo8, maskHi8};
+}
+
+Operation *traceToVectorReadLikeParentOperation(Value v) {
+  while (true) {
+    // Case 1: Value defined by an operation
+    if (Operation *defOp = v.getDefiningOp()) {
+      if (isa<vector::TransferReadOp, vector::LoadOp>(defOp)) {
+        return defOp;
+      }
+
+      if (isa<vector::ShapeCastOp, vector::ShuffleOp>(defOp)) {
+        return nullptr;
+      }
+
+      // Continue tracing (accumulators usually forward the value)
+      if (defOp->getNumOperands() == 1) {
+        v = defOp->getOperand(0);
+        continue;
+      }
+
+      return nullptr;
+    }
+
+    // Case 2: BlockArgument (scf.for iter_arg)
+    if (auto barg = dyn_cast<BlockArgument>(v)) {
+      auto *parentOp = barg.getOwner()->getParentOp();
+
+      if (auto forOp = dyn_cast<scf::ForOp>(parentOp)) {
+        unsigned argNum = barg.getArgNumber();
+
+        // arg0 = induction variable (not an iter_arg)
+        if (argNum == 0)
+          return nullptr;
+
+        unsigned iterIdx = argNum - 1;
+        v = forOp.getInitArgs()[iterIdx];
+        continue;
+      }
+
+      return nullptr;
+    }
+
+    return nullptr;
+  }
+}
+
+Operation *traceToVectorWriteLikeUserOperation(Value v) {
+  for (OpOperand &use : v.getUses()) {
+    Operation *user = use.getOwner();
+
+    // --- TERMINAL OPS ---
+    if (isa<vector::TransferWriteOp>(user) || isa<vector::StoreOp>(user)) {
+      return user;
+    }
+
+    if (isa<vector::ShapeCastOp, vector::ShuffleOp>(user)) {
+      return nullptr;
+    }
+
+    // --- SCF YIELD ---
+    if (auto yield = dyn_cast<scf::YieldOp>(user)) {
+      Operation *parent = yield->getParentOp();
+      unsigned idx = use.getOperandNumber();
+      if (auto *res =
+              traceToVectorWriteLikeUserOperation(parent->getResult(idx)))
+        return res;
+      continue;
+    }
+
+    // --- SCF FOR ---
+    if (auto forOp = dyn_cast<scf::ForOp>(user)) {
+      unsigned idx = use.getOperandNumber();
+      if (auto *res = traceToVectorWriteLikeUserOperation(forOp.getResult(idx)))
+        return res;
+      continue;
+    }
+
+    // --- GENERIC CASE ---
+    for (Value res : user->getResults()) {
+      if (auto *found = traceToVectorWriteLikeUserOperation(res))
+        return found;
+    }
+  }
+
+  return nullptr;
+}
+
+static void rewriteUses(mlir::Value oldVal, mlir::Value newVal,
+                        mlir::Operation *targetContract,
+                        mlir::PatternRewriter &rewriter) {
+  for (mlir::OpOperand &use : llvm::make_early_inc_range(oldVal.getUses())) {
+
+    mlir::Operation *user = use.getOwner();
+
+    // if (user == targetContract ||
+    if (mlir::isa<mlir::vector::ContractionOp>(user) ||
+        mlir::isa<mlir::scf::ForOp>(user)) {
+      use.set(newVal);
+    }
+  }
+}
+
+void shuffleAfterReadLikeOp(mlir::PatternRewriter &rewriter,
+                            mlir::Operation *opA, mlir::Operation *opB,
+                            mlir::vector::ContractionOp contractA,
+                            mlir::vector::ContractionOp contractB,
+                            int64_t nonUnitDimAcc, mlir::VectorType accTy) {
+  mlir::Operation *insertAfter = opA->isBeforeInBlock(opB) ? opB : opA;
+
+  rewriter.setInsertionPointAfter(insertAfter);
+  mlir::Location loc = insertAfter->getLoc();
+
+  auto elemTy = accTy.getElementType();
+  auto flatTy = mlir::VectorType::get(nonUnitDimAcc, elemTy);
+
+  auto castA = mlir::vector::ShapeCastOp::create(rewriter, loc, flatTy,
+                                                 opA->getResult(0));
+
+  auto castB = mlir::vector::ShapeCastOp::create(rewriter, loc, flatTy,
+                                                 opB->getResult(0));
+
+  auto masks = getShuffleMasks(nonUnitDimAcc);
+
+  auto shuffleLo = mlir::vector::ShuffleOp::create(rewriter, loc, flatTy, castA,
+                                                   castB, masks.maskLo);
+
+  auto shuffleHi = mlir::vector::ShuffleOp::create(rewriter, loc, flatTy, castA,
+                                                   castB, masks.maskHi);
+
+  auto newAccA =
+      mlir::vector::ShapeCastOp::create(rewriter, loc, accTy, shuffleLo);
+
+  auto newAccB =
+      mlir::vector::ShapeCastOp::create(rewriter, loc, accTy, shuffleHi);
+
+  rewriteUses(opA->getResult(0), newAccA.getResult(), contractA, rewriter);
+
+  rewriteUses(opB->getResult(0), newAccB.getResult(), contractB, rewriter);
+}
+
+void shuffleBeforeWriteLikeOp(mlir::PatternRewriter &rewriter,
+                              mlir::Operation *opA, mlir::Operation *opB,
+                              int64_t nonUnitDimAcc, mlir::VectorType accTy) {
+  // Helper to extract vector operand from write-like ops
+  auto getWrittenVector = [](mlir::Operation *op) -> mlir::Value {
+    if (auto write = mlir::dyn_cast<mlir::vector::TransferWriteOp>(op))
+      return write.getVector();
+    if (auto store = mlir::dyn_cast<mlir::vector::StoreOp>(op))
+      return store.getValueToStore();
+    return nullptr;
+  };
+
+  mlir::Value vecA = getWrittenVector(opA);
+  mlir::Value vecB = getWrittenVector(opB);
+
+  assert(vecA && vecB && "expected vector write-like ops");
+
+  // Decide insertion point and location
+  mlir::Operation *insertBefore = opA->isBeforeInBlock(opB) ? opA : opB;
+
+  rewriter.setInsertionPoint(insertBefore);
+  mlir::Location loc = insertBefore->getLoc();
+
+  auto elemTy = accTy.getElementType();
+  auto flatTy = mlir::VectorType::get(nonUnitDimAcc, elemTy);
+
+  // Flatten vectors
+  auto castA = mlir::vector::ShapeCastOp::create(rewriter, loc, flatTy, vecA);
+
+  auto castB = mlir::vector::ShapeCastOp::create(rewriter, loc, flatTy, vecB);
+
+  // TODO: derive shuffle masks instead of hard-coding
+  auto masks = getShuffleMasks(nonUnitDimAcc);
+
+  auto shuffledLo = mlir::vector::ShuffleOp::create(rewriter, loc, flatTy,
+                                                    castA, castB, masks.maskLo);
+
+  auto shuffledHi = mlir::vector::ShuffleOp::create(rewriter, loc, flatTy,
+                                                    castA, castB, masks.maskHi);
+
+  // Cast back to accumulator type
+  auto newVecA =
+      mlir::vector::ShapeCastOp::create(rewriter, loc, accTy, shuffledLo);
+
+  auto newVecB =
+      mlir::vector::ShapeCastOp::create(rewriter, loc, accTy, shuffledHi);
+
+  // Update write operands in place
+  opA->setOperand(0, newVecA.getResult());
+  opB->setOperand(0, newVecB.getResult());
+}
+
+void shuffleNonUnitDimOperand(mlir::PatternRewriter &rewriter,
+                              mlir::Operation *opA, mlir::Operation *opB,
+                              mlir::vector::ContractionOp contractA,
+                              mlir::vector::ContractionOp contractB,
+                              int64_t nonUnitDimAcc, mlir::VectorType Ty) {
+  mlir::Operation *insertAfter = opA->isBeforeInBlock(opB) ? opB : opA;
+
+  rewriter.setInsertionPointAfter(insertAfter);
+  mlir::Location loc = insertAfter->getLoc();
+
+  auto elemTy = Ty.getElementType();
+  auto flatTy = mlir::VectorType::get(nonUnitDimAcc, elemTy);
+
+  auto castA = mlir::vector::ShapeCastOp::create(rewriter, loc, flatTy,
+                                                 opA->getResult(0));
+
+  auto castB = mlir::vector::ShapeCastOp::create(rewriter, loc, flatTy,
+                                                 opB->getResult(0));
+
+  static constexpr int64_t maskLo[] = {
+      0,  32, 1,  33, 2,  34, 3,  35, 8,  40, 9,  41, 10, 42, 11, 43,
+      16, 48, 17, 49, 18, 50, 19, 51, 24, 56, 25, 57, 26, 58, 27, 59};
+  static constexpr int64_t maskHi[] = {
+      4,  36, 5,  37, 6,  38, 7,  39, 12, 44, 13, 45, 14, 46, 15, 47,
+      20, 52, 21, 53, 22, 54, 23, 55, 28, 60, 29, 61, 30, 62, 31, 63};
+
+  auto shuffleLo = mlir::vector::ShuffleOp::create(rewriter, loc, flatTy, castA,
+                                                   castB, maskLo);
+
+  auto shuffleHi = mlir::vector::ShuffleOp::create(rewriter, loc, flatTy, castA,
+                                                   castB, maskHi);
+
+  auto newAccA =
+      mlir::vector::ShapeCastOp::create(rewriter, loc, Ty, shuffleLo);
+
+  auto newAccB =
+      mlir::vector::ShapeCastOp::create(rewriter, loc, Ty, shuffleHi);
+
+  rewriteUses(opA->getResult(0), newAccA.getResult(), contractA, rewriter);
+
+  rewriteUses(opB->getResult(0), newAccB.getResult(), contractB, rewriter);
+}
+
+bool validatePairVectorContract(vector::ContractionOp contractOp,
+                                vector::ContractionOp pairContOp,
+                                bool rhsHasMultipleNonUnitDims,
+                                int64_t nonUnitDimValue) {
+
+  if (!(contractOp.getLhs() == pairContOp.getLhs()) &&
+      !(contractOp.getRhs() == pairContOp.getRhs()))
+    return false;
+
+  if (rhsHasMultipleNonUnitDims &&
+      !(contractOp.getLhs() == pairContOp.getLhs()))
+    return false;
+
+  if (!rhsHasMultipleNonUnitDims &&
+      !(contractOp.getRhs() == pairContOp.getRhs()))
+    return false;
+
+  auto op =
+      rhsHasMultipleNonUnitDims ? contractOp.getRhs() : contractOp.getLhs();
+  auto op1 =
+      rhsHasMultipleNonUnitDims ? pairContOp.getRhs() : pairContOp.getLhs();
+
+  Value srcBuff;
+  SmallVector<OpFoldResult> indexVals;
+  llvm::TypeSwitch<mlir::Operation *>(op.getDefiningOp())
+      .Case<vector::TransferReadOp, vector::LoadOp>([&](auto readOp) {
+        srcBuff = readOp.getOperand(0);
+        indexVals = SmallVector<OpFoldResult>(readOp.getIndices().begin(),
+                                              readOp.getIndices().end());
+      });
+
+  Value srcBuff1;
+  SmallVector<OpFoldResult> indexVals1;
+  llvm::TypeSwitch<mlir::Operation *>(op1.getDefiningOp())
+      .Case<vector::TransferReadOp, vector::LoadOp>([&](auto readOp) {
+        srcBuff1 = readOp.getOperand(0);
+        indexVals1 = SmallVector<OpFoldResult>(readOp.getIndices().begin(),
+                                               readOp.getIndices().end());
+      });
+
+  if (!srcBuff || !srcBuff1)
+    return false;
+
+  if (!(srcBuff == srcBuff1))
+    return false;
+
+  for (size_t i = 0; i < indexVals.size(); i++) {
+    if (getConstantIntValue(indexVals[i]) == getConstantIntValue(indexVals1[i]))
+      continue;
+
+    auto value1 = *getConstantIntValue(indexVals1[i]);
+    auto value2 = *getConstantIntValue(indexVals[i]);
+
+    if ((value1 - value2) != nonUnitDimValue)
+      return false;
+  }
+
+  return true;
+}
+
 } // namespace x86vector
 } // namespace mlir
diff --git a/mlir/test/Dialect/shuffle-bf16-vector-contract-result.mlir b/mlir/test/Dialect/shuffle-bf16-vector-contract-result.mlir
new file mode 100644
index 0000000000000..7b1a6eaa1258a
--- /dev/null
+++ b/mlir/test/Dialect/shuffle-bf16-vector-contract-result.mlir
@@ -0,0 +1,66 @@
+// RUN: mlir-opt %s -transform-interpreter -cse -split-input-file | FileCheck %s
+
+
+!vecA = vector<1x1x1xbf16>
+!vecB = vector<1x1x8xbf16>
+!vecC = vector<1x8xf32>
+!memrefA = memref<1x4x1xbf16>
+!memrefB = memref<1x1x32xbf16>
+!memrefC = memref<2x32xf32>
+#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+#map1 = affine_map<(d0,  d1, d2, d3) -> (d0, d3, d2)>
+#map2 = affine_map<(d0,  d1, d2, d3) -> (d1, d2)>
+func.func @shuffle_VC_output_flat_layout(
+  %arg0: !memrefA, %arg1: !memrefB, %arg2: !memrefC) -> !memrefC
+{
+  %c0 = arith.constant 0 : index
+  %c8 = arith.constant 8 : index
+  %0 = ub.poison : bf16
+  %1 = vector.load %arg0[%c0, %c0, %c0] :
+        !memrefA, !vecA
+  %2 = vector.load %arg1[%c0, %c0, %c0] :
+        !memrefB, !vecB
+  %3 = vector.load %arg1[%c0, %c0, %c8] :
+        !memrefB, !vecB
+  %4 = vector.load %arg2[%c0, %c0] :
+        !memrefC, !vecC
+  %5 = vector.load %arg2[%c0, %c8] :
+        !memrefC, !vecC
+
+  %6 = vector.contract {
+    indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["reduction", "parallel", "parallel", "reduction"],
+    kind = #vector.kind<add>}
+    %1, %2, %4
+    : !vecA, !vecB into !vecC
+
+  %7 = vector.contract {
+    indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["reduction", "parallel", "parallel", "reduction"],
+    kind = #vector.kind<add>}
+    %1, %3, %5
+    : !vecA, !vecB into !vecC
+
+  vector.store %6, %arg2[%c0, %c0] : !memrefC, !vecC
+
+  vector.store %7, %arg2[%c0, %c8] : !memrefC, !vecC
+
+  return %arg2 : !memrefC
+}
+
+// CHECK-LABEL: @shuffle_VC_output_flat_layout
+// CHECK: vector.shuffle
+// CHECK-NEXT: vector.shuffle
+// CHECK: vector.contract
+// CHECK: vector.shuffle
+// CHECK-NEXT: vector.shuffle
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.x86vector.shuffle_bf16_vector_contract_result
+    } : !transform.any_op
+    transform.yield
+  }
+}

>From 7fcfa7772f5e15ebb12653a10ba0c9d061d4bd84 Mon Sep 17 00:00:00 2001
From: Arun Thangamani <arun.thangamani at intel.com>
Date: Tue, 6 Jan 2026 06:37:29 -0800
Subject: [PATCH 02/17] fix header format issue

---
 .../X86Vector/Transforms/ShuffleBF16VectorContractResult.cpp   | 3 +--
 1 file changed, 1 insertion(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/X86Vector/Transforms/ShuffleBF16VectorContractResult.cpp b/mlir/lib/Dialect/X86Vector/Transforms/ShuffleBF16VectorContractResult.cpp
index 2081de9f2be91..faa1b0cd48131 100644
--- a/mlir/lib/Dialect/X86Vector/Transforms/ShuffleBF16VectorContractResult.cpp
+++ b/mlir/lib/Dialect/X86Vector/Transforms/ShuffleBF16VectorContractResult.cpp
@@ -1,5 +1,4 @@
-//===-
-//ShuffleBF16VectorContractResult.cpp-----------------------------------------===//
+//===- ShuffleBF16VectorContractResult.cpp --------------------------------===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.

>From b73ef356f51547a07efb7b3149062e01d3f675a2 Mon Sep 17 00:00:00 2001
From: Arun Thangamani <arun.thangamani at intel.com>
Date: Tue, 6 Jan 2026 20:33:32 -0800
Subject: [PATCH 03/17] added couple of unit tests

---
 .../shuffle-bf16-vector-contract-result.mlir  | 143 +++++++++++++++++-
 1 file changed, 138 insertions(+), 5 deletions(-)

diff --git a/mlir/test/Dialect/shuffle-bf16-vector-contract-result.mlir b/mlir/test/Dialect/shuffle-bf16-vector-contract-result.mlir
index 7b1a6eaa1258a..c109543a8a163 100644
--- a/mlir/test/Dialect/shuffle-bf16-vector-contract-result.mlir
+++ b/mlir/test/Dialect/shuffle-bf16-vector-contract-result.mlir
@@ -1,6 +1,5 @@
 // RUN: mlir-opt %s -transform-interpreter -cse -split-input-file | FileCheck %s
 
-
 !vecA = vector<1x1x1xbf16>
 !vecB = vector<1x1x8xbf16>
 !vecC = vector<1x8xf32>
@@ -49,11 +48,145 @@ func.func @shuffle_VC_output_flat_layout(
 }
 
 // CHECK-LABEL: @shuffle_VC_output_flat_layout
-// CHECK: vector.shuffle
-// CHECK-NEXT: vector.shuffle
+// CHECK: vector.shuffle{{.*}}[0, 8, 1, 9, 2, 10, 3, 11] : vector<8xf32>, vector<8xf32>
+// CHECK-NEXT: vector.shuffle{{.*}}[4, 12, 5, 13, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32>
+// CHECK: vector.contract
+// CHECK: vector.shuffle{{.*}}[0, 8, 1, 9, 2, 10, 3, 11] : vector<8xf32>, vector<8xf32>
+// CHECK-NEXT: vector.shuffle{{.*}}[4, 12, 5, 13, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32>
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.x86vector.shuffle_bf16_vector_contract_result
+    } : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+!vecA = vector<1x1xbf16>
+!vecB = vector<1x8xbf16>
+!vecC = vector<1x8xf32>
+!memrefA = memref<4x1xbf16>
+!memrefB = memref<1x32xbf16>
+!memrefC = memref<2x32xf32>
+#map = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0,  d1, d2) -> (d2, d1)>
+#map2 = affine_map<(d0,  d1, d2) -> (d0, d1)>
+func.func @shuffle_VC_output_flat_layout_transfer_read(
+  %arg0: !memrefA, %arg1: !memrefB, %arg2: !memrefC) -> !memrefC
+{
+  %c0 = arith.constant 0 : index
+  %c8 = arith.constant 8 : index
+  %0 = ub.poison : bf16
+  %32 = ub.poison : f32
+  %1 = vector.transfer_read %arg0[%c0, %c0], %0 {in_bounds = [true, true]} :
+        !memrefA, !vecA
+  %2 = vector.transfer_read %arg1[%c0, %c0], %0 {in_bounds = [true, true]} :
+        !memrefB, !vecB
+  %3 = vector.transfer_read %arg1[%c0, %c8], %0 {in_bounds = [true, true]} :
+        !memrefB, !vecB
+  %4 = vector.transfer_read %arg2[%c0, %c0], %32 {in_bounds = [true, true]} :
+        !memrefC, !vecC
+  %5 = vector.transfer_read %arg2[%c0, %c8], %32 {in_bounds = [true, true]} :
+        !memrefC, !vecC
+
+  %6 = vector.contract {
+    indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["parallel", "parallel", "reduction"],
+    kind = #vector.kind<add>}
+    %1, %2, %4
+    : !vecA, !vecB into !vecC
+
+  %7 = vector.contract {
+    indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["parallel", "parallel", "reduction"],
+    kind = #vector.kind<add>}
+    %1, %3, %5
+    : !vecA, !vecB into !vecC
+
+  vector.transfer_write %6, %arg2[%c0, %c0] {in_bounds = [true, true]} : !vecC, !memrefC
+
+  vector.transfer_write %7, %arg2[%c0, %c8] {in_bounds = [true, true]} : !vecC, !memrefC
+
+  return %arg2 : !memrefC
+}
+
+// CHECK-LABEL: @shuffle_VC_output_flat_layout_transfer_read
+// CHECK: vector.shuffle{{.*}}[0, 8, 1, 9, 2, 10, 3, 11] : vector<8xf32>, vector<8xf32>
+// CHECK-NEXT: vector.shuffle{{.*}}[4, 12, 5, 13, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32>
+// CHECK: vector.contract
+// CHECK: vector.shuffle{{.*}}[0, 8, 1, 9, 2, 10, 3, 11] : vector<8xf32>, vector<8xf32>
+// CHECK-NEXT: vector.shuffle{{.*}}[4, 12, 5, 13, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32>
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.x86vector.shuffle_bf16_vector_contract_result
+    } : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+!vecA = vector<1x2xbf16>
+!vecB = vector<2x16xbf16>
+!vecC = vector<1x16xf32>
+!memrefA = memref<4x2xbf16>
+!memrefB = memref<2x32xbf16>
+!memrefC = memref<2x32xf32>
+#map = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0,  d1, d2) -> (d2, d1)>
+#map2 = affine_map<(d0,  d1, d2) -> (d0, d1)>
+func.func @shuffle_VC_output_flat_layout_bf16dp(
+  %arg0: !memrefA, %arg1: !memrefB, %arg2: !memrefC) -> !memrefC
+{
+  %c0 = arith.constant 0 : index
+  %c16 = arith.constant 16 : index
+  %0 = ub.poison : bf16
+  %32 = ub.poison : f32
+  %1 = vector.transfer_read %arg0[%c0, %c0], %0 {in_bounds = [true, true]} :
+        !memrefA, !vecA
+  %2 = vector.transfer_read %arg1[%c0, %c0], %0 {in_bounds = [true, true]} :
+        !memrefB, !vecB
+  %3 = vector.transfer_read %arg1[%c0, %c16], %0 {in_bounds = [true, true]} :
+        !memrefB, !vecB
+  %4 = vector.transfer_read %arg2[%c0, %c0], %32 {in_bounds = [true, true]} :
+        !memrefC, !vecC
+  %5 = vector.transfer_read %arg2[%c0, %c16], %32 {in_bounds = [true, true]} :
+        !memrefC, !vecC
+
+  %6 = vector.contract {
+    indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["parallel", "parallel", "reduction"],
+    kind = #vector.kind<add>}
+    %1, %2, %4
+    : !vecA, !vecB into !vecC
+
+  %7 = vector.contract {
+    indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["parallel", "parallel", "reduction"],
+    kind = #vector.kind<add>}
+    %1, %3, %5
+    : !vecA, !vecB into !vecC
+
+  vector.transfer_write %6, %arg2[%c0, %c0] {in_bounds = [true, true]} : !vecC, !memrefC
+
+  vector.transfer_write %7, %arg2[%c0, %c16] {in_bounds = [true, true]} : !vecC, !memrefC
+
+  return %arg2 : !memrefC
+}
+
+// CHECK-LABEL: @shuffle_VC_output_flat_layout_bf16dp
+// CHECK: vector.shuffle{{.*}}[0, 1, 2, 3, 16, 17, 18, 19, 4, 5, 6, 7, 20, 21, 22, 23] : vector<16xf32>, vector<16xf32>
+// CHECK-NEXT: vector.shuffle{{.*}}[8, 9, 10, 11, 24, 25, 26, 27, 12, 13, 14, 15, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32>
 // CHECK: vector.contract
-// CHECK: vector.shuffle
-// CHECK-NEXT: vector.shuffle
+// CHECK: vector.shuffle{{.*}}[0, 1, 2, 3, 16, 17, 18, 19, 4, 5, 6, 7, 20, 21, 22, 23] : vector<16xf32>, vector<16xf32>
+// CHECK-NEXT: vector.shuffle{{.*}}[8, 9, 10, 11, 24, 25, 26, 27, 12, 13, 14, 15, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32>
 
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {

>From cc40fb30a3839b89cd55c56e11cff509d465df28 Mon Sep 17 00:00:00 2001
From: Arun Thangamani <arun.thangamani at intel.com>
Date: Thu, 8 Jan 2026 07:25:29 -0800
Subject: [PATCH 04/17] added support for BF16 flat layout to vector.fma using
 packed operations

---
 .../Transforms/VectorContractBF16ToFMA.cpp    | 173 ++++--
 .../X86Vector/Utils/X86VectorUtils.cpp        |  46 +-
 .../vector-contract-bf16-to-fma.mlir          | 495 ++++++++++++++++++
 3 files changed, 662 insertions(+), 52 deletions(-)

diff --git a/mlir/lib/Dialect/X86Vector/Transforms/VectorContractBF16ToFMA.cpp b/mlir/lib/Dialect/X86Vector/Transforms/VectorContractBF16ToFMA.cpp
index c60d9b91c18e5..eada03977595d 100644
--- a/mlir/lib/Dialect/X86Vector/Transforms/VectorContractBF16ToFMA.cpp
+++ b/mlir/lib/Dialect/X86Vector/Transforms/VectorContractBF16ToFMA.cpp
@@ -8,6 +8,7 @@
 
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/Dialect/Vector/Utils/VectorUtils.h"
 #include "mlir/Dialect/X86Vector/Transforms.h"
@@ -29,7 +30,7 @@ using namespace mlir::x86vector;
 // Verifies that the LHS and RHS operands of a vector.contract are load or
 // vector.transfer_read operations on a memref source buffer, and checks
 // their bounds, dimensions, offsets, and strides.
-static bool validateVectorContractOperands(Value prodOp) {
+static bool validateVectorContractOperands(Value prodOp, bool isVnni) {
   Operation *defOp = prodOp.getDefiningOp();
   if (!defOp)
     return false;
@@ -62,11 +63,13 @@ static bool validateVectorContractOperands(Value prodOp) {
   // Return false if the two innermost strides of the memref are not contiguous.
   // The x86vector.avx.cvt.packed.even/odd.indexed_to_f32 operations require
   // an eight-element tuple of bf16 values to be contiguous.
-  if (!llvm::cast<mlir::MemRefType>(srcType).areTrailingDimsContiguous(2))
+  int dimsToCheck = isVnni ? 2 : 1;
+  if (!llvm::cast<mlir::MemRefType>(srcType).areTrailingDimsContiguous(
+          dimsToCheck))
     return false;
 
   // Return false if the vnni offset of load or transfer_read is not zero.
-  if (getConstantIntValue(indexVals.back()) != 0)
+  if (isVnni && getConstantIntValue(indexVals.back()) != 0)
     return false;
 
   return true;
@@ -96,7 +99,8 @@ static bool validateVectorContractOperands(Value prodOp) {
 // ```
 static SmallVector<memref::SubViewOp>
 getSubviewFromVectorInput(Location loc, PatternRewriter &rewriter, Value prodOp,
-                          ArrayRef<int64_t> nonUnitDimShape, bool isUnitDim) {
+                          ArrayRef<int64_t> nonUnitDimShape, bool isUnitDim,
+                          bool isVNNI) {
 
   Operation *defOp = prodOp.getDefiningOp();
 
@@ -122,11 +126,26 @@ getSubviewFromVectorInput(Location loc, PatternRewriter &rewriter, Value prodOp,
     }
   }
 
-  int vnniDimSize = isUnitDim ? 1 : 2;
+  auto one = rewriter.getIndexAttr(1);
+  llvm::SmallVector<memref::SubViewOp> subviews;
 
+  if (!isVNNI) {
+    SmallVector<OpFoldResult> strides(indexVals.size(), one);
+    SmallVector<OpFoldResult> sizes(indexVals.size(), one);
+    // Retrive twice the nonUnit dim BF16 element for both even and odd
+    // index elements.
+    if (!isUnitDim)
+      mnDimSize = 2 * mnDimSize;
+    sizes[mnDimIdx] = rewriter.getIndexAttr(mnDimSize);
+    auto subview = memref::SubViewOp::create(rewriter, loc, srcBuff, indexVals,
+                                             sizes, strides);
+    subviews.push_back(subview);
+    return subviews;
+  }
+
+  int vnniDimSize = isUnitDim ? 1 : 2;
   auto nonVNNIDimSize = indexVals.size() - 1;
   // Create the size and stride offsets.
-  auto one = rewriter.getIndexAttr(1);
   SmallVector<OpFoldResult> strides(indexVals.size(), one);
   SmallVector<OpFoldResult> sizes(nonVNNIDimSize, one);
 
@@ -139,7 +158,6 @@ getSubviewFromVectorInput(Location loc, PatternRewriter &rewriter, Value prodOp,
   if (isUnitDim)
     indexVals[indexVals.size() - 1] = rewriter.getIndexAttr(1);
 
-  llvm::SmallVector<memref::SubViewOp> subviews;
   auto subview = memref::SubViewOp::create(rewriter, loc, srcBuff, indexVals,
                                            sizes, strides);
   subviews.push_back(subview);
@@ -168,7 +186,7 @@ getSubviewFromVectorInput(Location loc, PatternRewriter &rewriter, Value prodOp,
 // Implements outer product contraction as a sequence of BF16-packed
 // operation even/odd loads and FMA operations.
 //
-// For example:
+// For example (VNNI packed):
 // ```
 //   %1 = vector.load from memref (%m1) -> vector<1x1x2xbf16>
 //   %2 = vector.load from memref (%m2) -> vector<1x8x2xbf16>
@@ -183,6 +201,24 @@ getSubviewFromVectorInput(Location loc, PatternRewriter &rewriter, Value prodOp,
 //   %5 = x86vector.avx.cvt.packed.even.indexed_to_f32 %m2 -> vector<8xf32>
 //   return vector.fma %4, %5, %3
 // ```
+//
+// For example (Flat layout):
+// ```
+//   %1 = vector.load from memref (%m1) -> vector<1x1xbf16>
+//   %2 = vector.load from memref (%m2) -> vector<1x8xbf16>
+//   %3 = vector.contract %1, %2, %arg1
+//   %4 = vector.load from memref (%m2) -> vector<1x8xbf16>
+//   %5 = vector.contract %1, %4, %arg2
+//   scf.yield %3, %4
+// ```
+// to
+// ```
+//   %1 = x86vector.avx.bcst_to_f32.packed %m1[c0] -> vector<8xf32>
+//   %2 = x86vector.avx.cvt.packed.even.indexed_to_f32 %m2 -> vector<8xf32>
+//   %3 = vector.fma %1, %2, %arg1
+//   %4 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %m2 -> vector<8xf32>
+//   %5 = vector.fma %1, %4, %arg2
+//   scf.yield %3, %5
 struct VectorContractBF16ToFMA
     : public OpRewritePattern<vector::ContractionOp> {
   using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
@@ -202,11 +238,9 @@ struct VectorContractBF16ToFMA
       return rewriter.notifyMatchFailure(contractOp,
                                          "Only BF16 lowering is supported.");
 
-    if (!isInVnniLayout(contractOp.getOperation(),
-                        contractOp.getIndexingMapsArray(),
-                        /*blockingFactor=*/2))
-      return rewriter.notifyMatchFailure(contractOp,
-                                         "Input matrices not in VNNI format.");
+    bool isVnni = isInVnniLayout(contractOp.getOperation(),
+                                 contractOp.getIndexingMapsArray(),
+                                 /*blockingFactor=*/2);
 
     VectorType accTy = dyn_cast<VectorType>(contractOp.getAccType());
     if (!accTy)
@@ -216,6 +250,14 @@ struct VectorContractBF16ToFMA
       return rewriter.notifyMatchFailure(
           contractOp, "Only F32 acumulation supported for BF16 type.");
 
+    ArrayRef<int64_t> accShape = accTy.getShape();
+    llvm::SmallVector<int64_t> nonUnitDimAcc;
+    llvm::copy_if(accShape, std::back_inserter(nonUnitDimAcc),
+                  [](int64_t dim) { return dim != 1; });
+    if (nonUnitDimAcc.size() != 1)
+      return rewriter.notifyMatchFailure(
+          contractOp, "A or B should be a non-unit dim in acc.");
+
     ArrayRef<int64_t> lhsShape = lhsTy.getShape();
     llvm::SmallVector<int64_t> nonUnitDimLhs;
     llvm::copy_if(lhsShape, std::back_inserter(nonUnitDimLhs),
@@ -227,35 +269,38 @@ struct VectorContractBF16ToFMA
     llvm::copy_if(rhsShape, std::back_inserter(nonUnitDimRhs),
                   [](int64_t dim) { return dim != 1; });
 
-    if ((nonUnitDimLhs.size() - 1) > 0 && (nonUnitDimRhs.size() - 1) > 0)
+    if (isVnni && (nonUnitDimLhs.size() - 1) > 0 &&
+        (nonUnitDimRhs.size() - 1) > 0)
       return rewriter.notifyMatchFailure(contractOp,
                                          "Excepts unit dimensions for either "
                                          "LHS or RHS shape other than VNNI.");
 
-    if ((nonUnitDimLhs.size() - 1) != 1 && (nonUnitDimRhs.size() - 1) != 1)
+    if (isVnni && (nonUnitDimLhs.size() - 1) != 1 &&
+        (nonUnitDimRhs.size() - 1) != 1)
       return rewriter.notifyMatchFailure(
           contractOp,
           "Excepts a one non-unit A/B dimension for either LHS or RHS shape.");
 
-    ArrayRef<int64_t> accShape = accTy.getShape();
-    llvm::SmallVector<int64_t> nonUnitDimAcc;
-    llvm::copy_if(accShape, std::back_inserter(nonUnitDimAcc),
-                  [](int64_t dim) { return dim != 1; });
-    if (nonUnitDimAcc.size() != 1)
+    if (!isVnni && nonUnitDimLhs.size() > 0 && nonUnitDimRhs.size() > 0)
+      return rewriter.notifyMatchFailure(contractOp,
+                                         "Excepts unit dimensions for either "
+                                         "LHS or RHS shape.");
+
+    if (!isVnni && nonUnitDimLhs.size() != 1 && nonUnitDimRhs.size() != 1)
       return rewriter.notifyMatchFailure(
-          contractOp, "A or B should be a non-unit dim in acc.");
+          contractOp,
+          "Excepts a one non-unit A/B dimension for either LHS or RHS shape.");
 
     // Non-unit dimensions should match the vector length of BF16.
-    unsigned int nonUnitDim = nonUnitDimLhs.size() == 2 ? nonUnitDimLhs.front()
-                                                        : nonUnitDimRhs.front();
-    if (nonUnitDim != 4 && nonUnitDim != 8 &&
-        !(nonUnitDimAcc.front() == nonUnitDim))
+    unsigned int nonUnitDim = nonUnitDimAcc.front();
+
+    if (nonUnitDim != 4 && nonUnitDim != 8)
       return rewriter.notifyMatchFailure(
           contractOp, "BF16 packed load operation expects non-unit (LHR or "
                       "RHS) dim and acc dim of size 4/8.");
 
-    if (!validateVectorContractOperands(contractOp.getLhs()) ||
-        !validateVectorContractOperands(contractOp.getRhs())) {
+    if (!validateVectorContractOperands(contractOp.getLhs(), isVnni) ||
+        !validateVectorContractOperands(contractOp.getRhs(), isVnni)) {
       return rewriter.notifyMatchFailure(
           contractOp, "The LHS or RHS is in an invalid format. Either it has "
                       "false in-bounds, "
@@ -273,7 +318,12 @@ struct VectorContractBF16ToFMA
     // vector<1x8x2xbf16>, we create two subview for the LHS and one subview
     // for the RHS. In the opposite case (non-unit dimension on the LHS), we
     // do vice-versa.
-    bool rhsHasMultipleNonUnitDims = (nonUnitDimRhs.size() - 1) > 0;
+
+    bool rhsHasMultipleNonUnitDims = nonUnitDimRhs.size() > 0;
+    if (isVnni) {
+      rhsHasMultipleNonUnitDims = (nonUnitDimRhs.size() - 1) > 0;
+    }
+
     // Select which operand is "unit" and which is "non-unit".
     Value unitSrc =
         rhsHasMultipleNonUnitDims ? contractOp.getLhs() : contractOp.getRhs();
@@ -283,12 +333,38 @@ struct VectorContractBF16ToFMA
     ArrayRef<int64_t> nonUnitDimShape =
         rhsHasMultipleNonUnitDims ? rhsShape : lhsShape;
 
+    // Get the pair vector.contract operation. The pair is decided on:
+    //  (1) - the unitDim operand Lhs or Rhs should be same,
+    //  (2) - the defining source memref should be same for nonUnitDim
+    //  operation, (3) - the nonUnit dim offset difference between the
+    //  vector.contracts should be 8.
+    vector::ContractionOp pairContractOp;
+    if (!isVnni) {
+      Operation *nextOp = contractOp;
+      while ((nextOp = nextOp->getNextNode())) {
+        auto contOp = dyn_cast<vector::ContractionOp>(nextOp);
+
+        if (!contOp)
+          continue;
+
+        if (validatePairVectorContract(contractOp, contOp,
+                                       rhsHasMultipleNonUnitDims,
+                                       nonUnitDimAcc.front())) {
+          pairContractOp = contOp;
+          break;
+        }
+      }
+
+      if (!pairContractOp)
+        return failure();
+    }
+
     // Build subviews.
-    auto unitDimSubview = getSubviewFromVectorInput(loc, rewriter, unitSrc,
-                                                    nonUnitDimShape, true);
+    auto unitDimSubview = getSubviewFromVectorInput(
+        loc, rewriter, unitSrc, nonUnitDimShape, true, isVnni);
 
     auto nonUnitDimSubview = getSubviewFromVectorInput(
-        loc, rewriter, nonUnitSrc, nonUnitDimShape, false);
+        loc, rewriter, nonUnitSrc, nonUnitDimShape, false, isVnni);
 
     auto castAcc = vector::ShapeCastOp::create(
         rewriter, loc,
@@ -297,6 +373,41 @@ struct VectorContractBF16ToFMA
     VectorType dstType =
         VectorType::get(nonUnitDimAcc.front(), rewriter.getF32Type());
 
+    if (!isVnni) {
+      auto loadBcstBF16ElementToF32 = x86vector::BcstToPackedF32Op::create(
+          rewriter, loc, dstType, unitDimSubview[0]);
+      auto loadEvenIdxElementF32 =
+          x86vector::CvtPackedEvenIndexedToF32Op::create(rewriter, loc, dstType,
+                                                         nonUnitDimSubview[0]);
+      auto evenIdxFMA =
+          vector::FMAOp::create(rewriter, loc, loadBcstBF16ElementToF32,
+                                loadEvenIdxElementF32, castAcc);
+      auto castEvenFma =
+          vector::ShapeCastOp::create(rewriter, loc, accTy, evenIdxFMA);
+      rewriter.replaceOp(contractOp, castEvenFma);
+
+      rewriter.setInsertionPoint(pairContractOp);
+      auto pairContOpLoc = pairContractOp.getLoc();
+      VectorType accTyPairCont =
+          dyn_cast<VectorType>(pairContractOp.getAccType());
+      auto castAccPairCont = vector::ShapeCastOp::create(
+          rewriter, pairContOpLoc,
+          VectorType::get(nonUnitDimAcc.front(),
+                          accTyPairCont.getElementType()),
+          pairContractOp.getAcc());
+
+      auto loadOddIdxElementF32 = x86vector::CvtPackedOddIndexedToF32Op::create(
+          rewriter, pairContOpLoc, dstType, nonUnitDimSubview[0]);
+      auto oddIdxFMA = vector::FMAOp::create(
+          rewriter, pairContOpLoc, loadBcstBF16ElementToF32,
+          loadOddIdxElementF32, castAccPairCont);
+      auto castOddFma = vector::ShapeCastOp::create(rewriter, pairContOpLoc,
+                                                    accTyPairCont, oddIdxFMA);
+      rewriter.replaceOp(pairContractOp, castOddFma);
+
+      return success();
+    }
+
     // Load, broadcast, and do FMA for odd indexed BF16 elements.
     auto loadBcstOddIdxElementToF32 = x86vector::BcstToPackedF32Op::create(
         rewriter, loc, dstType, unitDimSubview[0]);
diff --git a/mlir/lib/Dialect/X86Vector/Utils/X86VectorUtils.cpp b/mlir/lib/Dialect/X86Vector/Utils/X86VectorUtils.cpp
index 65cd0773160ef..38edb8091ce72 100644
--- a/mlir/lib/Dialect/X86Vector/Utils/X86VectorUtils.cpp
+++ b/mlir/lib/Dialect/X86Vector/Utils/X86VectorUtils.cpp
@@ -369,15 +369,16 @@ void shuffleNonUnitDimOperand(mlir::PatternRewriter &rewriter,
   rewriteUses(opB->getResult(0), newAccB.getResult(), contractB, rewriter);
 }
 
+// Return true if vector.contract operations matches on below conditions:
+//  (1) - the unitDim operand Lhs or Rhs should be same,
+//  (2) - the defining source memref should be same for nonUnitDim
+//  operation,
+//  (3) - the nonUnit dim offset difference between the
+//  vector.contracts should be 8.
 bool validatePairVectorContract(vector::ContractionOp contractOp,
                                 vector::ContractionOp pairContOp,
                                 bool rhsHasMultipleNonUnitDims,
                                 int64_t nonUnitDimValue) {
-
-  if (!(contractOp.getLhs() == pairContOp.getLhs()) &&
-      !(contractOp.getRhs() == pairContOp.getRhs()))
-    return false;
-
   if (rhsHasMultipleNonUnitDims &&
       !(contractOp.getLhs() == pairContOp.getLhs()))
     return false;
@@ -386,43 +387,46 @@ bool validatePairVectorContract(vector::ContractionOp contractOp,
       !(contractOp.getRhs() == pairContOp.getRhs()))
     return false;
 
-  auto op =
+  auto nonUnitOperand =
       rhsHasMultipleNonUnitDims ? contractOp.getRhs() : contractOp.getLhs();
-  auto op1 =
+  auto nonUnitOperandPairContOp =
       rhsHasMultipleNonUnitDims ? pairContOp.getRhs() : pairContOp.getLhs();
 
   Value srcBuff;
   SmallVector<OpFoldResult> indexVals;
-  llvm::TypeSwitch<mlir::Operation *>(op.getDefiningOp())
+  llvm::TypeSwitch<mlir::Operation *>(nonUnitOperand.getDefiningOp())
       .Case<vector::TransferReadOp, vector::LoadOp>([&](auto readOp) {
         srcBuff = readOp.getOperand(0);
         indexVals = SmallVector<OpFoldResult>(readOp.getIndices().begin(),
                                               readOp.getIndices().end());
       });
 
-  Value srcBuff1;
-  SmallVector<OpFoldResult> indexVals1;
-  llvm::TypeSwitch<mlir::Operation *>(op1.getDefiningOp())
+  Value srcBuffPairContOp;
+  SmallVector<OpFoldResult> indexValsPairContOp;
+  llvm::TypeSwitch<mlir::Operation *>(nonUnitOperandPairContOp.getDefiningOp())
       .Case<vector::TransferReadOp, vector::LoadOp>([&](auto readOp) {
-        srcBuff1 = readOp.getOperand(0);
-        indexVals1 = SmallVector<OpFoldResult>(readOp.getIndices().begin(),
-                                               readOp.getIndices().end());
+        srcBuffPairContOp = readOp.getOperand(0);
+        indexValsPairContOp = SmallVector<OpFoldResult>(
+            readOp.getIndices().begin(), readOp.getIndices().end());
       });
 
-  if (!srcBuff || !srcBuff1)
+  if (!srcBuff || !srcBuffPairContOp)
     return false;
 
-  if (!(srcBuff == srcBuff1))
+  if (!(srcBuff == srcBuffPairContOp))
     return false;
 
   for (size_t i = 0; i < indexVals.size(); i++) {
-    if (getConstantIntValue(indexVals[i]) == getConstantIntValue(indexVals1[i]))
-      continue;
+    auto v0 = getConstantIntValue(indexVals[i]);
+    auto v1 = getConstantIntValue(indexValsPairContOp[i]);
+
+    if (!v0 || !v1)
+      return false;
 
-    auto value1 = *getConstantIntValue(indexVals1[i]);
-    auto value2 = *getConstantIntValue(indexVals[i]);
+    if (*v1 == *v0)
+      continue;
 
-    if ((value1 - value2) != nonUnitDimValue)
+    if ((*v1 - *v0) != nonUnitDimValue)
       return false;
   }
 
diff --git a/mlir/test/Dialect/X86Vector/vector-contract-bf16-to-fma.mlir b/mlir/test/Dialect/X86Vector/vector-contract-bf16-to-fma.mlir
index e7a70429490d1..05404fa1ef6a5 100644
--- a/mlir/test/Dialect/X86Vector/vector-contract-bf16-to-fma.mlir
+++ b/mlir/test/Dialect/X86Vector/vector-contract-bf16-to-fma.mlir
@@ -233,6 +233,228 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
+!vecA = vector<1x1xbf16>
+!vecB = vector<1x8xbf16>
+!vecC = vector<1x8xf32>
+!memrefA = memref<4x1xbf16>
+!memrefB = memref<1x32xbf16>
+!memrefC = memref<2x32xf32>
+#map = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0,  d1, d2) -> (d2, d1)>
+#map2 = affine_map<(d0,  d1, d2) -> (d0, d1)>
+func.func @matmul_to_fma_flat_layout(
+  %arg0: !memrefA, %arg1: !memrefB, %arg2: !memrefC) -> !memrefC
+{
+  %c0 = arith.constant 0 : index
+  %c8 = arith.constant 8 : index
+  %0 = ub.poison : bf16
+  %32 = ub.poison : f32
+  %1 = vector.transfer_read %arg0[%c0, %c0], %0 {in_bounds = [true, true]} :
+        !memrefA, !vecA
+  %2 = vector.transfer_read %arg1[%c0, %c0], %0 {in_bounds = [true, true]} :
+        !memrefB, !vecB
+  %3 = vector.transfer_read %arg1[%c0, %c8], %0 {in_bounds = [true, true]} :
+        !memrefB, !vecB
+  %4 = vector.transfer_read %arg2[%c0, %c0], %32 {in_bounds = [true, true]} :
+        !memrefC, !vecC
+  %5 = vector.transfer_read %arg2[%c0, %c8], %32 {in_bounds = [true, true]} :
+        !memrefC, !vecC
+
+  %6 = vector.contract {
+    indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["parallel", "parallel", "reduction"],
+    kind = #vector.kind<add>}
+    %1, %2, %4
+    : !vecA, !vecB into !vecC
+
+  %7 = vector.contract {
+    indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["parallel", "parallel", "reduction"],
+    kind = #vector.kind<add>}
+    %1, %3, %5
+    : !vecA, !vecB into !vecC
+
+  vector.transfer_write %6, %arg2[%c0, %c0] {in_bounds = [true, true]} : !vecC, !memrefC
+  vector.transfer_write %7, %arg2[%c0, %c8] {in_bounds = [true, true]} : !vecC, !memrefC
+
+  return %arg2 : !memrefC
+}
+
+// CHECK-LABEL: @matmul_to_fma_flat_layout
+// CHECK: vector.shuffle{{.*}}[0, 8, 1, 9, 2, 10, 3, 11] : vector<8xf32>, vector<8xf32>
+// CHECK-NEXT: vector.shuffle{{.*}}[4, 12, 5, 13, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32>
+// CHECK: memref.subview %arg0[%c0, %c0] {{.*}} : memref<4x1xbf16> to memref<1x1xbf16, {{.*}}>
+// CHECK: memref.subview %arg1[%c0, %c0] {{.*}} : memref<1x32xbf16> to memref<1x16xbf16, {{.*}}>
+// CHECK: x86vector.avx.bcst_to_f32.packed {{.*}} : memref<1x1xbf16, strided<[1, 1], offset: ?>>
+// CHECK: x86vector.avx.cvt.packed.even.indexed_to_f32 {{.*}} : memref<1x16xbf16, strided<[32, 1], offset: ?>>
+// CHECK: vector.fma {{.*}} : vector<8xf32>
+// CHECK: x86vector.avx.cvt.packed.odd.indexed_to_f32 {{.*}} : memref<1x16xbf16, strided<[32, 1], offset: ?>>
+// CHECK: vector.fma {{.*}} : vector<8xf32>
+// CHECK: vector.shuffle{{.*}}[0, 8, 1, 9, 2, 10, 3, 11] : vector<8xf32>, vector<8xf32>
+// CHECK-NEXT: vector.shuffle{{.*}}[4, 12, 5, 13, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32>
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.x86vector.shuffle_bf16_vector_contract_result
+      transform.apply_patterns.x86vector.vector_contract_bf16_to_fma
+    } : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+!vecA = vector<1x1xbf16>
+!vecB = vector<1x8xbf16>
+!vecC = vector<1x8xf32>
+!memrefA = memref<4x1xbf16>
+!memrefB = memref<1x32xbf16>
+!memrefC = memref<2x32xf32>
+#map = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0,  d1, d2) -> (d2, d1)>
+#map2 = affine_map<(d0,  d1, d2) -> (d0, d1)>
+func.func @matmul_to_fma_flat_layout_load(
+  %arg0: !memrefA, %arg1: !memrefB, %arg2: !memrefC) -> !memrefC
+{
+  %c0 = arith.constant 0 : index
+  %c8 = arith.constant 8 : index
+  %0 = ub.poison : bf16
+  %32 = ub.poison : f32
+  %1 = vector.load %arg0[%c0, %c0] :
+        !memrefA, !vecA
+  %2 = vector.load %arg1[%c0, %c0] :
+        !memrefB, !vecB
+  %3 = vector.load %arg1[%c0, %c8] :
+        !memrefB, !vecB
+  %4 = vector.load %arg2[%c0, %c0] :
+        !memrefC, !vecC
+  %5 = vector.load %arg2[%c0, %c8] :
+        !memrefC, !vecC
+
+  %6 = vector.contract {
+    indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["parallel", "parallel", "reduction"],
+    kind = #vector.kind<add>}
+    %1, %2, %4
+    : !vecA, !vecB into !vecC
+
+  %7 = vector.contract {
+    indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["parallel", "parallel", "reduction"],
+    kind = #vector.kind<add>}
+    %1, %3, %5
+    : !vecA, !vecB into !vecC
+
+  vector.store %6, %arg2[%c0, %c0] : !memrefC, !vecC
+  vector.store %7, %arg2[%c0, %c8] : !memrefC, !vecC
+
+  return %arg2 : !memrefC
+}
+
+// CHECK-LABEL: @matmul_to_fma_flat_layout_load
+// CHECK: vector.shuffle{{.*}}[0, 8, 1, 9, 2, 10, 3, 11] : vector<8xf32>, vector<8xf32>
+// CHECK-NEXT: vector.shuffle{{.*}}[4, 12, 5, 13, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32>
+// CHECK: memref.subview %arg0[%c0, %c0] {{.*}} : memref<4x1xbf16> to memref<1x1xbf16, {{.*}}>
+// CHECK: memref.subview %arg1[%c0, %c0] {{.*}} : memref<1x32xbf16> to memref<1x16xbf16, {{.*}}>
+// CHECK: x86vector.avx.bcst_to_f32.packed {{.*}} : memref<1x1xbf16, strided<[1, 1], offset: ?>>
+// CHECK: x86vector.avx.cvt.packed.even.indexed_to_f32 {{.*}} : memref<1x16xbf16, strided<[32, 1], offset: ?>>
+// CHECK: vector.fma {{.*}} : vector<8xf32>
+// CHECK: x86vector.avx.cvt.packed.odd.indexed_to_f32 {{.*}} : memref<1x16xbf16, strided<[32, 1], offset: ?>>
+// CHECK: vector.fma {{.*}} : vector<8xf32>
+// CHECK: vector.shuffle{{.*}}[0, 8, 1, 9, 2, 10, 3, 11] : vector<8xf32>, vector<8xf32>
+// CHECK-NEXT: vector.shuffle{{.*}}[4, 12, 5, 13, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32>
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.x86vector.shuffle_bf16_vector_contract_result
+      transform.apply_patterns.x86vector.vector_contract_bf16_to_fma
+    } : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+!vecA = vector<8x1xbf16>
+!vecB = vector<1x1xbf16>
+!vecC = vector<8x1xf32>
+!memrefA = memref<32x1xbf16>
+!memrefB = memref<1x4xbf16>
+!memrefC = memref<32x4xf32>
+#map = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0,  d1, d2) -> (d2, d1)>
+#map2 = affine_map<(d0,  d1, d2) -> (d0, d1)>
+func.func @matmul_to_fma_flat_layout_bcstB(
+  %arg0: !memrefA, %arg1: !memrefB, %arg2: !memrefC) -> !memrefC
+{
+  %c0 = arith.constant 0 : index
+  %c8 = arith.constant 8 : index
+  %0 = ub.poison : bf16
+  %32 = ub.poison : f32
+
+  %1 = vector.transfer_read %arg0[%c0, %c0], %0 {in_bounds = [true, true]} :
+        !memrefA, !vecA
+  %2 = vector.transfer_read %arg0[%c8, %c0], %0 {in_bounds = [true, true]} :
+        !memrefA, !vecA
+
+  %3 = vector.transfer_read %arg1[%c0, %c0], %0 {in_bounds = [true, true]} :
+        !memrefB, !vecB
+
+  %4 = vector.transfer_read %arg2[%c0, %c0], %32 {in_bounds = [true, true]} :
+        !memrefC, !vecC
+  %5 = vector.transfer_read %arg2[%c8, %c0], %32 {in_bounds = [true, true]} :
+        !memrefC, !vecC
+
+  %6 = vector.contract {
+    indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["parallel", "parallel", "reduction"],
+    kind = #vector.kind<add>}
+    %1, %3, %4
+    : !vecA, !vecB into !vecC
+
+  %7 = vector.contract {
+    indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["parallel", "parallel", "reduction"],
+    kind = #vector.kind<add>}
+    %2, %3, %5
+    : !vecA, !vecB into !vecC
+
+  vector.transfer_write %6, %arg2[%c0, %c0] {in_bounds = [true, true]} : !vecC, !memrefC
+  vector.transfer_write %7, %arg2[%c8, %c0] {in_bounds = [true, true]} : !vecC, !memrefC
+
+  return %arg2 : !memrefC
+}
+
+// CHECK-LABEL: @matmul_to_fma_flat_layout_bcstB
+// CHECK: vector.shuffle{{.*}}[0, 8, 1, 9, 2, 10, 3, 11] : vector<8xf32>, vector<8xf32>
+// CHECK-NEXT: vector.shuffle{{.*}}[4, 12, 5, 13, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32>
+// CHECK: memref.subview %arg1[%c0, %c0] {{.*}} : memref<1x4xbf16> to memref<1x1xbf16, {{.*}}>
+// CHECK: memref.subview %arg0[%c0, %c0] {{.*}} : memref<32x1xbf16> to memref<16x1xbf16, {{.*}}>
+// CHECK: x86vector.avx.bcst_to_f32.packed {{.*}} : memref<1x1xbf16, strided<[4, 1], offset: ?>>
+// CHECK: x86vector.avx.cvt.packed.even.indexed_to_f32 {{.*}} : memref<16x1xbf16, strided<[1, 1], offset: ?>>
+// CHECK: vector.fma {{.*}} : vector<8xf32>
+// CHECK: x86vector.avx.cvt.packed.odd.indexed_to_f32 {{.*}} : memref<16x1xbf16, strided<[1, 1], offset: ?>>
+// CHECK: vector.fma {{.*}} : vector<8xf32>
+// CHECK: vector.shuffle{{.*}}[0, 8, 1, 9, 2, 10, 3, 11] : vector<8xf32>, vector<8xf32>
+// CHECK-NEXT: vector.shuffle{{.*}}[4, 12, 5, 13, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32>
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.x86vector.shuffle_bf16_vector_contract_result
+      transform.apply_patterns.x86vector.vector_contract_bf16_to_fma
+    } : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
 !vecA = vector<1x1x2xbf16>
 !vecB = vector<1x8x2xbf16>
 !vecC = vector<1x8xf32>
@@ -413,6 +635,279 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
+!vecA = vector<1x1xbf16>
+!vecB = vector<1x8xbf16>
+!vecC = vector<1x8xf32>
+!memrefA = memref<4x1xbf16>
+!memrefB = memref<1x32xbf16>
+!memrefC = memref<2x32xf32>
+#map = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0,  d1, d2) -> (d2, d1)>
+#map2 = affine_map<(d0,  d1, d2) -> (d0, d1)>
+func.func @negative_offset_diff_is_not_8(
+  %arg0: !memrefA, %arg1: !memrefB, %arg2: !memrefC) -> !memrefC
+{
+  %c0 = arith.constant 0 : index
+  %c16 = arith.constant 16 : index
+  %0 = ub.poison : bf16
+  %32 = ub.poison : f32
+  %1 = vector.load %arg0[%c0, %c0] :
+        !memrefA, !vecA
+  %2 = vector.load %arg1[%c0, %c0] :
+        !memrefB, !vecB
+  %3 = vector.load %arg1[%c0, %c16] :
+        !memrefB, !vecB
+  %4 = vector.load %arg2[%c0, %c0] :
+        !memrefC, !vecC
+  %5 = vector.load %arg2[%c0, %c16] :
+        !memrefC, !vecC
+
+  %6 = vector.contract {
+    indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["parallel", "parallel", "reduction"],
+    kind = #vector.kind<add>}
+    %1, %2, %4
+    : !vecA, !vecB into !vecC
+
+  %7 = vector.contract {
+    indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["parallel", "parallel", "reduction"],
+    kind = #vector.kind<add>}
+    %1, %3, %5
+    : !vecA, !vecB into !vecC
+
+  vector.store %6, %arg2[%c0, %c0] : !memrefC, !vecC
+  vector.store %7, %arg2[%c0, %c16] : !memrefC, !vecC
+
+  return %arg2 : !memrefC
+}
+
+// CHECK-LABEL: @negative_offset_diff_is_not_8
+// CHECK-NOT: x86vector.avx.bcst_to_f32.packed
+// CHECK-NOT: x86vector.avx.cvt.packed.even.indexed_to_f32
+// CHECK-NOT: vector.fma {{.*}} : vector<8xf32>
+// CHECK-NOT: x86vector.avx.cvt.packed.odd.indexed_to_f32
+// CHECK: vector.contract
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.x86vector.vector_contract_bf16_to_fma
+    } : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+!vecA = vector<1x1xbf16>
+!vecB = vector<1x8xbf16>
+!vecC = vector<1x8xf32>
+!memrefA = memref<4x1xbf16>
+!memrefB = memref<1x32xbf16>
+!memrefC = memref<2x32xf32>
+#map = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0,  d1, d2) -> (d2, d1)>
+#map2 = affine_map<(d0,  d1, d2) -> (d0, d1)>
+func.func @negative_vector_contracts_not_in_order(
+  %arg0: !memrefA, %arg1: !memrefB, %arg2: !memrefC) -> !memrefC
+{
+  %c0 = arith.constant 0 : index
+  %c8 = arith.constant 8 : index
+  %0 = ub.poison : bf16
+  %32 = ub.poison : f32
+  %1 = vector.transfer_read %arg0[%c0, %c0], %0 {in_bounds = [true, true]} :
+        !memrefA, !vecA
+  %2 = vector.transfer_read %arg1[%c0, %c0], %0 {in_bounds = [true, true]} :
+        !memrefB, !vecB
+  %3 = vector.transfer_read %arg1[%c0, %c8], %0 {in_bounds = [true, true]} :
+        !memrefB, !vecB
+  %4 = vector.transfer_read %arg2[%c0, %c0], %32 {in_bounds = [true, true]} :
+        !memrefC, !vecC
+  %5 = vector.transfer_read %arg2[%c0, %c8], %32 {in_bounds = [true, true]} :
+        !memrefC, !vecC
+
+  %7 = vector.contract {
+    indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["parallel", "parallel", "reduction"],
+    kind = #vector.kind<add>}
+    %1, %3, %5
+    : !vecA, !vecB into !vecC
+
+  %6 = vector.contract {
+    indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["parallel", "parallel", "reduction"],
+    kind = #vector.kind<add>}
+    %1, %2, %4
+    : !vecA, !vecB into !vecC
+
+  vector.transfer_write %6, %arg2[%c0, %c0] {in_bounds = [true, true]} : !vecC, !memrefC
+  vector.transfer_write %7, %arg2[%c0, %c8] {in_bounds = [true, true]} : !vecC, !memrefC
+
+  return %arg2 : !memrefC
+}
+
+// CHECK-LABEL: @negative_vector_contracts_not_in_order
+// CHECK-NOT: x86vector.avx.bcst_to_f32.packed
+// CHECK-NOT: x86vector.avx.cvt.packed.even.indexed_to_f32
+// CHECK-NOT: vector.fma {{.*}} : vector<8xf32>
+// CHECK-NOT: x86vector.avx.cvt.packed.odd.indexed_to_f32
+// CHECK: vector.contract
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.x86vector.shuffle_bf16_vector_contract_result
+      transform.apply_patterns.x86vector.vector_contract_bf16_to_fma
+    } : !transform.any_op
+    transform.yield
+  }
+}
+  
+// -----
+
+!vecA = vector<8x1xbf16>
+!vecB = vector<1x1xbf16>
+!vecC = vector<8x1xf32>
+!memrefA = memref<32x1xbf16>
+!memrefB = memref<1x4xbf16>
+!memrefC = memref<32x4xf32>
+#map = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0,  d1, d2) -> (d2, d1)>
+#map2 = affine_map<(d0,  d1, d2) -> (d0, d1)>
+func.func @negative_flat_layout_dynamic_index(
+  %arg0: !memrefA, %arg1: !memrefB, %arg2: !memrefC, %arg3: index) -> !memrefC
+{
+  %c0 = arith.constant 0 : index
+  %c8 = arith.constant 8 : index
+  %0 = ub.poison : bf16
+  %32 = ub.poison : f32
+
+  %1 = vector.transfer_read %arg0[%c0, %c0], %0 {in_bounds = [true, true]} :
+        !memrefA, !vecA
+  %2 = vector.transfer_read %arg0[%arg3, %c0], %0 {in_bounds = [true, true]} :
+        !memrefA, !vecA
+
+  %3 = vector.transfer_read %arg1[%c0, %c0], %0 {in_bounds = [true, true]} :
+        !memrefB, !vecB
+
+  %4 = vector.transfer_read %arg2[%c0, %c0], %32 {in_bounds = [true, true]} :
+        !memrefC, !vecC
+  %5 = vector.transfer_read %arg2[%c8, %c0], %32 {in_bounds = [true, true]} :
+        !memrefC, !vecC
+
+  %6 = vector.contract {
+    indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["parallel", "parallel", "reduction"],
+    kind = #vector.kind<add>}
+    %1, %3, %4
+    : !vecA, !vecB into !vecC
+
+  %7 = vector.contract {
+    indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["parallel", "parallel", "reduction"],
+    kind = #vector.kind<add>}
+    %2, %3, %5
+    : !vecA, !vecB into !vecC
+
+  vector.transfer_write %6, %arg2[%c0, %c0] {in_bounds = [true, true]} : !vecC, !memrefC
+  vector.transfer_write %7, %arg2[%c8, %c0] {in_bounds = [true, true]} : !vecC, !memrefC
+
+  return %arg2 : !memrefC
+}
+
+// CHECK-LABEL: @negative_flat_layout_dynamic_index
+// CHECK-NOT: x86vector.avx.bcst_to_f32.packed
+// CHECK-NOT: x86vector.avx.cvt.packed.even.indexed_to_f32
+// CHECK-NOT: vector.fma {{.*}} : vector<8xf32>
+// CHECK-NOT: x86vector.avx.cvt.packed.odd.indexed_to_f32
+// CHECK: vector.contract
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.x86vector.shuffle_bf16_vector_contract_result
+      transform.apply_patterns.x86vector.vector_contract_bf16_to_fma
+    } : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+!vecA = vector<8x2xbf16>
+!vecB = vector<2x1xbf16>
+!vecC = vector<8x1xf32>
+!memrefA = memref<32x2xbf16>
+!memrefB = memref<2x4xbf16>
+!memrefC = memref<32x4xf32>
+#map = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0,  d1, d2) -> (d2, d1)>
+#map2 = affine_map<(d0,  d1, d2) -> (d0, d1)>
+func.func @negative_non_unit_K_dim(
+  %arg0: !memrefA, %arg1: !memrefB, %arg2: !memrefC) -> !memrefC
+{
+  %c0 = arith.constant 0 : index
+  %c8 = arith.constant 8 : index
+  %0 = ub.poison : bf16
+  %32 = ub.poison : f32
+
+  %1 = vector.transfer_read %arg0[%c0, %c0], %0 {in_bounds = [true, true]} :
+        !memrefA, !vecA
+  %2 = vector.transfer_read %arg0[%c8, %c0], %0 {in_bounds = [true, true]} :
+        !memrefA, !vecA
+
+  %3 = vector.transfer_read %arg1[%c0, %c0], %0 {in_bounds = [true, true]} :
+        !memrefB, !vecB
+
+  %4 = vector.transfer_read %arg2[%c0, %c0], %32 {in_bounds = [true, true]} :
+        !memrefC, !vecC
+  %5 = vector.transfer_read %arg2[%c8, %c0], %32 {in_bounds = [true, true]} :
+        !memrefC, !vecC
+
+  %6 = vector.contract {
+    indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["parallel", "parallel", "reduction"],
+    kind = #vector.kind<add>}
+    %1, %3, %4
+    : !vecA, !vecB into !vecC
+
+  %7 = vector.contract {
+    indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["parallel", "parallel", "reduction"],
+    kind = #vector.kind<add>}
+    %2, %3, %5
+    : !vecA, !vecB into !vecC
+
+  vector.transfer_write %6, %arg2[%c0, %c0] {in_bounds = [true, true]} : !vecC, !memrefC
+  vector.transfer_write %7, %arg2[%c8, %c0] {in_bounds = [true, true]} : !vecC, !memrefC
+
+  return %arg2 : !memrefC
+}
+
+// CHECK-LABEL: @negative_non_unit_K_dim
+// CHECK-NOT: x86vector.avx.bcst_to_f32.packed
+// CHECK-NOT: x86vector.avx.cvt.packed.even.indexed_to_f32
+// CHECK-NOT: vector.fma {{.*}} : vector<8xf32>
+// CHECK-NOT: x86vector.avx.cvt.packed.odd.indexed_to_f32
+// CHECK: vector.contract
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.x86vector.shuffle_bf16_vector_contract_result
+      transform.apply_patterns.x86vector.vector_contract_bf16_to_fma
+    } : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
 !vecA = vector<1x1x1x2xbf16>
 !vecB = vector<1x1x16x2xbf16>
 !vecC = vector<1x1x16xf32>

>From 7292900611a129531775e29af950cfa1c25d6a02 Mon Sep 17 00:00:00 2001
From: Arun Thangamani <arun.thangamani at intel.com>
Date: Fri, 9 Jan 2026 00:45:23 -0800
Subject: [PATCH 05/17] Round:1 comments + formatting and added corener
 test-cases

---
 .../ShuffleBF16VectorContractResult.cpp       |  79 +++--
 .../shuffle-bf16-vector-contract-result.mlir  | 274 +++++++++++++++++-
 2 files changed, 332 insertions(+), 21 deletions(-)

diff --git a/mlir/lib/Dialect/X86Vector/Transforms/ShuffleBF16VectorContractResult.cpp b/mlir/lib/Dialect/X86Vector/Transforms/ShuffleBF16VectorContractResult.cpp
index faa1b0cd48131..24b8a7489dbfa 100644
--- a/mlir/lib/Dialect/X86Vector/Transforms/ShuffleBF16VectorContractResult.cpp
+++ b/mlir/lib/Dialect/X86Vector/Transforms/ShuffleBF16VectorContractResult.cpp
@@ -26,6 +26,32 @@ using namespace mlir;
 using namespace mlir::vector;
 using namespace mlir::x86vector;
 
+// Shuffle the output of BF16 type flat layout vector.contract operations
+//
+// For example:
+// ```
+//   %1 = vector.load -> vector<1x1xbf16>
+//   %2 = vector.load from memref (%m1) -> vector<1x8xbf16>
+//   %3 = vector.load from memref (%m1) -> vector<1x8xbf16>
+//   %4 = vector.contract %1, %2, %arg0 ->  vector<1x8xf32>
+//   %5 = vector.contract %1, %3, %arg1 ->  vector<1x8xf32>
+//   vector.store %4, %m1
+//   vector.store %5, %m1
+// ```
+// to
+// ```
+//   %1 = vector.load -> vector<1x1xbf16>
+//   %2 = vector.load from memref (%m1) -> vector<1x8xbf16>
+//   %3 = vector.load from memref (%m1) -> vector<1x8xbf16>
+//   %4 = vector.shuffle %arg0, %arg1 [0, 8, 1, 9, 2, 10, 3, 11]
+//   %5 = vector.shuffle %arg0, %arg1 [4, 12, 5, 13, 6, 14, 7, 15]
+//   %6 = vector.contract %1, %2, %4 ->  vector<1x8xf32>
+//   %7 = vector.contract %1, %3, %5 ->  vector<1x8xf32>
+//   %8 = vector.shuffle %6, %7 [0, 8, 1, 9, 2, 10, 3, 11]
+//   %9 = vector.shuffle %6, %7 [4, 12, 5, 13, 6, 14, 7, 15]
+//   vector.store %8, %m1
+//   vector.store %9, %m1
+//```
 struct ShuffleBF16VectorContractResult
     : public OpRewritePattern<vector::ContractionOp> {
   using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
@@ -85,22 +111,15 @@ struct ShuffleBF16VectorContractResult
     llvm::copy_if(rhsShape, std::back_inserter(nonUnitDimRhs),
                   [](int64_t dim) { return dim != 1; });
 
-    if ((nonUnitDimValue == 16) && (nonUnitDimLhs.size() - 1) > 0 &&
-        (nonUnitDimRhs.size() - 1) > 0)
-      return rewriter.notifyMatchFailure(contractOp,
-                                         "Excepts unit dimensions for either "
-                                         "LHS or RHS shape.");
-    if (nonUnitDimValue == 8 && nonUnitDimLhs.size() > 0 &&
-        nonUnitDimRhs.size() > 0)
-      return rewriter.notifyMatchFailure(contractOp,
-                                         "Excepts unit dimensions for either "
-                                         "LHS or RHS shape.");
-
     vector::ContractionOp pairContractOp;
-    bool rhsHasMultipleNonUnitDims = nonUnitDimValue == 16
-                                         ? (nonUnitDimRhs.size() - 1) > 0
-                                         : nonUnitDimRhs.size() > 0;
-
+    bool rhsHasMultipleNonUnitDims =
+        nonUnitDimRhs.size() > nonUnitDimLhs.size();
+
+    // Get the pair vector.contract operation. The pair is decided on:
+    //  (1) - the unitDim operand Lhs or Rhs should be same,
+    //  (2) - the defining source memref should be same for nonUnitDim
+    //  operation, (3) - the nonUnit dim offset difference between the
+    //  vector.contracts should be 8.
     Operation *nextOp = contractOp;
     while ((nextOp = nextOp->getNextNode())) {
       auto contOp = dyn_cast<vector::ContractionOp>(nextOp);
@@ -116,26 +135,50 @@ struct ShuffleBF16VectorContractResult
     }
 
     if (!pairContractOp)
-      return failure();
+      return rewriter.notifyMatchFailure(
+          contractOp, "Coudn't find pair contract operation for shuffling");
 
+    // Trace back to the load or transfer_read operations of the contract
+    // accumulators.
     Operation *accReadOp0 =
         traceToVectorReadLikeParentOperation(contractOp.getAcc());
     Operation *accReadOp1 =
         traceToVectorReadLikeParentOperation(pairContractOp.getAcc());
 
+    // Iterate dowm to find the users of contact operations until it is store or
+    // transfer_write.
     Operation *resultWriteOp0 =
         traceToVectorWriteLikeUserOperation(contractOp.getResult());
     Operation *resultWriteOp1 =
         traceToVectorWriteLikeUserOperation(pairContractOp.getResult());
 
     if (!accReadOp0 || !accReadOp1)
-      return failure();
+      return rewriter.notifyMatchFailure(
+          contractOp,
+          "Operands doesn't have load or transfer_read as it's parent op");
 
     if (!resultWriteOp0 || !resultWriteOp1)
-      return failure();
+      return rewriter.notifyMatchFailure(
+          contractOp, "The use of contract operations are neither vector.store "
+                      "or transfer_write");
+
+    if (contractOp->getBlock() == accReadOp1->getBlock() &&
+        contractOp->isBeforeInBlock(accReadOp1))
+      return rewriter.notifyMatchFailure(
+          contractOp, "The load/read operation of pair contract operation is "
+                      "after the contractOp");
 
+    if (pairContractOp->getBlock() == resultWriteOp0->getBlock() &&
+        resultWriteOp0->isBeforeInBlock(pairContractOp))
+      return rewriter.notifyMatchFailure(
+          contractOp, "The store/write operation of contract operation is "
+                      "before the pair contract operation");
+
+    // Shuffle the accumulators of the contract operations.
     shuffleAfterReadLikeOp(rewriter, accReadOp0, accReadOp1, contractOp,
                            pairContractOp, nonUnitDimValue, accTy);
+
+    // Shuffle the output of contract operations before it's use.
     shuffleBeforeWriteLikeOp(rewriter, resultWriteOp0, resultWriteOp1,
                              nonUnitDimValue, accTy);
 
diff --git a/mlir/test/Dialect/shuffle-bf16-vector-contract-result.mlir b/mlir/test/Dialect/shuffle-bf16-vector-contract-result.mlir
index c109543a8a163..032d22154b9fb 100644
--- a/mlir/test/Dialect/shuffle-bf16-vector-contract-result.mlir
+++ b/mlir/test/Dialect/shuffle-bf16-vector-contract-result.mlir
@@ -41,7 +41,6 @@ func.func @shuffle_VC_output_flat_layout(
     : !vecA, !vecB into !vecC
 
   vector.store %6, %arg2[%c0, %c0] : !memrefC, !vecC
-
   vector.store %7, %arg2[%c0, %c8] : !memrefC, !vecC
 
   return %arg2 : !memrefC
@@ -108,7 +107,6 @@ func.func @shuffle_VC_output_flat_layout_transfer_read(
     : !vecA, !vecB into !vecC
 
   vector.transfer_write %6, %arg2[%c0, %c0] {in_bounds = [true, true]} : !vecC, !memrefC
-
   vector.transfer_write %7, %arg2[%c0, %c8] {in_bounds = [true, true]} : !vecC, !memrefC
 
   return %arg2 : !memrefC
@@ -175,7 +173,6 @@ func.func @shuffle_VC_output_flat_layout_bf16dp(
     : !vecA, !vecB into !vecC
 
   vector.transfer_write %6, %arg2[%c0, %c0] {in_bounds = [true, true]} : !vecC, !memrefC
-
   vector.transfer_write %7, %arg2[%c0, %c16] {in_bounds = [true, true]} : !vecC, !memrefC
 
   return %arg2 : !memrefC
@@ -197,3 +194,274 @@ module attributes {transform.with_named_sequence} {
     transform.yield
   }
 }
+
+// -----
+
+!vecA = vector<1x1xbf16>
+!vecB = vector<1x32xbf16>
+!vecC = vector<1x32xf32>
+!memrefA = memref<4x2xbf16>
+!memrefB = memref<2x64xbf16>
+!memrefC = memref<2x64xf32>
+#map = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0,  d1, d2) -> (d2, d1)>
+#map2 = affine_map<(d0,  d1, d2) -> (d0, d1)>
+func.func @negative_dim_is_32(
+  %arg0: !memrefA, %arg1: !memrefB, %arg2: !memrefC) -> !memrefC
+{
+  %c0 = arith.constant 0 : index
+  %c32 = arith.constant 32 : index
+  %0 = ub.poison : bf16
+  %32 = ub.poison : f32
+  %1 = vector.load %arg0[%c0, %c0] :
+        !memrefA, !vecA
+  %2 = vector.load %arg1[%c0, %c0] :
+        !memrefB, !vecB
+  %3 = vector.load %arg1[%c0, %c32] :
+        !memrefB, !vecB
+  %4 = vector.load %arg2[%c0, %c0] :
+        !memrefC, !vecC
+  %5 = vector.load %arg2[%c0, %c32] :
+        !memrefC, !vecC
+
+  %6 = vector.contract {
+    indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["parallel", "parallel", "reduction"],
+    kind = #vector.kind<add>}
+    %1, %2, %4
+    : !vecA, !vecB into !vecC
+
+  %7 = vector.contract {
+    indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["parallel", "parallel", "reduction"],
+    kind = #vector.kind<add>}
+    %1, %3, %5
+    : !vecA, !vecB into !vecC
+
+  vector.store %6, %arg2[%c0, %c0] : !memrefC, !vecC
+  vector.store %7, %arg2[%c0, %c32] : !memrefC, !vecC
+
+  return %arg2 : !memrefC
+}
+
+// CHECK-LABEL: @negative_dim_is_32
+// CHECK-NOT: vector.shuffle
+// CHECK-NOT: vector.shuffle
+// CHECK: vector.contract
+// CHECK-NOT: vector.shuffle
+// CHECK-NOT: vector.shuffle
+// CHECK: vector.store
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.x86vector.shuffle_bf16_vector_contract_result
+    } : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+!vecA = vector<1x1xbf16>
+!vecB = vector<1x16xbf16>
+!vecC = vector<1x16xf32>
+!memrefA = memref<4x2xbf16>
+!memrefB = memref<2x64xbf16>
+!memrefC = memref<2x64xf32>
+#map = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0,  d1, d2) -> (d2, d1)>
+#map2 = affine_map<(d0,  d1, d2) -> (d0, d1)>
+func.func @negative_offset_diff_is_32(
+  %arg0: !memrefA, %arg1: !memrefB, %arg2: !memrefC) -> !memrefC
+{
+  %c0 = arith.constant 0 : index
+  %c16 = arith.constant 16 : index
+  %c32 = arith.constant 32 : index
+  %0 = ub.poison : bf16
+  %32 = ub.poison : f32
+  %1 = vector.load %arg0[%c0, %c0] :
+        !memrefA, !vecA
+  %2 = vector.load %arg1[%c0, %c0] :
+        !memrefB, !vecB
+  %3 = vector.load %arg1[%c0, %c32] :
+        !memrefB, !vecB
+  %4 = vector.load %arg2[%c0, %c0] :
+        !memrefC, !vecC
+  %5 = vector.load %arg2[%c0, %c16] :
+        !memrefC, !vecC
+
+  %6 = vector.contract {
+    indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["parallel", "parallel", "reduction"],
+    kind = #vector.kind<add>}
+    %1, %2, %4
+    : !vecA, !vecB into !vecC
+
+  %7 = vector.contract {
+    indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["parallel", "parallel", "reduction"],
+    kind = #vector.kind<add>}
+    %1, %3, %5
+    : !vecA, !vecB into !vecC
+
+  vector.store %6, %arg2[%c0, %c0] : !memrefC, !vecC
+  vector.store %7, %arg2[%c0, %c16] : !memrefC, !vecC
+
+  return %arg2 : !memrefC
+}
+
+// CHECK-LABEL: @negative_offset_diff_is_32
+// CHECK-NOT: vector.shuffle
+// CHECK-NOT: vector.shuffle
+// CHECK: vector.contract
+// CHECK-NOT: vector.shuffle
+// CHECK-NOT: vector.shuffle
+// CHECK: vector.store
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.x86vector.shuffle_bf16_vector_contract_result
+    } : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+!vecA = vector<1x1xbf16>
+!vecB = vector<1x16xbf16>
+!vecC = vector<1x16xf32>
+!memrefA = memref<4x2xbf16>
+!memrefB = memref<2x64xbf16>
+!memrefC = memref<2x64xf32>
+#map = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0,  d1, d2) -> (d2, d1)>
+#map2 = affine_map<(d0,  d1, d2) -> (d0, d1)>
+func.func @negative_dynamic_offset(
+  %arg0: !memrefA, %arg1: !memrefB, %arg2: !memrefC, %arg3: index) -> !memrefC
+{
+  %c0 = arith.constant 0 : index
+  %c16 = arith.constant 16 : index
+  %0 = ub.poison : bf16
+  %32 = ub.poison : f32
+  %1 = vector.load %arg0[%c0, %c0] :
+        !memrefA, !vecA
+  %2 = vector.load %arg1[%c0, %c0] :
+        !memrefB, !vecB
+  %3 = vector.load %arg1[%c0, %arg3] :
+        !memrefB, !vecB
+  %4 = vector.load %arg2[%c0, %c0] :
+        !memrefC, !vecC
+  %5 = vector.load %arg2[%c0, %c16] :
+        !memrefC, !vecC
+
+  %6 = vector.contract {
+    indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["parallel", "parallel", "reduction"],
+    kind = #vector.kind<add>}
+    %1, %2, %4
+    : !vecA, !vecB into !vecC
+
+  %7 = vector.contract {
+    indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["parallel", "parallel", "reduction"],
+    kind = #vector.kind<add>}
+    %1, %3, %5
+    : !vecA, !vecB into !vecC
+
+  vector.store %6, %arg2[%c0, %c0] : !memrefC, !vecC
+  vector.store %7, %arg2[%c0, %c16] : !memrefC, !vecC
+
+  return %arg2 : !memrefC
+}
+
+// CHECK-LABEL: @negative_dynamic_offset
+// CHECK-NOT: vector.shuffle
+// CHECK-NOT: vector.shuffle
+// CHECK: vector.contract
+// CHECK-NOT: vector.shuffle
+// CHECK-NOT: vector.shuffle
+// CHECK: vector.store
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.x86vector.shuffle_bf16_vector_contract_result
+    } : !transform.any_op
+    transform.yield
+  }
+}
+
+
+// -----
+
+!vecA = vector<1x1xbf16>
+!vecB = vector<1x16xbf16>
+!vecC = vector<1x16xf32>
+!memrefA = memref<4x2xbf16>
+!memrefB = memref<2x64xbf16>
+!memrefC = memref<2x64xf32>
+#map = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0,  d1, d2) -> (d2, d1)>
+#map2 = affine_map<(d0,  d1, d2) -> (d0, d1)>
+func.func @negative_contracts_not_in_order(
+  %arg0: !memrefA, %arg1: !memrefB, %arg2: !memrefC) -> !memrefC
+{
+  %c0 = arith.constant 0 : index
+  %c16 = arith.constant 16 : index
+  %c32 = arith.constant 32 : index
+  %0 = ub.poison : bf16
+  %32 = ub.poison : f32
+  %1 = vector.load %arg0[%c0, %c0] :
+        !memrefA, !vecA
+  %2 = vector.load %arg1[%c0, %c0] :
+        !memrefB, !vecB
+  %3 = vector.load %arg1[%c0, %c32] :
+        !memrefB, !vecB
+  %4 = vector.load %arg2[%c0, %c0] :
+        !memrefC, !vecC
+  %5 = vector.load %arg2[%c0, %c16] :
+        !memrefC, !vecC
+
+  %7 = vector.contract {
+    indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["parallel", "parallel", "reduction"],
+    kind = #vector.kind<add>}
+    %1, %3, %5
+    : !vecA, !vecB into !vecC
+
+  %6 = vector.contract {
+    indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["parallel", "parallel", "reduction"],
+    kind = #vector.kind<add>}
+    %1, %2, %4
+    : !vecA, !vecB into !vecC
+
+  vector.store %6, %arg2[%c0, %c0] : !memrefC, !vecC
+  vector.store %7, %arg2[%c0, %c16] : !memrefC, !vecC
+
+  return %arg2 : !memrefC
+}
+
+// CHECK-LABEL: @negative_contracts_not_in_order
+// CHECK-NOT: vector.shuffle
+// CHECK-NOT: vector.shuffle
+// CHECK: vector.contract
+// CHECK-NOT: vector.shuffle
+// CHECK-NOT: vector.shuffle
+// CHECK: vector.store
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.x86vector.shuffle_bf16_vector_contract_result
+    } : !transform.any_op
+    transform.yield
+  }
+}

>From 1d7a3f54b81a7ada1da33fa72b8efd78f946c87d Mon Sep 17 00:00:00 2001
From: Arun Thangamani <arun.thangamani at intel.com>
Date: Mon, 12 Jan 2026 00:35:50 -0800
Subject: [PATCH 06/17] Added support for bf16dp flat layout + unit tests.

---
 .../VectorContractToPackedTypeDotProduct.cpp  | 301 +++++++++----
 ...or-contract-to-packed-type-dotproduct.mlir | 412 +++++++++++++++++-
 2 files changed, 603 insertions(+), 110 deletions(-)

diff --git a/mlir/lib/Dialect/X86Vector/Transforms/VectorContractToPackedTypeDotProduct.cpp b/mlir/lib/Dialect/X86Vector/Transforms/VectorContractToPackedTypeDotProduct.cpp
index a00a3e5bdd766..8ec8688cf3c19 100644
--- a/mlir/lib/Dialect/X86Vector/Transforms/VectorContractToPackedTypeDotProduct.cpp
+++ b/mlir/lib/Dialect/X86Vector/Transforms/VectorContractToPackedTypeDotProduct.cpp
@@ -8,6 +8,7 @@
 
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
 #include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/Dialect/Vector/Utils/VectorUtils.h"
 #include "mlir/Dialect/X86Vector/Transforms.h"
@@ -27,10 +28,83 @@ using namespace mlir::x86vector;
 
 namespace {
 
+// Returns true if the A or B matrix vector is packed (shuffled) to 
+// VNNI layout, already.
+static bool isNonUnitDimOperandShuffled(Value nonUnitDimOperand) {
+ if (Operation *defOp = nonUnitDimOperand.getDefiningOp()) {
+   if (isa<vector::ShuffleOp>(defOp)) {
+      return true;
+   }
+
+   if (isa<vector::ShapeCastOp>(defOp)) {
+      Operation *defOpShpCst = defOp->getOperand(0).getDefiningOp();
+      if (isa<vector::ShuffleOp>(defOpShpCst)) {
+         return true;
+      }
+   }
+ }
+
+ return false;
+}
+
+static void rewriteUses(mlir::Value oldVal, mlir::Value newVal,
+                        mlir::Operation *targetContract,
+                        mlir::PatternRewriter &rewriter) {
+  for (mlir::OpOperand &use : llvm::make_early_inc_range(oldVal.getUses())) {
+
+    mlir::Operation *user = use.getOwner();
+    if (mlir::isa<mlir::vector::ContractionOp>(user) ||
+        mlir::isa<mlir::scf::ForOp>(user)) {
+      use.set(newVal);
+    }
+  }
+}
+
+// Function to convert the flat layout A or B matrix vector<32xbf16>
+// into VNNI packed layout using the vpunpack operations
+static void packNonUnitDimOperandToVNNI(mlir::PatternRewriter &rewriter,
+                              mlir::Operation *opA, mlir::Operation *opB,
+                              mlir::vector::ContractionOp contractA,
+                              mlir::vector::ContractionOp contractB,
+                              int64_t nonUnitDimAcc, mlir::VectorType Ty) {
+  mlir::Operation *insertAfter = opA->isBeforeInBlock(opB) ? opB : opA;
+
+  rewriter.setInsertionPointAfter(insertAfter);
+  mlir::Location loc = insertAfter->getLoc();
+
+  auto elemTy = Ty.getElementType();
+  auto flatTy = mlir::VectorType::get(nonUnitDimAcc, elemTy);
+
+  auto castA = mlir::vector::ShapeCastOp::create(rewriter, loc, flatTy,
+                                                 opA->getResult(0));
+  auto castB = mlir::vector::ShapeCastOp::create(rewriter, loc, flatTy,
+                                                 opB->getResult(0));
+
+  static constexpr int64_t maskLo[] = {
+      0,  32, 1,  33, 2,  34, 3,  35, 8,  40, 9,  41, 10, 42, 11, 43,
+      16, 48, 17, 49, 18, 50, 19, 51, 24, 56, 25, 57, 26, 58, 27, 59};
+  static constexpr int64_t maskHi[] = {
+      4,  36, 5,  37, 6,  38, 7,  39, 12, 44, 13, 45, 14, 46, 15, 47,
+      20, 52, 21, 53, 22, 54, 23, 55, 28, 60, 29, 61, 30, 62, 31, 63};
+
+  auto shuffleLo = mlir::vector::ShuffleOp::create(rewriter, loc, flatTy, castA,
+                                                   castB, maskLo);
+  auto shuffleHi = mlir::vector::ShuffleOp::create(rewriter, loc, flatTy, castA,
+                                                   castB, maskHi);
+
+  auto newA =
+      mlir::vector::ShapeCastOp::create(rewriter, loc, Ty, shuffleLo);
+  auto newB =
+      mlir::vector::ShapeCastOp::create(rewriter, loc, Ty, shuffleHi);
+
+  rewriteUses(opA->getResult(0), newA.getResult(), contractA, rewriter);
+  rewriteUses(opB->getResult(0), newB.getResult(), contractB, rewriter);
+}
+
 // Implements packed type outer product contraction as a sequence
 // of broadcast and packed dot-product operations.
 //
-// For example - for F32 type:
+// For example - for bf16 type (VNNI):
 // ```
 //   vector.contract <1x1x2xbf16>, <1x16x2xbf16> into <1x16xf32>
 // ```
@@ -39,6 +113,25 @@ namespace {
 //   vector.broadcast %lhs to <32xbf16>
 //   x86vector.avx512.dot vector<32xbf16> -> vector<16xf32>
 // ```
+//
+// For example - for bf16 type (Flat layout):
+// ```
+//   %1 = vector.load -> <2x16xbf16>
+//   %2 = vector.load -> <2x16xbf16>
+//   vector.contract <1x2xbf16>, %1 into <1x16xf32>
+//   vector.contract <1x2xbf16>, %2 into <1x16xf32>
+// ```
+// to
+// ```
+//   %1 = vector.load -> <2x16xbf16>
+//   %2 = vector.load -> <2x16xbf16>
+//   %3 = vector.shuffle %1, %2 [0, 32, 1, ... 27, 59]
+//   %4 = vector.shuffle %1, %2 [4, 36, 5, ... 31, 63]
+//   vector.broadcast %lhs to <32xbf16>
+//   x86vector.avx512.dot vector<32xbf16>, %3 -> vector<16xf32>
+//   vector.broadcast %lhs to <32xbf16>
+//   x86vector.avx512.dot vector<32xbf16>, %3 -> vector<16xf32>
+// ```
 struct VectorContractToPackedTypeDotProduct
     : public OpRewritePattern<vector::ContractionOp> {
   using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
@@ -57,10 +150,38 @@ struct VectorContractToPackedTypeDotProduct
           contractOp, "Only BF16/Int8 lowering is supported.");
 
     unsigned int blockingFactor = lhsTy.getElementType().isBF16() ? 2 : 4;
-    if (!isInVnniLayout(contractOp.getOperation(),
-                        contractOp.getIndexingMapsArray(), blockingFactor))
-      return rewriter.notifyMatchFailure(contractOp,
-                                         "Input matrices not in VNNI format.");
+    bool isVnni = isInVnniLayout(contractOp.getOperation(),
+                        contractOp.getIndexingMapsArray(), blockingFactor);
+
+    if (lhsTy.getElementType().isSignlessInteger(8) && !isVnni)
+      return failure();
+
+    VectorType accTy = dyn_cast<VectorType>(contractOp.getAccType());
+    if (!accTy)
+      return rewriter.notifyMatchFailure(contractOp, "Wrong accmulator type.");
+
+    ArrayRef<int64_t> accShape = accTy.getShape();
+    llvm::SmallVector<int64_t> nonUnitDimAcc;
+    llvm::copy_if(accShape, std::back_inserter(nonUnitDimAcc),
+                  [](int64_t dim) { return dim != 1; });
+    if (nonUnitDimAcc.size() != 1)
+      return rewriter.notifyMatchFailure(
+          contractOp, "A or B should be a non-unit dim in acc.");
+
+    int64_t nonUnitDimValue = nonUnitDimAcc.front();
+    // Non-unit dimensions should match the vector length of BF16 or Int8
+    // dot-product.
+    if (lhsTy.getElementType().isBF16() && nonUnitDimValue != 4 && nonUnitDimValue != 8 &&
+        nonUnitDimValue != 16)
+      return rewriter.notifyMatchFailure(
+          contractOp, "BF16 dot-product operation expects non-unit (LHR or "
+                      "RHS) dim and acc dim of size 4/8/16.");
+
+    if (lhsTy.getElementType().isSignlessInteger(8) && nonUnitDimValue != 4 &&
+        nonUnitDimValue != 8)
+      return rewriter.notifyMatchFailure(
+          contractOp, "Int8 dot-product operation expects non-unit (LHR or "
+                      "RHS) dim and acc dim of size 4/8.");
 
     ArrayRef<int64_t> lhsShape = lhsTy.getShape();
     llvm::SmallVector<int64_t> nonUnitDimLhs;
@@ -76,16 +197,20 @@ struct VectorContractToPackedTypeDotProduct
     if ((nonUnitDimLhs.size() - 1) > 0 && (nonUnitDimRhs.size() - 1) > 0)
       return rewriter.notifyMatchFailure(contractOp,
                                          "Excepts unit dimensions for either "
-                                         "LHS or RHS shape other than VNNI.");
+                                         "LHS or RHS shape.");
 
     if ((nonUnitDimLhs.size() - 1) != 1 && (nonUnitDimRhs.size() - 1) != 1)
       return rewriter.notifyMatchFailure(
           contractOp,
           "Excepts a one non-unit A/B dimension for either LHS or RHS shape.");
 
-    VectorType accTy = dyn_cast<VectorType>(contractOp.getAccType());
-    if (!accTy)
-      return rewriter.notifyMatchFailure(contractOp, "Wrong accmulator type.");
+    bool rhsHasMultipleNonUnitDims = (nonUnitDimRhs.size() - 1) > 0;
+    int64_t extraFlatDim = rhsHasMultipleNonUnitDims ? nonUnitDimLhs.front() : nonUnitDimRhs.size();
+
+    if (!isVnni && (extraFlatDim != blockingFactor))
+      return rewriter.notifyMatchFailure(
+          contractOp,
+          "The K or reduction dim for flat layout should be 2.");
 
     if ((lhsTy.getElementType().isBF16() && !accTy.getElementType().isF32()) ||
         (lhsTy.getElementType().isSignlessInteger(8) &&
@@ -94,109 +219,105 @@ struct VectorContractToPackedTypeDotProduct
                                          "Only F32 for BF16 or Int32 for Int8 "
                                          "accumulation type is supported.");
 
-    ArrayRef<int64_t> accShape = accTy.getShape();
-    llvm::SmallVector<int64_t> nonUnitDimAcc;
-    llvm::copy_if(accShape, std::back_inserter(nonUnitDimAcc),
-                  [](int64_t dim) { return dim != 1; });
-    if (nonUnitDimAcc.size() != 1)
-      return rewriter.notifyMatchFailure(
-          contractOp, "A or B should be a non-unit dim in acc.");
+    Value unitDimOperand =
+        rhsHasMultipleNonUnitDims ? contractOp.getLhs() : contractOp.getRhs();
+    Value nonUnitDimOperand =
+        rhsHasMultipleNonUnitDims ? contractOp.getRhs() : contractOp.getLhs();
 
-    // Non-unit dimensions should match the vector length of BF16 or Int8
-    // dot-product.
-    unsigned int nonUnitDim = nonUnitDimLhs.size() == 2 ? nonUnitDimLhs.front()
-                                                        : nonUnitDimRhs.front();
-    if (lhsTy.getElementType().isBF16() && nonUnitDim != 4 && nonUnitDim != 8 &&
-        nonUnitDim != 16 && nonUnitDimAcc.front() == nonUnitDim)
-      return rewriter.notifyMatchFailure(
-          contractOp, "BF16 dot-product operation expects non-unit (LHR or "
-                      "RHS) dim and acc dim of size 4/8/16.");
+    // If the A or B matrix vector of the contact operation is not packed, then
+    // find it's pair contract operation and pack (shuffle) them to VNNI packed.
+    if (!isVnni && !isNonUnitDimOperandShuffled(nonUnitDimOperand)) {
+      vector::ContractionOp pairContractOp;
+      Operation *nextOp = contractOp;
+      while ((nextOp = nextOp->getNextNode())) {
+        auto contOp = dyn_cast<vector::ContractionOp>(nextOp);
 
-    if (lhsTy.getElementType().isSignlessInteger(8) && nonUnitDim != 4 &&
-        nonUnitDim != 8 && nonUnitDimAcc.front() == nonUnitDim)
-      return rewriter.notifyMatchFailure(
-          contractOp, "Int8 dot-product operation expects non-unit (LHR or "
-                      "RHS) dim and acc dim of size 4/8.");
+        if (!contOp)
+          continue;
 
+        if (validatePairVectorContract(contractOp, contOp,
+                                       rhsHasMultipleNonUnitDims,
+                                       nonUnitDimValue)) {
+          pairContractOp = contOp;
+          break;
+        }
+      }
+
+      if (!pairContractOp)
+        return rewriter.notifyMatchFailure(
+            contractOp, "Could not find a contract pair");
+
+      Value nonUnitDimOperandPairContract =
+            rhsHasMultipleNonUnitDims ? pairContractOp.getRhs() : pairContractOp.getLhs();
+
+      // Get the non-packed A or B matrix's vector<32xbf16> elements.
+      Operation *nonUnitDimReadOp =
+          traceToVectorReadLikeParentOperation(nonUnitDimOperand);
+      Operation *nonUnitDimReadOpPairContract =
+          traceToVectorReadLikeParentOperation(nonUnitDimOperandPairContract);
+
+      if (!nonUnitDimReadOp || !nonUnitDimReadOpPairContract)
+        return rewriter.notifyMatchFailure(
+            contractOp, "Could not find a valid contract pair");
+
+      if (contractOp->getBlock() == nonUnitDimReadOpPairContract->getBlock() &&
+        contractOp->isBeforeInBlock(nonUnitDimReadOpPairContract))
+        return rewriter.notifyMatchFailure(
+            contractOp, "The load/read operation of pair contract operation is "
+                        "after the contractOp");
+
+      VectorType nonUnitDimTy = rhsHasMultipleNonUnitDims ? contractOp.getRhsType() : contractOp.getLhsType();
+      
+      packNonUnitDimOperandToVNNI(rewriter, nonUnitDimReadOp, nonUnitDimReadOpPairContract, contractOp, pairContractOp,
+                                blockingFactor * nonUnitDimValue, nonUnitDimTy);
+
+      nonUnitDimOperand =
+        rhsHasMultipleNonUnitDims ? contractOp.getRhs() : contractOp.getLhs();
+    }
+
+    rewriter.setInsertionPoint(contractOp);
     auto loc = contractOp.getLoc();
     auto castAcc = vector::ShapeCastOp::create(
         rewriter, loc,
         VectorType::get(nonUnitDimAcc.front(), accTy.getElementType()),
         contractOp.getAcc());
 
+    VectorType nonUnitDimTy = rhsHasMultipleNonUnitDims ? contractOp.getRhsType() : contractOp.getLhsType();
+    VectorType unitDimTy = rhsHasMultipleNonUnitDims ? contractOp.getLhsType() : contractOp.getRhsType();
+
     Value dp;
 
-    // Broadcast the unit-dimension LHS or RHS to match the vector length of the
-    // corresponding non-unit dimension on the other operand. For example,
-    // if LHS has type vector<1x1x2xbf16> and RHS has type vector<1x16x2xbf16>,
-    // we broadcast the LHS to vector<16x2xbf16>. In the opposite case (non-unit
-    // dimension on the LHS), we broadcast the RHS instead.
-    if ((nonUnitDimRhs.size() - 1) > 0) {
-      auto castRhs = vector::ShapeCastOp::create(
-          rewriter, loc,
-          VectorType::get(nonUnitDimRhs.front() * nonUnitDimRhs.back(),
-                          rhsTy.getElementType()),
-          contractOp.getRhs());
-      auto castLhs = vector::ShapeCastOp::create(
-          rewriter, loc,
-          VectorType::get(nonUnitDimLhs.front(), lhsTy.getElementType()),
-          contractOp.getLhs());
-      auto bitcastLhs = vector::BitCastOp::create(
-          rewriter, loc, VectorType::get({1}, rewriter.getIntegerType(32)),
-          castLhs);
-      auto broadcastLhs = vector::BroadcastOp::create(
+    auto castNonUnitDim = vector::ShapeCastOp::create(
           rewriter, loc,
-          VectorType::get({nonUnitDimRhs.front()}, rewriter.getIntegerType(32)),
-          bitcastLhs);
-      auto bitcastLhsPkType = vector::BitCastOp::create(
-          rewriter, loc, castRhs.getResult().getType(), broadcastLhs);
-
-      if (lhsTy.getElementType().isBF16()) {
-        dp = x86vector::DotBF16Op::create(
-            rewriter, loc,
-            VectorType::get(nonUnitDimRhs.front(), rewriter.getF32Type()),
-            castAcc, bitcastLhsPkType, castRhs);
-      }
+          VectorType::get(blockingFactor * nonUnitDimValue,
+                          nonUnitDimTy.getElementType()), nonUnitDimOperand);
 
-      if (lhsTy.getElementType().isSignlessInteger(8)) {
-        dp = x86vector::DotInt8Op::create(
-            rewriter, loc,
-            VectorType::get(nonUnitDimRhs.front(), rewriter.getIntegerType(32)),
-            castAcc, bitcastLhsPkType, castRhs);
-      }
-    } else {
-      auto castLhs = vector::ShapeCastOp::create(
-          rewriter, loc,
-          VectorType::get(nonUnitDimLhs.front() * nonUnitDimLhs.back(),
-                          lhsTy.getElementType()),
-          contractOp.getLhs());
-      auto castRhs = vector::ShapeCastOp::create(
+    auto castUnitDim = vector::ShapeCastOp::create(
           rewriter, loc,
-          VectorType::get(nonUnitDimRhs.front(), rhsTy.getElementType()),
-          contractOp.getRhs());
-      auto bitcastRhs = vector::BitCastOp::create(
+          VectorType::get(blockingFactor, unitDimTy.getElementType()),
+          unitDimOperand);
+    auto bitcastUnitDim = vector::BitCastOp::create(
           rewriter, loc, VectorType::get({1}, rewriter.getIntegerType(32)),
-          castRhs);
-      auto broadcastRhs = vector::BroadcastOp::create(
+          castUnitDim);
+    auto broadcastUnitDim = vector::BroadcastOp::create(
           rewriter, loc,
-          VectorType::get({nonUnitDimLhs.front()}, rewriter.getIntegerType(32)),
-          bitcastRhs);
-      auto bitcastRhsPkType = vector::BitCastOp::create(
-          rewriter, loc, castLhs.getResult().getType(), broadcastRhs);
+          VectorType::get({nonUnitDimValue}, rewriter.getIntegerType(32)),
+          bitcastUnitDim);
+    auto bitcastUnitDimPkType = vector::BitCastOp::create(
+          rewriter, loc, castNonUnitDim.getResult().getType(), broadcastUnitDim);
 
-      if (lhsTy.getElementType().isBF16()) {
+    if (lhsTy.getElementType().isBF16()) {
         dp = x86vector::DotBF16Op::create(
             rewriter, loc,
-            VectorType::get(nonUnitDimLhs.front(), rewriter.getF32Type()),
-            castAcc, castLhs, bitcastRhsPkType);
-      }
+            VectorType::get(nonUnitDimValue, rewriter.getF32Type()),
+            castAcc, bitcastUnitDimPkType, castNonUnitDim);
+    }
 
-      if (lhsTy.getElementType().isSignlessInteger(8)) {
+    if (lhsTy.getElementType().isSignlessInteger(8)) {
         dp = x86vector::DotInt8Op::create(
             rewriter, loc,
-            VectorType::get(nonUnitDimLhs.front(), rewriter.getIntegerType(32)),
-            castAcc, castLhs, bitcastRhsPkType);
-      }
+            VectorType::get(nonUnitDimValue, rewriter.getIntegerType(32)),
+            castAcc, bitcastUnitDimPkType, castNonUnitDim);
     }
 
     if (!dp)
diff --git a/mlir/test/Dialect/X86Vector/vector-contract-to-packed-type-dotproduct.mlir b/mlir/test/Dialect/X86Vector/vector-contract-to-packed-type-dotproduct.mlir
index 65676cbae772c..4026bf9dee912 100644
--- a/mlir/test/Dialect/X86Vector/vector-contract-to-packed-type-dotproduct.mlir
+++ b/mlir/test/Dialect/X86Vector/vector-contract-to-packed-type-dotproduct.mlir
@@ -240,13 +240,13 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
-!vecA = vector<16x1x2xbf16>
-!vecB = vector<1x1x2xbf16>
-!vecC = vector<16x1xf32>
+!vecA = vector<1x1x4xi8>
+!vecB = vector<1x8x4xi8>
+!vecC = vector<1x8xi32>
 #map = affine_map<(d4, d1, d2, d3) -> (d1, d3, d4)>
 #map1 = affine_map<(d4, d1, d2, d3) -> (d3, d2, d4)>
 #map2 = affine_map<(d4, d1, d2, d3) -> (d1, d2)>
-func.func @matmul_outer_product_to_bf16dp_bcst_B(
+func.func @matmul_outer_product_to_int8dp(
   %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC
 {
   %0 = vector.contract {
@@ -258,14 +258,84 @@ func.func @matmul_outer_product_to_bf16dp_bcst_B(
   return %0 : !vecC
 }
 
-// CHECK-LABEL: @matmul_outer_product_to_bf16dp_bcst_B
+// CHECK-LABEL: @matmul_outer_product_to_int8dp
 // CHECK: vector.broadcast
+// CHECK: x86vector.avx.dot.i8
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product
+    } : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+!vecA = vector<1x2xbf16>
+!vecB = vector<2x16xbf16>
+!vecC = vector<1x16xf32>
+!memrefA = memref<4x2xbf16>
+!memrefB = memref<2x32xbf16>
+!memrefC = memref<2x32xf32>
+#map = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0,  d1, d2) -> (d2, d1)>
+#map2 = affine_map<(d0,  d1, d2) -> (d0, d1)>
+func.func @matmul_bf16dp_flat_layout(
+  %arg0: !memrefA, %arg1: !memrefB, %arg2: !memrefC) -> !memrefC
+{
+  %c0 = arith.constant 0 : index
+  %c16 = arith.constant 16 : index
+  %0 = ub.poison : bf16
+  %32 = ub.poison : f32
+  %1 = vector.transfer_read %arg0[%c0, %c0], %0 {in_bounds = [true, true]} :
+        !memrefA, !vecA
+  %2 = vector.transfer_read %arg2[%c0, %c0], %32 {in_bounds = [true, true]} :
+        !memrefC, !vecC
+  %3 = vector.transfer_read %arg2[%c0, %c16], %32 {in_bounds = [true, true]} :
+        !memrefC, !vecC
+  %4 = vector.transfer_read %arg1[%c0, %c0], %0 {in_bounds = [true, true]} :
+        !memrefB, !vecB
+  %5 = vector.transfer_read %arg1[%c0, %c16], %0 {in_bounds = [true, true]} :
+        !memrefB, !vecB
+
+  %6 = vector.contract {
+    indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["parallel", "parallel", "reduction"],
+    kind = #vector.kind<add>}
+    %1, %4, %2
+    : !vecA, !vecB into !vecC
+
+  %7 = vector.contract {
+    indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["parallel", "parallel", "reduction"],
+    kind = #vector.kind<add>}
+    %1, %5, %3
+    : !vecA, !vecB into !vecC
+
+  vector.transfer_write %6, %arg2[%c0, %c0] {in_bounds = [true, true]} : !vecC, !memrefC
+  vector.transfer_write %7, %arg2[%c0, %c16] {in_bounds = [true, true]} : !vecC, !memrefC
+
+  return %arg2 : !memrefC
+}
+
+// CHECK-LABEL: @matmul_bf16dp_flat_layout
+// CHECK: vector.shuffle{{.*}}[0, 1, 2, 3, 16, 17, 18, 19, 4, 5, 6, 7, 20, 21, 22, 23] : vector<16xf32>, vector<16xf32>
+// CHECK-NEXT: vector.shuffle{{.*}}[8, 9, 10, 11, 24, 25, 26, 27, 12, 13, 14, 15, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32>
+// CHECK: vector.shuffle{{.*}}[0, 32, 1, 33, 2, 34, 3, 35, 8, 40, 9, 41, 10, 42, 11, 43, 16, 48, 17, 49, 18, 50, 19, 51, 24, 56, 25, 57, 26, 58, 27, 59] : vector<32xbf16>, vector<32xbf16>
+// CHECK-NEXT: vector.shuffle{{.*}}[4, 36, 5, 37, 6, 38, 7, 39, 12, 44, 13, 45, 14, 46, 15, 47, 20, 52, 21, 53, 22, 54, 23, 55, 28, 60, 29, 61, 30, 62, 31, 63] : vector<32xbf16>, vector<32xbf16>
 // CHECK: x86vector.avx512.dot
+// CHECK: x86vector.avx512.dot
+// CHECK: vector.shuffle{{.*}}[0, 1, 2, 3, 16, 17, 18, 19, 4, 5, 6, 7, 20, 21, 22, 23] : vector<16xf32>, vector<16xf32>
+// CHECK-NEXT: vector.shuffle{{.*}}[8, 9, 10, 11, 24, 25, 26, 27, 12, 13, 14, 15, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32>
 
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
     %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
     transform.apply_patterns to %func {
+      transform.apply_patterns.x86vector.shuffle_bf16_vector_contract_result
       transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product
     } : !transform.any_op
     transform.yield
@@ -274,27 +344,71 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
-!vecA = vector<1x1x4xi8>
-!vecB = vector<1x8x4xi8>
-!vecC = vector<1x8xi32>
-#map = affine_map<(d4, d1, d2, d3) -> (d1, d3, d4)>
-#map1 = affine_map<(d4, d1, d2, d3) -> (d3, d2, d4)>
-#map2 = affine_map<(d4, d1, d2, d3) -> (d1, d2)>
-func.func @matmul_outer_product_to_int8dp(
-  %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC
+!vecA = vector<1x2xbf16>
+!vecB = vector<2x16xbf16>
+!vecC = vector<1x16xf32>
+!memrefA = memref<4x2xbf16>
+!memrefB = memref<2x32xbf16>
+!memrefC = memref<2x32xf32>
+#map = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0,  d1, d2) -> (d2, d1)>
+#map2 = affine_map<(d0,  d1, d2) -> (d0, d1)>
+func.func @matmul_bf16dp_flat_layout_B_shuffled(
+  %arg0: !memrefA, %arg1: !memrefB, %arg2: !memrefC) -> !memrefC
 {
-  %0 = vector.contract {
+  %c0 = arith.constant 0 : index
+  %c16 = arith.constant 16 : index
+  %0 = ub.poison : bf16
+  %32 = ub.poison : f32
+  %1 = vector.load %arg0[%c0, %c0] :
+        !memrefA, !vecA
+  %2 = vector.load %arg1[%c0, %c0] :
+        !memrefB, !vecB
+  %3 = vector.load %arg1[%c0, %c16] :
+        !memrefB, !vecB
+  %4 = vector.load %arg2[%c0, %c0] :
+        !memrefC, !vecC
+  %5 = vector.load %arg2[%c0, %c16] :
+        !memrefC, !vecC
+
+  %6 = vector.shape_cast %2 : !vecB to vector<32xbf16>
+  %7 = vector.shape_cast %3 : !vecB to vector<32xbf16>
+  %8 = vector.shuffle %6, %7 [0, 32, 1, 33, 2, 34, 3, 35,
+        8, 40, 9, 41, 10, 42, 11, 43, 16, 48, 17, 49, 18,
+        50, 19, 51, 24, 56, 25, 57, 26, 58, 27, 59] :
+        vector<32xbf16>, vector<32xbf16>
+  %9 = vector.shuffle %6, %7 [4, 36, 5, 37, 6, 38, 7, 39,
+        12, 44, 13, 45, 14, 46, 15, 47, 20, 52, 21, 53,
+        22, 54, 23, 55, 28, 60, 29, 61, 30, 62, 31, 63] :
+        vector<32xbf16>, vector<32xbf16>
+
+  %10 = vector.shape_cast %8 : vector<32xbf16> to !vecB
+  %11 = vector.shape_cast %9 : vector<32xbf16> to !vecB
+
+  %12 = vector.contract {
     indexing_maps = [#map, #map1, #map2],
-    iterator_types = ["reduction", "parallel", "parallel", "reduction"],
+    iterator_types = ["parallel", "parallel", "reduction"],
     kind = #vector.kind<add>}
-    %arg0, %arg1, %arg2
+    %1, %10, %4
     : !vecA, !vecB into !vecC
-  return %0 : !vecC
+
+  %13 = vector.contract {
+    indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["parallel", "parallel", "reduction"],
+    kind = #vector.kind<add>}
+    %1, %11, %5
+    : !vecA, !vecB into !vecC
+
+  vector.store %12, %arg2[%c0, %c0]  : !memrefC, !vecC
+  vector.store %13, %arg2[%c0, %c16]  : !memrefC, !vecC
+
+  return %arg2 : !memrefC
 }
 
-// CHECK-LABEL: @matmul_outer_product_to_int8dp
-// CHECK: vector.broadcast
-// CHECK: x86vector.avx.dot.i8
+// CHECK-LABEL: @matmul_bf16dp_flat_layout_B_shuffled
+// CHECK: x86vector.avx512.dot
+// CHECK: x86vector.avx512.dot
+// CHECK-NOT: vector.contract
 
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
@@ -679,3 +793,261 @@ module attributes {transform.with_named_sequence} {
     transform.yield
   }
 }
+
+// -----
+
+!vecA = vector<1x4xbf16>
+!vecB = vector<4x16xbf16>
+!vecC = vector<1x16xf32>
+!memrefA = memref<4x4xbf16>
+!memrefB = memref<4x32xbf16>
+!memrefC = memref<2x32xf32>
+#map = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0,  d1, d2) -> (d2, d1)>
+#map2 = affine_map<(d0,  d1, d2) -> (d0, d1)>
+func.func @negative_flat_other_dim_is_not_2(
+  %arg0: !memrefA, %arg1: !memrefB, %arg2: !memrefC) -> !memrefC
+{
+  %c0 = arith.constant 0 : index
+  %c16 = arith.constant 16 : index
+  %0 = ub.poison : bf16
+  %32 = ub.poison : f32
+  %1 = vector.transfer_read %arg0[%c0, %c0], %0 {in_bounds = [true, true]} :
+        !memrefA, !vecA
+  %2 = vector.transfer_read %arg1[%c0, %c0], %0 {in_bounds = [true, true]} :
+        !memrefB, !vecB
+  %3 = vector.transfer_read %arg1[%c0, %c16], %0 {in_bounds = [true, true]} :
+        !memrefB, !vecB
+  %4 = vector.transfer_read %arg2[%c0, %c0], %32 {in_bounds = [true, true]} :
+        !memrefC, !vecC
+  %5 = vector.transfer_read %arg2[%c0, %c16], %32 {in_bounds = [true, true]} :
+        !memrefC, !vecC
+
+  %6 = vector.contract {
+    indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["parallel", "parallel", "reduction"],
+    kind = #vector.kind<add>}
+    %1, %2, %4
+    : !vecA, !vecB into !vecC
+
+  %7 = vector.contract {
+    indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["parallel", "parallel", "reduction"],
+    kind = #vector.kind<add>}
+    %1, %3, %5
+    : !vecA, !vecB into !vecC
+
+  vector.transfer_write %6, %arg2[%c0, %c0] {in_bounds = [true, true]} : !vecC, !memrefC
+  vector.transfer_write %7, %arg2[%c0, %c16] {in_bounds = [true, true]} : !vecC, !memrefC
+
+  return %arg2 : !memrefC
+}
+
+// CHECK-LABEL: @negative_flat_other_dim_is_not_2
+// CHECK-NOT: x86vector.avx512.dot
+// CHECK: vector.contract
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.x86vector.shuffle_bf16_vector_contract_result
+      transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product
+    } : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+!vecA = vector<1x2xbf16>
+!vecB = vector<2x16xbf16>
+!vecC = vector<1x16xf32>
+!memrefA = memref<4x2xbf16>
+!memrefB = memref<2x64xbf16>
+!memrefC = memref<2x32xf32>
+#map = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0,  d1, d2) -> (d2, d1)>
+#map2 = affine_map<(d0,  d1, d2) -> (d0, d1)>
+func.func @negative_flat_offset_diff_is_not16(
+  %arg0: !memrefA, %arg1: !memrefB, %arg2: !memrefC) -> !memrefC
+{
+  %c0 = arith.constant 0 : index
+  %c16 = arith.constant 16 : index
+  %c32 = arith.constant 32 : index
+  %0 = ub.poison : bf16
+  %32 = ub.poison : f32
+  %1 = vector.transfer_read %arg0[%c0, %c0], %0 {in_bounds = [true, true]} :
+        !memrefA, !vecA
+  %2 = vector.transfer_read %arg1[%c0, %c0], %0 {in_bounds = [true, true]} :
+        !memrefB, !vecB
+  %3 = vector.transfer_read %arg1[%c0, %c32], %0 {in_bounds = [true, true]} :
+        !memrefB, !vecB
+  %4 = vector.transfer_read %arg2[%c0, %c0], %32 {in_bounds = [true, true]} :
+        !memrefC, !vecC
+  %5 = vector.transfer_read %arg2[%c0, %c16], %32 {in_bounds = [true, true]} :
+        !memrefC, !vecC
+
+  %6 = vector.contract {
+    indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["parallel", "parallel", "reduction"],
+    kind = #vector.kind<add>}
+    %1, %2, %4
+    : !vecA, !vecB into !vecC
+
+  %7 = vector.contract {
+    indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["parallel", "parallel", "reduction"],
+    kind = #vector.kind<add>}
+    %1, %3, %5
+    : !vecA, !vecB into !vecC
+
+  vector.transfer_write %6, %arg2[%c0, %c0] {in_bounds = [true, true]} : !vecC, !memrefC
+  vector.transfer_write %7, %arg2[%c0, %c16] {in_bounds = [true, true]} : !vecC, !memrefC
+
+  return %arg2 : !memrefC
+}
+
+// CHECK-LABEL: @negative_flat_offset_diff_is_not16
+// CHECK-NOT: x86vector.avx512.dot
+// CHECK: vector.contract
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.x86vector.shuffle_bf16_vector_contract_result
+      transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product
+    } : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+!vecA = vector<1x2xbf16>
+!vecB = vector<2x16xbf16>
+!vecC = vector<1x16xf32>
+!memrefA = memref<4x2xbf16>
+!memrefB = memref<2x64xbf16>
+!memrefC = memref<2x32xf32>
+#map = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0,  d1, d2) -> (d2, d1)>
+#map2 = affine_map<(d0,  d1, d2) -> (d0, d1)>
+func.func @negative_flat_dynamic_offset(
+  %arg0: !memrefA, %arg1: !memrefB, %arg2: !memrefC, %arg3: index) -> !memrefC
+{
+  %c0 = arith.constant 0 : index
+  %c16 = arith.constant 16 : index
+  %0 = ub.poison : bf16
+  %32 = ub.poison : f32
+  %1 = vector.transfer_read %arg0[%c0, %c0], %0 {in_bounds = [true, true]} :
+        !memrefA, !vecA
+  %2 = vector.transfer_read %arg1[%c0, %c0], %0 {in_bounds = [true, true]} :
+        !memrefB, !vecB
+  %3 = vector.transfer_read %arg1[%c0, %arg3], %0 {in_bounds = [true, true]} :
+        !memrefB, !vecB
+  %4 = vector.transfer_read %arg2[%c0, %c0], %32 {in_bounds = [true, true]} :
+        !memrefC, !vecC
+  %5 = vector.transfer_read %arg2[%c0, %c16], %32 {in_bounds = [true, true]} :
+        !memrefC, !vecC
+
+  %6 = vector.contract {
+    indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["parallel", "parallel", "reduction"],
+    kind = #vector.kind<add>}
+    %1, %2, %4
+    : !vecA, !vecB into !vecC
+
+  %7 = vector.contract {
+    indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["parallel", "parallel", "reduction"],
+    kind = #vector.kind<add>}
+    %1, %3, %5
+    : !vecA, !vecB into !vecC
+
+  vector.transfer_write %6, %arg2[%c0, %c0] {in_bounds = [true, true]} : !vecC, !memrefC
+  vector.transfer_write %7, %arg2[%c0, %c16] {in_bounds = [true, true]} : !vecC, !memrefC
+
+  return %arg2 : !memrefC
+}
+
+// CHECK-LABEL: @negative_flat_dynamic_offset
+// CHECK-NOT: x86vector.avx512.dot
+// CHECK: vector.contract
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.x86vector.shuffle_bf16_vector_contract_result
+      transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product
+    } : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+!vecA = vector<1x2xbf16>
+!vecB = vector<2x16xbf16>
+!vecC = vector<1x16xf32>
+!memrefA = memref<4x2xbf16>
+!memrefB = memref<2x64xbf16>
+!memrefC = memref<2x32xf32>
+#map = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0,  d1, d2) -> (d2, d1)>
+#map2 = affine_map<(d0,  d1, d2) -> (d0, d1)>
+func.func @negative_flat_read_after_contract(
+  %arg0: !memrefA, %arg1: !memrefB, %arg2: !memrefC) -> !memrefC
+{
+  %c0 = arith.constant 0 : index
+  %c16 = arith.constant 16 : index
+  %0 = ub.poison : bf16
+  %32 = ub.poison : f32
+  %1 = vector.transfer_read %arg0[%c0, %c0], %0 {in_bounds = [true, true]} :
+        !memrefA, !vecA
+  %2 = vector.transfer_read %arg1[%c0, %c0], %0 {in_bounds = [true, true]} :
+        !memrefB, !vecB
+  %4 = vector.transfer_read %arg2[%c0, %c0], %32 {in_bounds = [true, true]} :
+        !memrefC, !vecC
+
+  %6 = vector.contract {
+    indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["parallel", "parallel", "reduction"],
+    kind = #vector.kind<add>}
+    %1, %2, %4
+    : !vecA, !vecB into !vecC      
+
+  %3 = vector.transfer_read %arg1[%c0, %c16], %0 {in_bounds = [true, true]} :
+        !memrefB, !vecB      
+  %5 = vector.transfer_read %arg2[%c0, %c16], %32 {in_bounds = [true, true]} :
+        !memrefC, !vecC
+
+  %7 = vector.contract {
+    indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["parallel", "parallel", "reduction"],
+    kind = #vector.kind<add>}
+    %1, %3, %5
+    : !vecA, !vecB into !vecC
+
+  vector.transfer_write %6, %arg2[%c0, %c0] {in_bounds = [true, true]} : !vecC, !memrefC
+  vector.transfer_write %7, %arg2[%c0, %c16] {in_bounds = [true, true]} : !vecC, !memrefC
+
+  return %arg2 : !memrefC
+}
+
+// CHECK-LABEL: @negative_flat_read_after_contract
+// CHECK-NOT: x86vector.avx512.dot
+// CHECK: vector.contract
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.x86vector.shuffle_bf16_vector_contract_result
+      transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product
+    } : !transform.any_op
+    transform.yield
+  }
+}

>From 75ba0f685a987e70178b52b380d817ae0909e70b Mon Sep 17 00:00:00 2001
From: Arun Thangamani <arun.thangamani at intel.com>
Date: Mon, 12 Jan 2026 00:39:35 -0800
Subject: [PATCH 07/17] Added support for bf16dp flat layout + unit tests.

---
 .../VectorContractToPackedTypeDotProduct.cpp  | 121 ++++++++++--------
 1 file changed, 65 insertions(+), 56 deletions(-)

diff --git a/mlir/lib/Dialect/X86Vector/Transforms/VectorContractToPackedTypeDotProduct.cpp b/mlir/lib/Dialect/X86Vector/Transforms/VectorContractToPackedTypeDotProduct.cpp
index 8ec8688cf3c19..80895c512b32b 100644
--- a/mlir/lib/Dialect/X86Vector/Transforms/VectorContractToPackedTypeDotProduct.cpp
+++ b/mlir/lib/Dialect/X86Vector/Transforms/VectorContractToPackedTypeDotProduct.cpp
@@ -28,23 +28,23 @@ using namespace mlir::x86vector;
 
 namespace {
 
-// Returns true if the A or B matrix vector is packed (shuffled) to 
+// Returns true if the A or B matrix vector is packed (shuffled) to
 // VNNI layout, already.
 static bool isNonUnitDimOperandShuffled(Value nonUnitDimOperand) {
- if (Operation *defOp = nonUnitDimOperand.getDefiningOp()) {
-   if (isa<vector::ShuffleOp>(defOp)) {
+  if (Operation *defOp = nonUnitDimOperand.getDefiningOp()) {
+    if (isa<vector::ShuffleOp>(defOp)) {
       return true;
-   }
+    }
 
-   if (isa<vector::ShapeCastOp>(defOp)) {
+    if (isa<vector::ShapeCastOp>(defOp)) {
       Operation *defOpShpCst = defOp->getOperand(0).getDefiningOp();
       if (isa<vector::ShuffleOp>(defOpShpCst)) {
-         return true;
+        return true;
       }
-   }
- }
+    }
+  }
 
- return false;
+  return false;
 }
 
 static void rewriteUses(mlir::Value oldVal, mlir::Value newVal,
@@ -63,10 +63,12 @@ static void rewriteUses(mlir::Value oldVal, mlir::Value newVal,
 // Function to convert the flat layout A or B matrix vector<32xbf16>
 // into VNNI packed layout using the vpunpack operations
 static void packNonUnitDimOperandToVNNI(mlir::PatternRewriter &rewriter,
-                              mlir::Operation *opA, mlir::Operation *opB,
-                              mlir::vector::ContractionOp contractA,
-                              mlir::vector::ContractionOp contractB,
-                              int64_t nonUnitDimAcc, mlir::VectorType Ty) {
+                                        mlir::Operation *opA,
+                                        mlir::Operation *opB,
+                                        mlir::vector::ContractionOp contractA,
+                                        mlir::vector::ContractionOp contractB,
+                                        int64_t nonUnitDimAcc,
+                                        mlir::VectorType Ty) {
   mlir::Operation *insertAfter = opA->isBeforeInBlock(opB) ? opB : opA;
 
   rewriter.setInsertionPointAfter(insertAfter);
@@ -92,10 +94,8 @@ static void packNonUnitDimOperandToVNNI(mlir::PatternRewriter &rewriter,
   auto shuffleHi = mlir::vector::ShuffleOp::create(rewriter, loc, flatTy, castA,
                                                    castB, maskHi);
 
-  auto newA =
-      mlir::vector::ShapeCastOp::create(rewriter, loc, Ty, shuffleLo);
-  auto newB =
-      mlir::vector::ShapeCastOp::create(rewriter, loc, Ty, shuffleHi);
+  auto newA = mlir::vector::ShapeCastOp::create(rewriter, loc, Ty, shuffleLo);
+  auto newB = mlir::vector::ShapeCastOp::create(rewriter, loc, Ty, shuffleHi);
 
   rewriteUses(opA->getResult(0), newA.getResult(), contractA, rewriter);
   rewriteUses(opB->getResult(0), newB.getResult(), contractB, rewriter);
@@ -150,8 +150,9 @@ struct VectorContractToPackedTypeDotProduct
           contractOp, "Only BF16/Int8 lowering is supported.");
 
     unsigned int blockingFactor = lhsTy.getElementType().isBF16() ? 2 : 4;
-    bool isVnni = isInVnniLayout(contractOp.getOperation(),
-                        contractOp.getIndexingMapsArray(), blockingFactor);
+    bool isVnni =
+        isInVnniLayout(contractOp.getOperation(),
+                       contractOp.getIndexingMapsArray(), blockingFactor);
 
     if (lhsTy.getElementType().isSignlessInteger(8) && !isVnni)
       return failure();
@@ -171,8 +172,8 @@ struct VectorContractToPackedTypeDotProduct
     int64_t nonUnitDimValue = nonUnitDimAcc.front();
     // Non-unit dimensions should match the vector length of BF16 or Int8
     // dot-product.
-    if (lhsTy.getElementType().isBF16() && nonUnitDimValue != 4 && nonUnitDimValue != 8 &&
-        nonUnitDimValue != 16)
+    if (lhsTy.getElementType().isBF16() && nonUnitDimValue != 4 &&
+        nonUnitDimValue != 8 && nonUnitDimValue != 16)
       return rewriter.notifyMatchFailure(
           contractOp, "BF16 dot-product operation expects non-unit (LHR or "
                       "RHS) dim and acc dim of size 4/8/16.");
@@ -205,12 +206,12 @@ struct VectorContractToPackedTypeDotProduct
           "Excepts a one non-unit A/B dimension for either LHS or RHS shape.");
 
     bool rhsHasMultipleNonUnitDims = (nonUnitDimRhs.size() - 1) > 0;
-    int64_t extraFlatDim = rhsHasMultipleNonUnitDims ? nonUnitDimLhs.front() : nonUnitDimRhs.size();
+    int64_t extraFlatDim = rhsHasMultipleNonUnitDims ? nonUnitDimLhs.front()
+                                                     : nonUnitDimRhs.size();
 
     if (!isVnni && (extraFlatDim != blockingFactor))
       return rewriter.notifyMatchFailure(
-          contractOp,
-          "The K or reduction dim for flat layout should be 2.");
+          contractOp, "The K or reduction dim for flat layout should be 2.");
 
     if ((lhsTy.getElementType().isBF16() && !accTy.getElementType().isF32()) ||
         (lhsTy.getElementType().isSignlessInteger(8) &&
@@ -244,11 +245,12 @@ struct VectorContractToPackedTypeDotProduct
       }
 
       if (!pairContractOp)
-        return rewriter.notifyMatchFailure(
-            contractOp, "Could not find a contract pair");
+        return rewriter.notifyMatchFailure(contractOp,
+                                           "Could not find a contract pair");
 
-      Value nonUnitDimOperandPairContract =
-            rhsHasMultipleNonUnitDims ? pairContractOp.getRhs() : pairContractOp.getLhs();
+      Value nonUnitDimOperandPairContract = rhsHasMultipleNonUnitDims
+                                                ? pairContractOp.getRhs()
+                                                : pairContractOp.getLhs();
 
       // Get the non-packed A or B matrix's vector<32xbf16> elements.
       Operation *nonUnitDimReadOp =
@@ -261,18 +263,21 @@ struct VectorContractToPackedTypeDotProduct
             contractOp, "Could not find a valid contract pair");
 
       if (contractOp->getBlock() == nonUnitDimReadOpPairContract->getBlock() &&
-        contractOp->isBeforeInBlock(nonUnitDimReadOpPairContract))
+          contractOp->isBeforeInBlock(nonUnitDimReadOpPairContract))
         return rewriter.notifyMatchFailure(
             contractOp, "The load/read operation of pair contract operation is "
                         "after the contractOp");
 
-      VectorType nonUnitDimTy = rhsHasMultipleNonUnitDims ? contractOp.getRhsType() : contractOp.getLhsType();
-      
-      packNonUnitDimOperandToVNNI(rewriter, nonUnitDimReadOp, nonUnitDimReadOpPairContract, contractOp, pairContractOp,
-                                blockingFactor * nonUnitDimValue, nonUnitDimTy);
+      VectorType nonUnitDimTy = rhsHasMultipleNonUnitDims
+                                    ? contractOp.getRhsType()
+                                    : contractOp.getLhsType();
+
+      packNonUnitDimOperandToVNNI(
+          rewriter, nonUnitDimReadOp, nonUnitDimReadOpPairContract, contractOp,
+          pairContractOp, blockingFactor * nonUnitDimValue, nonUnitDimTy);
 
       nonUnitDimOperand =
-        rhsHasMultipleNonUnitDims ? contractOp.getRhs() : contractOp.getLhs();
+          rhsHasMultipleNonUnitDims ? contractOp.getRhs() : contractOp.getLhs();
     }
 
     rewriter.setInsertionPoint(contractOp);
@@ -282,42 +287,46 @@ struct VectorContractToPackedTypeDotProduct
         VectorType::get(nonUnitDimAcc.front(), accTy.getElementType()),
         contractOp.getAcc());
 
-    VectorType nonUnitDimTy = rhsHasMultipleNonUnitDims ? contractOp.getRhsType() : contractOp.getLhsType();
-    VectorType unitDimTy = rhsHasMultipleNonUnitDims ? contractOp.getLhsType() : contractOp.getRhsType();
+    VectorType nonUnitDimTy = rhsHasMultipleNonUnitDims
+                                  ? contractOp.getRhsType()
+                                  : contractOp.getLhsType();
+    VectorType unitDimTy = rhsHasMultipleNonUnitDims ? contractOp.getLhsType()
+                                                     : contractOp.getRhsType();
 
     Value dp;
 
     auto castNonUnitDim = vector::ShapeCastOp::create(
-          rewriter, loc,
-          VectorType::get(blockingFactor * nonUnitDimValue,
-                          nonUnitDimTy.getElementType()), nonUnitDimOperand);
+        rewriter, loc,
+        VectorType::get(blockingFactor * nonUnitDimValue,
+                        nonUnitDimTy.getElementType()),
+        nonUnitDimOperand);
 
     auto castUnitDim = vector::ShapeCastOp::create(
-          rewriter, loc,
-          VectorType::get(blockingFactor, unitDimTy.getElementType()),
-          unitDimOperand);
+        rewriter, loc,
+        VectorType::get(blockingFactor, unitDimTy.getElementType()),
+        unitDimOperand);
     auto bitcastUnitDim = vector::BitCastOp::create(
-          rewriter, loc, VectorType::get({1}, rewriter.getIntegerType(32)),
-          castUnitDim);
+        rewriter, loc, VectorType::get({1}, rewriter.getIntegerType(32)),
+        castUnitDim);
     auto broadcastUnitDim = vector::BroadcastOp::create(
-          rewriter, loc,
-          VectorType::get({nonUnitDimValue}, rewriter.getIntegerType(32)),
-          bitcastUnitDim);
+        rewriter, loc,
+        VectorType::get({nonUnitDimValue}, rewriter.getIntegerType(32)),
+        bitcastUnitDim);
     auto bitcastUnitDimPkType = vector::BitCastOp::create(
-          rewriter, loc, castNonUnitDim.getResult().getType(), broadcastUnitDim);
+        rewriter, loc, castNonUnitDim.getResult().getType(), broadcastUnitDim);
 
     if (lhsTy.getElementType().isBF16()) {
-        dp = x86vector::DotBF16Op::create(
-            rewriter, loc,
-            VectorType::get(nonUnitDimValue, rewriter.getF32Type()),
-            castAcc, bitcastUnitDimPkType, castNonUnitDim);
+      dp = x86vector::DotBF16Op::create(
+          rewriter, loc,
+          VectorType::get(nonUnitDimValue, rewriter.getF32Type()), castAcc,
+          bitcastUnitDimPkType, castNonUnitDim);
     }
 
     if (lhsTy.getElementType().isSignlessInteger(8)) {
-        dp = x86vector::DotInt8Op::create(
-            rewriter, loc,
-            VectorType::get(nonUnitDimValue, rewriter.getIntegerType(32)),
-            castAcc, bitcastUnitDimPkType, castNonUnitDim);
+      dp = x86vector::DotInt8Op::create(
+          rewriter, loc,
+          VectorType::get(nonUnitDimValue, rewriter.getIntegerType(32)),
+          castAcc, bitcastUnitDimPkType, castNonUnitDim);
     }
 
     if (!dp)

>From d07ade88504ef8bcb903e997d80516a72a760b63 Mon Sep 17 00:00:00 2001
From: Arun Thangamani <arun.thangamani at intel.com>
Date: Mon, 12 Jan 2026 07:55:12 -0800
Subject: [PATCH 08/17] Comments + formatting and added corener test-cases

---
 .../Dialect/X86Vector/Utils/X86VectorUtils.h  | 25 +++++--
 .../X86Vector/Utils/X86VectorUtils.cpp        | 74 +++++--------------
 2 files changed, 37 insertions(+), 62 deletions(-)

diff --git a/mlir/include/mlir/Dialect/X86Vector/Utils/X86VectorUtils.h b/mlir/include/mlir/Dialect/X86Vector/Utils/X86VectorUtils.h
index 7267c5c2032d9..3b9a10f77d35f 100644
--- a/mlir/include/mlir/Dialect/X86Vector/Utils/X86VectorUtils.h
+++ b/mlir/include/mlir/Dialect/X86Vector/Utils/X86VectorUtils.h
@@ -29,28 +29,37 @@ namespace x86vector {
 bool isInVnniLayout(Operation *op, llvm::ArrayRef<AffineMap> indexingMaps,
                     std::optional<unsigned> blockingFactor = std::nullopt);
 
+// Returns true if two contraction ops form a valid pair for VNNI packing.
+// It verifies that both contractions share the appropriate operand, read from
+// the same source buffer, and use constant indices that differ by 8 or 16.
 bool validatePairVectorContract(vector::ContractionOp contractOp,
                                 vector::ContractionOp pairContOp,
                                 bool rhsHasMultipleNonUnitDims,
                                 int64_t nonUnitDimValue);
 
+// Walks backward from a value to find its originating vector read-like op
+// (vector.transfer_read or vector.load), following scf.for iter-args but
+// stopping at layout-transforming ops; returns the read op or nullptr.
 Operation *traceToVectorReadLikeParentOperation(mlir::Value v);
 
+// Recursively traces a value to find a downstream vector write-like op
+// (vector.transfer_write or vector.store), crossing scf.for/yield but
+// stopping at layout-altering ops; returns the first match or nullptr.
 Operation *traceToVectorWriteLikeUserOperation(mlir::Value v);
 
-void shuffleBeforeWriteLikeOp(PatternRewriter &rewriter, Operation *op,
-                              Operation *op1, int64_t nonUnitDimAcc,
-                              VectorType accTy);
-
+// Packs the accumulators of two flat BF16 vector.contraction ops into a
+// VNNI-packed layout and replaces the original accumulators to enable post-read
+// packing transformations.
 void shuffleAfterReadLikeOp(PatternRewriter &rewriter, Operation *op,
                             Operation *op1, vector::ContractionOp contractOp,
                             vector::ContractionOp pairContractOp,
                             int64_t nonUnitDimAcc, VectorType accTy);
 
-void shuffleNonUnitDimOperand(PatternRewriter &rewriter, Operation *op,
-                              Operation *op1, vector::ContractionOp contractOp,
-                              vector::ContractionOp pairContractOp,
-                              int64_t nonUnitDimAcc, VectorType Ty);
+// Shuffles vectors produced by vector.contraction ops into a flat layout
+// before they are written to memory.
+void shuffleBeforeWriteLikeOp(PatternRewriter &rewriter, Operation *op,
+                              Operation *op1, int64_t nonUnitDimAcc,
+                              VectorType accTy);
 
 } // namespace x86vector
 } // namespace mlir
diff --git a/mlir/lib/Dialect/X86Vector/Utils/X86VectorUtils.cpp b/mlir/lib/Dialect/X86Vector/Utils/X86VectorUtils.cpp
index 38edb8091ce72..ac7d70f07c377 100644
--- a/mlir/lib/Dialect/X86Vector/Utils/X86VectorUtils.cpp
+++ b/mlir/lib/Dialect/X86Vector/Utils/X86VectorUtils.cpp
@@ -120,10 +120,11 @@ inline ShuffleMasks getShuffleMasks(int64_t nonUnitDimAcc) {
   // We only support these two layouts for now.
   assert((nonUnitDimAcc == 8 || nonUnitDimAcc == 16) &&
          "Unsupported nonUnitDimAcc value");
-
+  // Do interleaving between two <8xf32> targeting AVX2.
   static constexpr int64_t maskLo8[] = {0, 8, 1, 9, 2, 10, 3, 11};
   static constexpr int64_t maskHi8[] = {4, 12, 5, 13, 6, 14, 7, 15};
 
+  // Shuffle two <16xf32> as below targeting AVX512.
   static constexpr int64_t maskLo16[] = {0, 1, 2, 3, 16, 17, 18, 19,
                                          4, 5, 6, 7, 20, 21, 22, 23};
   static constexpr int64_t maskHi16[] = {8,  9,  10, 11, 24, 25, 26, 27,
@@ -132,10 +133,16 @@ inline ShuffleMasks getShuffleMasks(int64_t nonUnitDimAcc) {
   if (nonUnitDimAcc == 16)
     return {maskLo16, maskHi16};
 
-  // nonUnitDimAcc == 8
   return {maskLo8, maskHi8};
 }
 
+// This function walks backward from a value to locate its originating
+// vector read-like operation (`vector.transfer_read` or `vector.load`).
+// It follows simple forwarding through unary ops and across `scf.for`
+// loop iter-arguments, while stopping if layout-transforming ops such
+// as `shape_cast` or `shuffle` are encountered. The traversal returns
+// the read-like defining operation or `nullptr` if no valid source
+// is found.
 Operation *traceToVectorReadLikeParentOperation(Value v) {
   while (true) {
     // Case 1: Value defined by an operation
@@ -180,6 +187,12 @@ Operation *traceToVectorReadLikeParentOperation(Value v) {
   }
 }
 
+// This function recursively traces a value through its uses to find
+// a downstream vector write-like operation (`vector.transfer_write`
+// or `vector.store`). It transparently follows values across `scf.for`
+// and `scf.yield` boundaries while stopping if layout-altering ops such
+// as `shape_cast` or `shuffle` are encountered. The traversal returns
+// the first matching write-like user or `nullptr` if none is found.
 Operation *traceToVectorWriteLikeUserOperation(Value v) {
   for (OpOperand &use : v.getUses()) {
     Operation *user = use.getOwner();
@@ -227,8 +240,6 @@ static void rewriteUses(mlir::Value oldVal, mlir::Value newVal,
   for (mlir::OpOperand &use : llvm::make_early_inc_range(oldVal.getUses())) {
 
     mlir::Operation *user = use.getOwner();
-
-    // if (user == targetContract ||
     if (mlir::isa<mlir::vector::ContractionOp>(user) ||
         mlir::isa<mlir::scf::ForOp>(user)) {
       use.set(newVal);
@@ -236,6 +247,9 @@ static void rewriteUses(mlir::Value oldVal, mlir::Value newVal,
   }
 }
 
+// This function packs the accumulator of two flat BF16 vector.contract
+// operations into VNNI packed and are then replaced in their respective
+// contraction ops, enabling post-read layout or packing transformations.
 void shuffleAfterReadLikeOp(mlir::PatternRewriter &rewriter,
                             mlir::Operation *opA, mlir::Operation *opB,
                             mlir::vector::ContractionOp contractA,
@@ -251,7 +265,6 @@ void shuffleAfterReadLikeOp(mlir::PatternRewriter &rewriter,
 
   auto castA = mlir::vector::ShapeCastOp::create(rewriter, loc, flatTy,
                                                  opA->getResult(0));
-
   auto castB = mlir::vector::ShapeCastOp::create(rewriter, loc, flatTy,
                                                  opB->getResult(0));
 
@@ -259,21 +272,20 @@ void shuffleAfterReadLikeOp(mlir::PatternRewriter &rewriter,
 
   auto shuffleLo = mlir::vector::ShuffleOp::create(rewriter, loc, flatTy, castA,
                                                    castB, masks.maskLo);
-
   auto shuffleHi = mlir::vector::ShuffleOp::create(rewriter, loc, flatTy, castA,
                                                    castB, masks.maskHi);
 
   auto newAccA =
       mlir::vector::ShapeCastOp::create(rewriter, loc, accTy, shuffleLo);
-
   auto newAccB =
       mlir::vector::ShapeCastOp::create(rewriter, loc, accTy, shuffleHi);
 
   rewriteUses(opA->getResult(0), newAccA.getResult(), contractA, rewriter);
-
   rewriteUses(opB->getResult(0), newAccB.getResult(), contractB, rewriter);
 }
 
+// This function shuffles the vectors written by vector.ocntract operation
+// as a flat layout structure before they are stored.
 void shuffleBeforeWriteLikeOp(mlir::PatternRewriter &rewriter,
                               mlir::Operation *opA, mlir::Operation *opB,
                               int64_t nonUnitDimAcc, mlir::VectorType accTy) {
@@ -302,7 +314,6 @@ void shuffleBeforeWriteLikeOp(mlir::PatternRewriter &rewriter,
 
   // Flatten vectors
   auto castA = mlir::vector::ShapeCastOp::create(rewriter, loc, flatTy, vecA);
-
   auto castB = mlir::vector::ShapeCastOp::create(rewriter, loc, flatTy, vecB);
 
   // TODO: derive shuffle masks instead of hard-coding
@@ -310,14 +321,12 @@ void shuffleBeforeWriteLikeOp(mlir::PatternRewriter &rewriter,
 
   auto shuffledLo = mlir::vector::ShuffleOp::create(rewriter, loc, flatTy,
                                                     castA, castB, masks.maskLo);
-
   auto shuffledHi = mlir::vector::ShuffleOp::create(rewriter, loc, flatTy,
                                                     castA, castB, masks.maskHi);
 
   // Cast back to accumulator type
   auto newVecA =
       mlir::vector::ShapeCastOp::create(rewriter, loc, accTy, shuffledLo);
-
   auto newVecB =
       mlir::vector::ShapeCastOp::create(rewriter, loc, accTy, shuffledHi);
 
@@ -326,49 +335,6 @@ void shuffleBeforeWriteLikeOp(mlir::PatternRewriter &rewriter,
   opB->setOperand(0, newVecB.getResult());
 }
 
-void shuffleNonUnitDimOperand(mlir::PatternRewriter &rewriter,
-                              mlir::Operation *opA, mlir::Operation *opB,
-                              mlir::vector::ContractionOp contractA,
-                              mlir::vector::ContractionOp contractB,
-                              int64_t nonUnitDimAcc, mlir::VectorType Ty) {
-  mlir::Operation *insertAfter = opA->isBeforeInBlock(opB) ? opB : opA;
-
-  rewriter.setInsertionPointAfter(insertAfter);
-  mlir::Location loc = insertAfter->getLoc();
-
-  auto elemTy = Ty.getElementType();
-  auto flatTy = mlir::VectorType::get(nonUnitDimAcc, elemTy);
-
-  auto castA = mlir::vector::ShapeCastOp::create(rewriter, loc, flatTy,
-                                                 opA->getResult(0));
-
-  auto castB = mlir::vector::ShapeCastOp::create(rewriter, loc, flatTy,
-                                                 opB->getResult(0));
-
-  static constexpr int64_t maskLo[] = {
-      0,  32, 1,  33, 2,  34, 3,  35, 8,  40, 9,  41, 10, 42, 11, 43,
-      16, 48, 17, 49, 18, 50, 19, 51, 24, 56, 25, 57, 26, 58, 27, 59};
-  static constexpr int64_t maskHi[] = {
-      4,  36, 5,  37, 6,  38, 7,  39, 12, 44, 13, 45, 14, 46, 15, 47,
-      20, 52, 21, 53, 22, 54, 23, 55, 28, 60, 29, 61, 30, 62, 31, 63};
-
-  auto shuffleLo = mlir::vector::ShuffleOp::create(rewriter, loc, flatTy, castA,
-                                                   castB, maskLo);
-
-  auto shuffleHi = mlir::vector::ShuffleOp::create(rewriter, loc, flatTy, castA,
-                                                   castB, maskHi);
-
-  auto newAccA =
-      mlir::vector::ShapeCastOp::create(rewriter, loc, Ty, shuffleLo);
-
-  auto newAccB =
-      mlir::vector::ShapeCastOp::create(rewriter, loc, Ty, shuffleHi);
-
-  rewriteUses(opA->getResult(0), newAccA.getResult(), contractA, rewriter);
-
-  rewriteUses(opB->getResult(0), newAccB.getResult(), contractB, rewriter);
-}
-
 // Return true if vector.contract operations matches on below conditions:
 //  (1) - the unitDim operand Lhs or Rhs should be same,
 //  (2) - the defining source memref should be same for nonUnitDim

>From 4bd22fa7a57f1302b41e04c2be9476510f909356 Mon Sep 17 00:00:00 2001
From: Arun Thangamani <arun.thangamani at intel.com>
Date: Mon, 12 Jan 2026 07:57:47 -0800
Subject: [PATCH 09/17] Comments + formatting and added corener test-cases

---
 mlir/include/mlir/Dialect/X86Vector/Transforms.h | 1 +
 1 file changed, 1 insertion(+)

diff --git a/mlir/include/mlir/Dialect/X86Vector/Transforms.h b/mlir/include/mlir/Dialect/X86Vector/Transforms.h
index 538ffaac79998..e07fb4aedf539 100644
--- a/mlir/include/mlir/Dialect/X86Vector/Transforms.h
+++ b/mlir/include/mlir/Dialect/X86Vector/Transforms.h
@@ -100,6 +100,7 @@ void populateVectorContractBF16ToFMAPatterns(RewritePatternSet &patterns);
 // range by placing them at their earliest legal use site.
 void populateSinkVectorProducerOpsPatterns(RewritePatternSet &patterns);
 
+// Shuffle the output of BF16 type flat layout vector.contract operations.
 void populateShuffleBF16VectorContractResultPatterns(
     RewritePatternSet &patterns);
 

>From a7c7e6d3efcdf52a39b58856eeabbc2f882517d3 Mon Sep 17 00:00:00 2001
From: Arun Thangamani <arun.thangamani at intel.com>
Date: Sun, 18 Jan 2026 21:52:04 -0800
Subject: [PATCH 10/17] typo fix in comments.

---
 .../lib/Dialect/X86Vector/Utils/X86VectorUtils.cpp | 14 +++++++-------
 1 file changed, 7 insertions(+), 7 deletions(-)

diff --git a/mlir/lib/Dialect/X86Vector/Utils/X86VectorUtils.cpp b/mlir/lib/Dialect/X86Vector/Utils/X86VectorUtils.cpp
index ac7d70f07c377..eceaebbf8be27 100644
--- a/mlir/lib/Dialect/X86Vector/Utils/X86VectorUtils.cpp
+++ b/mlir/lib/Dialect/X86Vector/Utils/X86VectorUtils.cpp
@@ -234,12 +234,12 @@ Operation *traceToVectorWriteLikeUserOperation(Value v) {
   return nullptr;
 }
 
-static void rewriteUses(mlir::Value oldVal, mlir::Value newVal,
-                        mlir::Operation *targetContract,
-                        mlir::PatternRewriter &rewriter) {
+// TODO: replace all use with the packed value along with contration
+// and for op.
+static void rewriteUses(mlir::Value oldVal, mlir::Value newVal) {
   for (mlir::OpOperand &use : llvm::make_early_inc_range(oldVal.getUses())) {
-
     mlir::Operation *user = use.getOwner();
+
     if (mlir::isa<mlir::vector::ContractionOp>(user) ||
         mlir::isa<mlir::scf::ForOp>(user)) {
       use.set(newVal);
@@ -280,11 +280,11 @@ void shuffleAfterReadLikeOp(mlir::PatternRewriter &rewriter,
   auto newAccB =
       mlir::vector::ShapeCastOp::create(rewriter, loc, accTy, shuffleHi);
 
-  rewriteUses(opA->getResult(0), newAccA.getResult(), contractA, rewriter);
-  rewriteUses(opB->getResult(0), newAccB.getResult(), contractB, rewriter);
+  rewriteUses(opA->getResult(0), newAccA.getResult());
+  rewriteUses(opB->getResult(0), newAccB.getResult());
 }
 
-// This function shuffles the vectors written by vector.ocntract operation
+// This function shuffles the vectors written by vector.contract operation
 // as a flat layout structure before they are stored.
 void shuffleBeforeWriteLikeOp(mlir::PatternRewriter &rewriter,
                               mlir::Operation *opA, mlir::Operation *opB,

>From 4051ef576a41543e600ce241833da923d2f02836 Mon Sep 17 00:00:00 2001
From: Arun Thangamani <arun.thangamani at intel.com>
Date: Sat, 31 Jan 2026 22:08:42 -0800
Subject: [PATCH 11/17] modified a test-case with unary operator before the
 transfr read.

---
 mlir/lib/Dialect/X86Vector/Utils/X86VectorUtils.cpp        | 3 +--
 mlir/test/Dialect/shuffle-bf16-vector-contract-result.mlir | 3 ++-
 2 files changed, 3 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Dialect/X86Vector/Utils/X86VectorUtils.cpp b/mlir/lib/Dialect/X86Vector/Utils/X86VectorUtils.cpp
index eceaebbf8be27..74585fe1bdc5b 100644
--- a/mlir/lib/Dialect/X86Vector/Utils/X86VectorUtils.cpp
+++ b/mlir/lib/Dialect/X86Vector/Utils/X86VectorUtils.cpp
@@ -155,10 +155,9 @@ Operation *traceToVectorReadLikeParentOperation(Value v) {
         return nullptr;
       }
 
-      // Continue tracing (accumulators usually forward the value)
       if (defOp->getNumOperands() == 1) {
         v = defOp->getOperand(0);
-        continue;
+        return defOp;
       }
 
       return nullptr;
diff --git a/mlir/test/Dialect/shuffle-bf16-vector-contract-result.mlir b/mlir/test/Dialect/shuffle-bf16-vector-contract-result.mlir
index 032d22154b9fb..0a02fc8e2e659 100644
--- a/mlir/test/Dialect/shuffle-bf16-vector-contract-result.mlir
+++ b/mlir/test/Dialect/shuffle-bf16-vector-contract-result.mlir
@@ -23,6 +23,7 @@ func.func @shuffle_VC_output_flat_layout(
         !memrefB, !vecB
   %4 = vector.load %arg2[%c0, %c0] :
         !memrefC, !vecC
+  %sqrt = math.sqrt %4 : !vecC 
   %5 = vector.load %arg2[%c0, %c8] :
         !memrefC, !vecC
 
@@ -30,7 +31,7 @@ func.func @shuffle_VC_output_flat_layout(
     indexing_maps = [#map, #map1, #map2],
     iterator_types = ["reduction", "parallel", "parallel", "reduction"],
     kind = #vector.kind<add>}
-    %1, %2, %4
+    %1, %2, %sqrt
     : !vecA, !vecB into !vecC
 
   %7 = vector.contract {

>From df91b7a8c827445f400a684bb6b6764a1ec6b9da Mon Sep 17 00:00:00 2001
From: Arun Thangamani <arun.thangamani at intel.com>
Date: Tue, 10 Feb 2026 22:00:16 -0800
Subject: [PATCH 12/17] removed shuffling vc output as a standalone pass.
 Merged with vc lowering.

---
 .../TransformOps/X86VectorTransformOps.td     |  11 -
 .../mlir/Dialect/X86Vector/Transforms.h       |   4 -
 .../TransformOps/X86VectorTransformOps.cpp    |   5 -
 .../X86Vector/Transforms/CMakeLists.txt       |   1 -
 .../ShuffleBF16VectorContractResult.cpp       | 192 -------
 .../Transforms/VectorContractBF16ToFMA.cpp    |  54 +-
 .../VectorContractToPackedTypeDotProduct.cpp  | 123 +++--
 .../X86Vector/Utils/X86VectorUtils.cpp        |  15 +
 .../vector-contract-bf16-to-fma.mlir          |   6 -
 ...or-contract-to-packed-type-dotproduct.mlir | 275 +++++++++-
 .../shuffle-bf16-vector-contract-result.mlir  | 468 ------------------
 11 files changed, 429 insertions(+), 725 deletions(-)
 delete mode 100644 mlir/lib/Dialect/X86Vector/Transforms/ShuffleBF16VectorContractResult.cpp
 delete mode 100644 mlir/test/Dialect/shuffle-bf16-vector-contract-result.mlir

diff --git a/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td b/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td
index 00c611a9f3a7a..3c73eadf82167 100644
--- a/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td
+++ b/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td
@@ -60,17 +60,6 @@ def ApplySinkVectorProducerOpsPatternsOp : Op<Transform_Dialect,
   let assemblyFormat = "attr-dict";
 }
 
-def ApplyShuffleBF16VectorContractResultPatternsOp : Op<Transform_Dialect,
-    "apply_patterns.x86vector.shuffle_bf16_vector_contract_result",
-    [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
-  let description = [{
-    Collect patterns to shuffle results of flat layout BF16 type 
-       vector.contract operations.
-  }];
-
-  let assemblyFormat = "attr-dict";
-}
-
 
 #endif // X86VECTOR_TRANSFORM_OPS
 
diff --git a/mlir/include/mlir/Dialect/X86Vector/Transforms.h b/mlir/include/mlir/Dialect/X86Vector/Transforms.h
index e07fb4aedf539..c25cdaf2d9428 100644
--- a/mlir/include/mlir/Dialect/X86Vector/Transforms.h
+++ b/mlir/include/mlir/Dialect/X86Vector/Transforms.h
@@ -100,10 +100,6 @@ void populateVectorContractBF16ToFMAPatterns(RewritePatternSet &patterns);
 // range by placing them at their earliest legal use site.
 void populateSinkVectorProducerOpsPatterns(RewritePatternSet &patterns);
 
-// Shuffle the output of BF16 type flat layout vector.contract operations.
-void populateShuffleBF16VectorContractResultPatterns(
-    RewritePatternSet &patterns);
-
 //===----------------------------------------------------------------------===//
 /// Helpers extracted from:
 ///   - clang/lib/Headers/avxintrin.h
diff --git a/mlir/lib/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp b/mlir/lib/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp
index e40ddd3a4b1c0..e77d30c9c5ffb 100644
--- a/mlir/lib/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp
+++ b/mlir/lib/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp
@@ -42,11 +42,6 @@ void mlir::transform::ApplySinkVectorProducerOpsPatternsOp::populatePatterns(
   x86vector::populateSinkVectorProducerOpsPatterns(patterns);
 }
 
-void mlir::transform::ApplyShuffleBF16VectorContractResultPatternsOp::
-    populatePatterns(RewritePatternSet &patterns) {
-  x86vector::populateShuffleBF16VectorContractResultPatterns(patterns);
-}
-
 //===----------------------------------------------------------------------===//
 // Transform op registration
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/X86Vector/Transforms/CMakeLists.txt b/mlir/lib/Dialect/X86Vector/Transforms/CMakeLists.txt
index acbc7fcfb635e..bbd9be880eb0a 100644
--- a/mlir/lib/Dialect/X86Vector/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/X86Vector/Transforms/CMakeLists.txt
@@ -5,7 +5,6 @@ add_mlir_dialect_library(MLIRX86VectorTransforms
   VectorContractToPackedTypeDotProduct.cpp
   VectorContractBF16ToFMA.cpp
   SinkVectorProducerOps.cpp
-  ShuffleBF16VectorContractResult.cpp
 
   LINK_LIBS PUBLIC
   MLIRArithDialect
diff --git a/mlir/lib/Dialect/X86Vector/Transforms/ShuffleBF16VectorContractResult.cpp b/mlir/lib/Dialect/X86Vector/Transforms/ShuffleBF16VectorContractResult.cpp
deleted file mode 100644
index 24b8a7489dbfa..0000000000000
--- a/mlir/lib/Dialect/X86Vector/Transforms/ShuffleBF16VectorContractResult.cpp
+++ /dev/null
@@ -1,192 +0,0 @@
-//===- ShuffleBF16VectorContractResult.cpp --------------------------------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir/Dialect/MemRef/IR/MemRef.h"
-#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
-#include "mlir/Dialect/Vector/IR/VectorOps.h"
-#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
-#include "mlir/Dialect/X86Vector/Transforms.h"
-#include "mlir/Dialect/X86Vector/Utils/X86VectorUtils.h"
-#include "mlir/Dialect/X86Vector/X86VectorDialect.h"
-
-#include "mlir/IR/BuiltinAttributes.h"
-#include "mlir/IR/Dominance.h"
-#include "mlir/IR/PatternMatch.h"
-
-#include "mlir/Pass/Pass.h"
-#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
-#include "llvm/Support/Casting.h"
-
-using namespace mlir;
-using namespace mlir::vector;
-using namespace mlir::x86vector;
-
-// Shuffle the output of BF16 type flat layout vector.contract operations
-//
-// For example:
-// ```
-//   %1 = vector.load -> vector<1x1xbf16>
-//   %2 = vector.load from memref (%m1) -> vector<1x8xbf16>
-//   %3 = vector.load from memref (%m1) -> vector<1x8xbf16>
-//   %4 = vector.contract %1, %2, %arg0 ->  vector<1x8xf32>
-//   %5 = vector.contract %1, %3, %arg1 ->  vector<1x8xf32>
-//   vector.store %4, %m1
-//   vector.store %5, %m1
-// ```
-// to
-// ```
-//   %1 = vector.load -> vector<1x1xbf16>
-//   %2 = vector.load from memref (%m1) -> vector<1x8xbf16>
-//   %3 = vector.load from memref (%m1) -> vector<1x8xbf16>
-//   %4 = vector.shuffle %arg0, %arg1 [0, 8, 1, 9, 2, 10, 3, 11]
-//   %5 = vector.shuffle %arg0, %arg1 [4, 12, 5, 13, 6, 14, 7, 15]
-//   %6 = vector.contract %1, %2, %4 ->  vector<1x8xf32>
-//   %7 = vector.contract %1, %3, %5 ->  vector<1x8xf32>
-//   %8 = vector.shuffle %6, %7 [0, 8, 1, 9, 2, 10, 3, 11]
-//   %9 = vector.shuffle %6, %7 [4, 12, 5, 13, 6, 14, 7, 15]
-//   vector.store %8, %m1
-//   vector.store %9, %m1
-//```
-struct ShuffleBF16VectorContractResult
-    : public OpRewritePattern<vector::ContractionOp> {
-  using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
-                                PatternRewriter &rewriter) const override {
-
-    if (contractOp.getKind() != vector::CombiningKind::ADD)
-      return rewriter.notifyMatchFailure(contractOp,
-                                         "Expects add combining kind.");
-
-    // TODO: Move this validation to a common utility folder. Planned to
-    // do once (code refactoring), all architecture specific nanokernel
-    // passes are merged into the repo.
-    VectorType lhsTy = contractOp.getLhsType();
-    if (!lhsTy.getElementType().isBF16())
-      return rewriter.notifyMatchFailure(contractOp,
-                                         "Only BF16 lowering is supported.");
-
-    if (isInVnniLayout(contractOp.getOperation(),
-                       contractOp.getIndexingMapsArray(),
-                       /*blockingFactor=*/2))
-      return rewriter.notifyMatchFailure(contractOp,
-                                         "Input matrices in VNNI format.");
-
-    VectorType accTy = dyn_cast<VectorType>(contractOp.getAccType());
-    if (!accTy)
-      return rewriter.notifyMatchFailure(contractOp, "Wrong accmulator type.");
-
-    if (!accTy.getElementType().isF32())
-      return rewriter.notifyMatchFailure(
-          contractOp, "Only F32 acumulation supported for BF16 type.");
-
-    ArrayRef<int64_t> accShape = accTy.getShape();
-    llvm::SmallVector<int64_t> nonUnitDimAcc;
-    llvm::copy_if(accShape, std::back_inserter(nonUnitDimAcc),
-                  [](int64_t dim) { return dim != 1; });
-
-    if (nonUnitDimAcc.size() != 1)
-      return rewriter.notifyMatchFailure(
-          contractOp, "A or B should be a non-unit dim in acc.");
-
-    int64_t nonUnitDimValue = nonUnitDimAcc.front();
-
-    if (nonUnitDimValue != 8 && nonUnitDimValue != 16)
-      return rewriter.notifyMatchFailure(
-          contractOp, "The accumulator dimension should be 8 or 16");
-
-    ArrayRef<int64_t> lhsShape = lhsTy.getShape();
-    llvm::SmallVector<int64_t> nonUnitDimLhs;
-    llvm::copy_if(lhsShape, std::back_inserter(nonUnitDimLhs),
-                  [](int64_t dim) { return dim != 1; });
-
-    VectorType rhsTy = contractOp.getRhsType();
-    ArrayRef<int64_t> rhsShape = rhsTy.getShape();
-    llvm::SmallVector<int64_t> nonUnitDimRhs;
-    llvm::copy_if(rhsShape, std::back_inserter(nonUnitDimRhs),
-                  [](int64_t dim) { return dim != 1; });
-
-    vector::ContractionOp pairContractOp;
-    bool rhsHasMultipleNonUnitDims =
-        nonUnitDimRhs.size() > nonUnitDimLhs.size();
-
-    // Get the pair vector.contract operation. The pair is decided on:
-    //  (1) - the unitDim operand Lhs or Rhs should be same,
-    //  (2) - the defining source memref should be same for nonUnitDim
-    //  operation, (3) - the nonUnit dim offset difference between the
-    //  vector.contracts should be 8.
-    Operation *nextOp = contractOp;
-    while ((nextOp = nextOp->getNextNode())) {
-      auto contOp = dyn_cast<vector::ContractionOp>(nextOp);
-
-      if (!contOp)
-        continue;
-
-      if (validatePairVectorContract(
-              contractOp, contOp, rhsHasMultipleNonUnitDims, nonUnitDimValue)) {
-        pairContractOp = contOp;
-        break;
-      }
-    }
-
-    if (!pairContractOp)
-      return rewriter.notifyMatchFailure(
-          contractOp, "Coudn't find pair contract operation for shuffling");
-
-    // Trace back to the load or transfer_read operations of the contract
-    // accumulators.
-    Operation *accReadOp0 =
-        traceToVectorReadLikeParentOperation(contractOp.getAcc());
-    Operation *accReadOp1 =
-        traceToVectorReadLikeParentOperation(pairContractOp.getAcc());
-
-    // Iterate dowm to find the users of contact operations until it is store or
-    // transfer_write.
-    Operation *resultWriteOp0 =
-        traceToVectorWriteLikeUserOperation(contractOp.getResult());
-    Operation *resultWriteOp1 =
-        traceToVectorWriteLikeUserOperation(pairContractOp.getResult());
-
-    if (!accReadOp0 || !accReadOp1)
-      return rewriter.notifyMatchFailure(
-          contractOp,
-          "Operands doesn't have load or transfer_read as it's parent op");
-
-    if (!resultWriteOp0 || !resultWriteOp1)
-      return rewriter.notifyMatchFailure(
-          contractOp, "The use of contract operations are neither vector.store "
-                      "or transfer_write");
-
-    if (contractOp->getBlock() == accReadOp1->getBlock() &&
-        contractOp->isBeforeInBlock(accReadOp1))
-      return rewriter.notifyMatchFailure(
-          contractOp, "The load/read operation of pair contract operation is "
-                      "after the contractOp");
-
-    if (pairContractOp->getBlock() == resultWriteOp0->getBlock() &&
-        resultWriteOp0->isBeforeInBlock(pairContractOp))
-      return rewriter.notifyMatchFailure(
-          contractOp, "The store/write operation of contract operation is "
-                      "before the pair contract operation");
-
-    // Shuffle the accumulators of the contract operations.
-    shuffleAfterReadLikeOp(rewriter, accReadOp0, accReadOp1, contractOp,
-                           pairContractOp, nonUnitDimValue, accTy);
-
-    // Shuffle the output of contract operations before it's use.
-    shuffleBeforeWriteLikeOp(rewriter, resultWriteOp0, resultWriteOp1,
-                             nonUnitDimValue, accTy);
-
-    return success();
-  }
-};
-
-void x86vector::populateShuffleBF16VectorContractResultPatterns(
-    RewritePatternSet &patterns) {
-  patterns.add<ShuffleBF16VectorContractResult>(patterns.getContext());
-}
diff --git a/mlir/lib/Dialect/X86Vector/Transforms/VectorContractBF16ToFMA.cpp b/mlir/lib/Dialect/X86Vector/Transforms/VectorContractBF16ToFMA.cpp
index eada03977595d..798e31f72ee97 100644
--- a/mlir/lib/Dialect/X86Vector/Transforms/VectorContractBF16ToFMA.cpp
+++ b/mlir/lib/Dialect/X86Vector/Transforms/VectorContractBF16ToFMA.cpp
@@ -293,7 +293,6 @@ struct VectorContractBF16ToFMA
 
     // Non-unit dimensions should match the vector length of BF16.
     unsigned int nonUnitDim = nonUnitDimAcc.front();
-
     if (nonUnitDim != 4 && nonUnitDim != 8)
       return rewriter.notifyMatchFailure(
           contractOp, "BF16 packed load operation expects non-unit (LHR or "
@@ -374,6 +373,55 @@ struct VectorContractBF16ToFMA
         VectorType::get(nonUnitDimAcc.front(), rewriter.getF32Type());
 
     if (!isVnni) {
+
+      // Validate and shuffle the accumulator
+      Operation *accReadOp0 =
+          traceToVectorReadLikeParentOperation(contractOp.getAcc());
+      Operation *accReadOp1 =
+          traceToVectorReadLikeParentOperation(pairContractOp.getAcc());
+
+      // Iterate dowm to find the users of contact operations until it is store
+      // or transfer_write.
+      Operation *resultWriteOp0 =
+          traceToVectorWriteLikeUserOperation(contractOp.getResult());
+      Operation *resultWriteOp1 =
+          traceToVectorWriteLikeUserOperation(pairContractOp.getResult());
+
+      if (!accReadOp0 || !accReadOp1)
+        return rewriter.notifyMatchFailure(
+            contractOp,
+            "Operands doesn't have load or transfer_read as it's parent op");
+
+      if (!resultWriteOp0 || !resultWriteOp1)
+        return rewriter.notifyMatchFailure(
+            contractOp,
+            "The use of contract operations are neither vector.store "
+            "or transfer_write");
+
+      if (contractOp->getBlock() == accReadOp1->getBlock() &&
+          contractOp->isBeforeInBlock(accReadOp1))
+        return rewriter.notifyMatchFailure(
+            contractOp, "The load/read operation of pair contract operation is "
+                        "after the contractOp");
+
+      if (pairContractOp->getBlock() == resultWriteOp0->getBlock() &&
+          resultWriteOp0->isBeforeInBlock(pairContractOp)) {
+        return rewriter.notifyMatchFailure(
+            contractOp, "The store/write operation of contract operation is "
+                        "before the pair contract operation");
+      }
+
+      // Shuffle the accumulators of the contract operations.
+      shuffleAfterReadLikeOp(rewriter, accReadOp0, accReadOp1, contractOp,
+                             pairContractOp, nonUnitDim, accTy);
+
+      rewriter.setInsertionPoint(contractOp);
+
+      castAcc = vector::ShapeCastOp::create(
+          rewriter, loc,
+          VectorType::get(nonUnitDimAcc.front(), accTy.getElementType()),
+          contractOp.getAcc());
+
       auto loadBcstBF16ElementToF32 = x86vector::BcstToPackedF32Op::create(
           rewriter, loc, dstType, unitDimSubview[0]);
       auto loadEvenIdxElementF32 =
@@ -405,6 +453,10 @@ struct VectorContractBF16ToFMA
                                                     accTyPairCont, oddIdxFMA);
       rewriter.replaceOp(pairContractOp, castOddFma);
 
+      // Shuffle the output of contract operations before it's use.
+      shuffleBeforeWriteLikeOp(rewriter, resultWriteOp0, resultWriteOp1,
+                               nonUnitDim, accTy);
+
       return success();
     }
 
diff --git a/mlir/lib/Dialect/X86Vector/Transforms/VectorContractToPackedTypeDotProduct.cpp b/mlir/lib/Dialect/X86Vector/Transforms/VectorContractToPackedTypeDotProduct.cpp
index 80895c512b32b..e67a1cde8b3d3 100644
--- a/mlir/lib/Dialect/X86Vector/Transforms/VectorContractToPackedTypeDotProduct.cpp
+++ b/mlir/lib/Dialect/X86Vector/Transforms/VectorContractToPackedTypeDotProduct.cpp
@@ -207,7 +207,7 @@ struct VectorContractToPackedTypeDotProduct
 
     bool rhsHasMultipleNonUnitDims = (nonUnitDimRhs.size() - 1) > 0;
     int64_t extraFlatDim = rhsHasMultipleNonUnitDims ? nonUnitDimLhs.front()
-                                                     : nonUnitDimRhs.size();
+                                                     : nonUnitDimRhs.front();
 
     if (!isVnni && (extraFlatDim != blockingFactor))
       return rewriter.notifyMatchFailure(
@@ -227,7 +227,7 @@ struct VectorContractToPackedTypeDotProduct
 
     // If the A or B matrix vector of the contact operation is not packed, then
     // find it's pair contract operation and pack (shuffle) them to VNNI packed.
-    if (!isVnni && !isNonUnitDimOperandShuffled(nonUnitDimOperand)) {
+    if (!isVnni) {
       vector::ContractionOp pairContractOp;
       Operation *nextOp = contractOp;
       while ((nextOp = nextOp->getNextNode())) {
@@ -244,40 +244,99 @@ struct VectorContractToPackedTypeDotProduct
         }
       }
 
-      if (!pairContractOp)
+      // If the accumulators are shuffled we get nullptr else the 
+      // transfer_read or load operations.
+      Operation *accRead =
+          traceToVectorReadLikeParentOperation(contractOp.getAcc());
+
+      if (!pairContractOp &&
+          (!isNonUnitDimOperandShuffled(nonUnitDimOperand) || accRead))
         return rewriter.notifyMatchFailure(contractOp,
                                            "Could not find a contract pair");
 
-      Value nonUnitDimOperandPairContract = rhsHasMultipleNonUnitDims
-                                                ? pairContractOp.getRhs()
-                                                : pairContractOp.getLhs();
-
-      // Get the non-packed A or B matrix's vector<32xbf16> elements.
-      Operation *nonUnitDimReadOp =
-          traceToVectorReadLikeParentOperation(nonUnitDimOperand);
-      Operation *nonUnitDimReadOpPairContract =
-          traceToVectorReadLikeParentOperation(nonUnitDimOperandPairContract);
-
-      if (!nonUnitDimReadOp || !nonUnitDimReadOpPairContract)
-        return rewriter.notifyMatchFailure(
-            contractOp, "Could not find a valid contract pair");
-
-      if (contractOp->getBlock() == nonUnitDimReadOpPairContract->getBlock() &&
-          contractOp->isBeforeInBlock(nonUnitDimReadOpPairContract))
-        return rewriter.notifyMatchFailure(
-            contractOp, "The load/read operation of pair contract operation is "
-                        "after the contractOp");
-
-      VectorType nonUnitDimTy = rhsHasMultipleNonUnitDims
-                                    ? contractOp.getRhsType()
-                                    : contractOp.getLhsType();
-
-      packNonUnitDimOperandToVNNI(
-          rewriter, nonUnitDimReadOp, nonUnitDimReadOpPairContract, contractOp,
-          pairContractOp, blockingFactor * nonUnitDimValue, nonUnitDimTy);
+      if (!isNonUnitDimOperandShuffled(nonUnitDimOperand)) {
+        Value nonUnitDimOperandPairContract = rhsHasMultipleNonUnitDims
+                                                  ? pairContractOp.getRhs()
+                                                  : pairContractOp.getLhs();
+
+        // Get the non-packed A or B matrix's vector<32xbf16> elements.
+        Operation *nonUnitDimReadOp =
+            traceToVectorReadLikeParentOperation(nonUnitDimOperand);
+        Operation *nonUnitDimReadOpPairContract =
+            traceToVectorReadLikeParentOperation(nonUnitDimOperandPairContract);
+
+        if (!nonUnitDimReadOp || !nonUnitDimReadOpPairContract)
+          return rewriter.notifyMatchFailure(
+              contractOp, "Could not find a valid contract pair");
+
+        if (contractOp->getBlock() ==
+                nonUnitDimReadOpPairContract->getBlock() &&
+            contractOp->isBeforeInBlock(nonUnitDimReadOpPairContract))
+          return rewriter.notifyMatchFailure(
+              contractOp,
+              "The load/read operation of pair contract operation is "
+              "after the contractOp");
+
+        VectorType nonUnitDimTy = rhsHasMultipleNonUnitDims
+                                      ? contractOp.getRhsType()
+                                      : contractOp.getLhsType();
+
+        packNonUnitDimOperandToVNNI(
+            rewriter, nonUnitDimReadOp, nonUnitDimReadOpPairContract,
+            contractOp, pairContractOp, blockingFactor * nonUnitDimValue,
+            nonUnitDimTy);
+
+        nonUnitDimOperand = rhsHasMultipleNonUnitDims ? contractOp.getRhs()
+                                                      : contractOp.getLhs();
+      }
 
-      nonUnitDimOperand =
-          rhsHasMultipleNonUnitDims ? contractOp.getRhs() : contractOp.getLhs();
+      // Validate and shuffle the accumulator 
+      if (accRead) {
+        // Trace back to the load or transfer_read operations of the contract
+        // accumulators.
+        Operation *accReadOp0 =
+            traceToVectorReadLikeParentOperation(contractOp.getAcc());
+        Operation *accReadOp1 =
+            traceToVectorReadLikeParentOperation(pairContractOp.getAcc());
+
+        // Iterate dowm to find the users of contact operations until it is
+        // store or transfer_write.
+        Operation *resultWriteOp0 =
+            traceToVectorWriteLikeUserOperation(contractOp.getResult());
+        Operation *resultWriteOp1 =
+            traceToVectorWriteLikeUserOperation(pairContractOp.getResult());
+
+        if (!accReadOp0 || !accReadOp1)
+          return rewriter.notifyMatchFailure(
+              contractOp,
+              "Operands doesn't have load or transfer_read as it's parent op");
+
+        if (!resultWriteOp0 || !resultWriteOp1)
+          return rewriter.notifyMatchFailure(
+              contractOp,
+              "The use of contract operations are neither vector.store "
+              "or transfer_write");
+
+        if (contractOp->getBlock() == accReadOp1->getBlock() &&
+            contractOp->isBeforeInBlock(accReadOp1))
+          return rewriter.notifyMatchFailure(
+              contractOp,
+              "The load/read operation of pair contract operation is "
+              "after the contractOp");
+
+        if (pairContractOp->getBlock() == resultWriteOp0->getBlock() &&
+            resultWriteOp0->isBeforeInBlock(pairContractOp))
+          return rewriter.notifyMatchFailure(
+              contractOp, "The store/write operation of contract operation is "
+                          "before the pair contract operation");
+        // Shuffle the accumulators of the contract operations.
+        shuffleAfterReadLikeOp(rewriter, accReadOp0, accReadOp1, contractOp,
+                               pairContractOp, nonUnitDimValue, accTy);
+
+        // Shuffle the output of contract operations before it's use.
+        shuffleBeforeWriteLikeOp(rewriter, resultWriteOp0, resultWriteOp1,
+                                 nonUnitDimValue, accTy);
+      }
     }
 
     rewriter.setInsertionPoint(contractOp);
diff --git a/mlir/lib/Dialect/X86Vector/Utils/X86VectorUtils.cpp b/mlir/lib/Dialect/X86Vector/Utils/X86VectorUtils.cpp
index 74585fe1bdc5b..9ce16d9ed91bc 100644
--- a/mlir/lib/Dialect/X86Vector/Utils/X86VectorUtils.cpp
+++ b/mlir/lib/Dialect/X86Vector/Utils/X86VectorUtils.cpp
@@ -364,6 +364,10 @@ bool validatePairVectorContract(vector::ContractionOp contractOp,
         srcBuff = readOp.getOperand(0);
         indexVals = SmallVector<OpFoldResult>(readOp.getIndices().begin(),
                                               readOp.getIndices().end());
+      })
+      .Case<vector::ShapeCastOp>([&](vector::ShapeCastOp op) {
+        srcBuff = op.getSource();
+        indexVals.clear();
       });
 
   Value srcBuffPairContOp;
@@ -373,11 +377,22 @@ bool validatePairVectorContract(vector::ContractionOp contractOp,
         srcBuffPairContOp = readOp.getOperand(0);
         indexValsPairContOp = SmallVector<OpFoldResult>(
             readOp.getIndices().begin(), readOp.getIndices().end());
+      })
+      .Case<vector::ShapeCastOp>([&](vector::ShapeCastOp op) {
+        srcBuffPairContOp = op.getSource();
+        indexVals.clear();
       });
 
   if (!srcBuff || !srcBuffPairContOp)
     return false;
 
+  auto shuffleLw = srcBuff.getDefiningOp<vector::ShuffleOp>();
+  auto shuffleHw = srcBuffPairContOp.getDefiningOp<vector::ShuffleOp>();
+
+  if (shuffleLw && shuffleHw)
+    return shuffleLw.getV1() == shuffleHw.getV1() &&
+           shuffleLw.getV2() == shuffleHw.getV2();
+
   if (!(srcBuff == srcBuffPairContOp))
     return false;
 
diff --git a/mlir/test/Dialect/X86Vector/vector-contract-bf16-to-fma.mlir b/mlir/test/Dialect/X86Vector/vector-contract-bf16-to-fma.mlir
index 05404fa1ef6a5..a29ffda283b3c 100644
--- a/mlir/test/Dialect/X86Vector/vector-contract-bf16-to-fma.mlir
+++ b/mlir/test/Dialect/X86Vector/vector-contract-bf16-to-fma.mlir
@@ -297,7 +297,6 @@ module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
     %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
     transform.apply_patterns to %func {
-      transform.apply_patterns.x86vector.shuffle_bf16_vector_contract_result
       transform.apply_patterns.x86vector.vector_contract_bf16_to_fma
     } : !transform.any_op
     transform.yield
@@ -370,7 +369,6 @@ module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
     %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
     transform.apply_patterns to %func {
-      transform.apply_patterns.x86vector.shuffle_bf16_vector_contract_result
       transform.apply_patterns.x86vector.vector_contract_bf16_to_fma
     } : !transform.any_op
     transform.yield
@@ -446,7 +444,6 @@ module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
     %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
     transform.apply_patterns to %func {
-      transform.apply_patterns.x86vector.shuffle_bf16_vector_contract_result
       transform.apply_patterns.x86vector.vector_contract_bf16_to_fma
     } : !transform.any_op
     transform.yield
@@ -759,7 +756,6 @@ module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
     %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
     transform.apply_patterns to %func {
-      transform.apply_patterns.x86vector.shuffle_bf16_vector_contract_result
       transform.apply_patterns.x86vector.vector_contract_bf16_to_fma
     } : !transform.any_op
     transform.yield
@@ -829,7 +825,6 @@ module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
     %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
     transform.apply_patterns to %func {
-      transform.apply_patterns.x86vector.shuffle_bf16_vector_contract_result
       transform.apply_patterns.x86vector.vector_contract_bf16_to_fma
     } : !transform.any_op
     transform.yield
@@ -899,7 +894,6 @@ module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
     %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
     transform.apply_patterns to %func {
-      transform.apply_patterns.x86vector.shuffle_bf16_vector_contract_result
       transform.apply_patterns.x86vector.vector_contract_bf16_to_fma
     } : !transform.any_op
     transform.yield
diff --git a/mlir/test/Dialect/X86Vector/vector-contract-to-packed-type-dotproduct.mlir b/mlir/test/Dialect/X86Vector/vector-contract-to-packed-type-dotproduct.mlir
index 4026bf9dee912..fd45daa44f59c 100644
--- a/mlir/test/Dialect/X86Vector/vector-contract-to-packed-type-dotproduct.mlir
+++ b/mlir/test/Dialect/X86Vector/vector-contract-to-packed-type-dotproduct.mlir
@@ -335,7 +335,6 @@ module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
     %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
     transform.apply_patterns to %func {
-      transform.apply_patterns.x86vector.shuffle_bf16_vector_contract_result
       transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product
     } : !transform.any_op
     transform.yield
@@ -851,7 +850,6 @@ module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
     %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
     transform.apply_patterns to %func {
-      transform.apply_patterns.x86vector.shuffle_bf16_vector_contract_result
       transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product
     } : !transform.any_op
     transform.yield
@@ -916,7 +914,6 @@ module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
     %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
     transform.apply_patterns to %func {
-      transform.apply_patterns.x86vector.shuffle_bf16_vector_contract_result
       transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product
     } : !transform.any_op
     transform.yield
@@ -980,7 +977,6 @@ module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
     %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
     transform.apply_patterns to %func {
-      transform.apply_patterns.x86vector.shuffle_bf16_vector_contract_result
       transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product
     } : !transform.any_op
     transform.yield
@@ -1045,7 +1041,276 @@ module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
     %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
     transform.apply_patterns to %func {
-      transform.apply_patterns.x86vector.shuffle_bf16_vector_contract_result
+      transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product
+    } : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+!vecA = vector<1x1xbf16>
+!vecB = vector<1x16xbf16>
+!vecC = vector<1x16xf32>
+!memrefA = memref<4x2xbf16>
+!memrefB = memref<2x64xbf16>
+!memrefC = memref<2x64xf32>
+#map = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0,  d1, d2) -> (d2, d1)>
+#map2 = affine_map<(d0,  d1, d2) -> (d0, d1)>
+func.func @negative_contracts_not_in_order(
+  %arg0: !memrefA, %arg1: !memrefB, %arg2: !memrefC) -> !memrefC
+{
+  %c0 = arith.constant 0 : index
+  %c16 = arith.constant 16 : index
+  %c32 = arith.constant 32 : index
+  %0 = ub.poison : bf16
+  %32 = ub.poison : f32
+  %1 = vector.load %arg0[%c0, %c0] :
+        !memrefA, !vecA
+  %2 = vector.load %arg1[%c0, %c0] :
+        !memrefB, !vecB
+  %3 = vector.load %arg1[%c0, %c32] :
+        !memrefB, !vecB
+  %4 = vector.load %arg2[%c0, %c0] :
+        !memrefC, !vecC
+  %5 = vector.load %arg2[%c0, %c16] :
+        !memrefC, !vecC
+
+  %7 = vector.contract {
+    indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["parallel", "parallel", "reduction"],
+    kind = #vector.kind<add>}
+    %1, %3, %5
+    : !vecA, !vecB into !vecC
+
+  %6 = vector.contract {
+    indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["parallel", "parallel", "reduction"],
+    kind = #vector.kind<add>}
+    %1, %2, %4
+    : !vecA, !vecB into !vecC
+
+  vector.store %6, %arg2[%c0, %c0] : !memrefC, !vecC
+  vector.store %7, %arg2[%c0, %c16] : !memrefC, !vecC
+
+  return %arg2 : !memrefC
+}
+
+// CHECK-LABEL: @negative_contracts_not_in_order
+// CHECK-NOT: vector.shuffle
+// CHECK-NOT: vector.shuffle
+// CHECK: vector.contract
+// CHECK-NOT: vector.shuffle
+// CHECK-NOT: vector.shuffle
+// CHECK: vector.store
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product
+    } : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+!vecA = vector<1x1xbf16>
+!vecB = vector<1x32xbf16>
+!vecC = vector<1x32xf32>
+!memrefA = memref<4x2xbf16>
+!memrefB = memref<2x64xbf16>
+!memrefC = memref<2x64xf32>
+#map = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0,  d1, d2) -> (d2, d1)>
+#map2 = affine_map<(d0,  d1, d2) -> (d0, d1)>
+func.func @negative_dim_is_32(
+  %arg0: !memrefA, %arg1: !memrefB, %arg2: !memrefC) -> !memrefC
+{
+  %c0 = arith.constant 0 : index
+  %c32 = arith.constant 32 : index
+  %0 = ub.poison : bf16
+  %32 = ub.poison : f32
+  %1 = vector.load %arg0[%c0, %c0] :
+        !memrefA, !vecA
+  %2 = vector.load %arg1[%c0, %c0] :
+        !memrefB, !vecB
+  %3 = vector.load %arg1[%c0, %c32] :
+        !memrefB, !vecB
+  %4 = vector.load %arg2[%c0, %c0] :
+        !memrefC, !vecC
+  %5 = vector.load %arg2[%c0, %c32] :
+        !memrefC, !vecC
+
+  %6 = vector.contract {
+    indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["parallel", "parallel", "reduction"],
+    kind = #vector.kind<add>}
+    %1, %2, %4
+    : !vecA, !vecB into !vecC
+
+  %7 = vector.contract {
+    indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["parallel", "parallel", "reduction"],
+    kind = #vector.kind<add>}
+    %1, %3, %5
+    : !vecA, !vecB into !vecC
+
+  vector.store %6, %arg2[%c0, %c0] : !memrefC, !vecC
+  vector.store %7, %arg2[%c0, %c32] : !memrefC, !vecC
+
+  return %arg2 : !memrefC
+}
+
+// CHECK-LABEL: @negative_dim_is_32
+// CHECK-NOT: vector.shuffle
+// CHECK-NOT: vector.shuffle
+// CHECK: vector.contract
+// CHECK-NOT: vector.shuffle
+// CHECK-NOT: vector.shuffle
+// CHECK: vector.store
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product
+    } : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+!vecA = vector<1x1xbf16>
+!vecB = vector<1x16xbf16>
+!vecC = vector<1x16xf32>
+!memrefA = memref<4x2xbf16>
+!memrefB = memref<2x64xbf16>
+!memrefC = memref<2x64xf32>
+#map = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0,  d1, d2) -> (d2, d1)>
+#map2 = affine_map<(d0,  d1, d2) -> (d0, d1)>
+func.func @negative_offset_diff_is_32(
+  %arg0: !memrefA, %arg1: !memrefB, %arg2: !memrefC) -> !memrefC
+{
+  %c0 = arith.constant 0 : index
+  %c16 = arith.constant 16 : index
+  %c32 = arith.constant 32 : index
+  %0 = ub.poison : bf16
+  %32 = ub.poison : f32
+  %1 = vector.load %arg0[%c0, %c0] :
+        !memrefA, !vecA
+  %2 = vector.load %arg1[%c0, %c0] :
+        !memrefB, !vecB
+  %3 = vector.load %arg1[%c0, %c32] :
+        !memrefB, !vecB
+  %4 = vector.load %arg2[%c0, %c0] :
+        !memrefC, !vecC
+  %5 = vector.load %arg2[%c0, %c16] :
+        !memrefC, !vecC
+
+  %6 = vector.contract {
+    indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["parallel", "parallel", "reduction"],
+    kind = #vector.kind<add>}
+    %1, %2, %4
+    : !vecA, !vecB into !vecC
+
+  %7 = vector.contract {
+    indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["parallel", "parallel", "reduction"],
+    kind = #vector.kind<add>}
+    %1, %3, %5
+    : !vecA, !vecB into !vecC
+
+  vector.store %6, %arg2[%c0, %c0] : !memrefC, !vecC
+  vector.store %7, %arg2[%c0, %c16] : !memrefC, !vecC
+
+  return %arg2 : !memrefC
+}
+
+// CHECK-LABEL: @negative_offset_diff_is_32
+// CHECK-NOT: vector.shuffle
+// CHECK-NOT: vector.shuffle
+// CHECK: vector.contract
+// CHECK-NOT: vector.shuffle
+// CHECK-NOT: vector.shuffle
+// CHECK: vector.store
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product
+    } : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+!vecA = vector<1x1xbf16>
+!vecB = vector<1x16xbf16>
+!vecC = vector<1x16xf32>
+!memrefA = memref<4x2xbf16>
+!memrefB = memref<2x64xbf16>
+!memrefC = memref<2x64xf32>
+#map = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0,  d1, d2) -> (d2, d1)>
+#map2 = affine_map<(d0,  d1, d2) -> (d0, d1)>
+func.func @negative_dynamic_offset(
+  %arg0: !memrefA, %arg1: !memrefB, %arg2: !memrefC, %arg3: index) -> !memrefC
+{
+  %c0 = arith.constant 0 : index
+  %c16 = arith.constant 16 : index
+  %0 = ub.poison : bf16
+  %32 = ub.poison : f32
+  %1 = vector.load %arg0[%c0, %c0] :
+        !memrefA, !vecA
+  %2 = vector.load %arg1[%c0, %c0] :
+        !memrefB, !vecB
+  %3 = vector.load %arg1[%c0, %arg3] :
+        !memrefB, !vecB
+  %4 = vector.load %arg2[%c0, %c0] :
+        !memrefC, !vecC
+  %5 = vector.load %arg2[%c0, %c16] :
+        !memrefC, !vecC
+
+  %6 = vector.contract {
+    indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["parallel", "parallel", "reduction"],
+    kind = #vector.kind<add>}
+    %1, %2, %4
+    : !vecA, !vecB into !vecC
+
+  %7 = vector.contract {
+    indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["parallel", "parallel", "reduction"],
+    kind = #vector.kind<add>}
+    %1, %3, %5
+    : !vecA, !vecB into !vecC
+
+  vector.store %6, %arg2[%c0, %c0] : !memrefC, !vecC
+  vector.store %7, %arg2[%c0, %c16] : !memrefC, !vecC
+
+  return %arg2 : !memrefC
+}
+
+// CHECK-LABEL: @negative_dynamic_offset
+// CHECK-NOT: vector.shuffle
+// CHECK-NOT: vector.shuffle
+// CHECK: vector.contract
+// CHECK-NOT: vector.shuffle
+// CHECK-NOT: vector.shuffle
+// CHECK: vector.store
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
       transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product
     } : !transform.any_op
     transform.yield
diff --git a/mlir/test/Dialect/shuffle-bf16-vector-contract-result.mlir b/mlir/test/Dialect/shuffle-bf16-vector-contract-result.mlir
deleted file mode 100644
index 0a02fc8e2e659..0000000000000
--- a/mlir/test/Dialect/shuffle-bf16-vector-contract-result.mlir
+++ /dev/null
@@ -1,468 +0,0 @@
-// RUN: mlir-opt %s -transform-interpreter -cse -split-input-file | FileCheck %s
-
-!vecA = vector<1x1x1xbf16>
-!vecB = vector<1x1x8xbf16>
-!vecC = vector<1x8xf32>
-!memrefA = memref<1x4x1xbf16>
-!memrefB = memref<1x1x32xbf16>
-!memrefC = memref<2x32xf32>
-#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
-#map1 = affine_map<(d0,  d1, d2, d3) -> (d0, d3, d2)>
-#map2 = affine_map<(d0,  d1, d2, d3) -> (d1, d2)>
-func.func @shuffle_VC_output_flat_layout(
-  %arg0: !memrefA, %arg1: !memrefB, %arg2: !memrefC) -> !memrefC
-{
-  %c0 = arith.constant 0 : index
-  %c8 = arith.constant 8 : index
-  %0 = ub.poison : bf16
-  %1 = vector.load %arg0[%c0, %c0, %c0] :
-        !memrefA, !vecA
-  %2 = vector.load %arg1[%c0, %c0, %c0] :
-        !memrefB, !vecB
-  %3 = vector.load %arg1[%c0, %c0, %c8] :
-        !memrefB, !vecB
-  %4 = vector.load %arg2[%c0, %c0] :
-        !memrefC, !vecC
-  %sqrt = math.sqrt %4 : !vecC 
-  %5 = vector.load %arg2[%c0, %c8] :
-        !memrefC, !vecC
-
-  %6 = vector.contract {
-    indexing_maps = [#map, #map1, #map2],
-    iterator_types = ["reduction", "parallel", "parallel", "reduction"],
-    kind = #vector.kind<add>}
-    %1, %2, %sqrt
-    : !vecA, !vecB into !vecC
-
-  %7 = vector.contract {
-    indexing_maps = [#map, #map1, #map2],
-    iterator_types = ["reduction", "parallel", "parallel", "reduction"],
-    kind = #vector.kind<add>}
-    %1, %3, %5
-    : !vecA, !vecB into !vecC
-
-  vector.store %6, %arg2[%c0, %c0] : !memrefC, !vecC
-  vector.store %7, %arg2[%c0, %c8] : !memrefC, !vecC
-
-  return %arg2 : !memrefC
-}
-
-// CHECK-LABEL: @shuffle_VC_output_flat_layout
-// CHECK: vector.shuffle{{.*}}[0, 8, 1, 9, 2, 10, 3, 11] : vector<8xf32>, vector<8xf32>
-// CHECK-NEXT: vector.shuffle{{.*}}[4, 12, 5, 13, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32>
-// CHECK: vector.contract
-// CHECK: vector.shuffle{{.*}}[0, 8, 1, 9, 2, 10, 3, 11] : vector<8xf32>, vector<8xf32>
-// CHECK-NEXT: vector.shuffle{{.*}}[4, 12, 5, 13, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32>
-
-module attributes {transform.with_named_sequence} {
-  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
-    %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
-    transform.apply_patterns to %func {
-      transform.apply_patterns.x86vector.shuffle_bf16_vector_contract_result
-    } : !transform.any_op
-    transform.yield
-  }
-}
-
-// -----
-
-!vecA = vector<1x1xbf16>
-!vecB = vector<1x8xbf16>
-!vecC = vector<1x8xf32>
-!memrefA = memref<4x1xbf16>
-!memrefB = memref<1x32xbf16>
-!memrefC = memref<2x32xf32>
-#map = affine_map<(d0, d1, d2) -> (d0, d2)>
-#map1 = affine_map<(d0,  d1, d2) -> (d2, d1)>
-#map2 = affine_map<(d0,  d1, d2) -> (d0, d1)>
-func.func @shuffle_VC_output_flat_layout_transfer_read(
-  %arg0: !memrefA, %arg1: !memrefB, %arg2: !memrefC) -> !memrefC
-{
-  %c0 = arith.constant 0 : index
-  %c8 = arith.constant 8 : index
-  %0 = ub.poison : bf16
-  %32 = ub.poison : f32
-  %1 = vector.transfer_read %arg0[%c0, %c0], %0 {in_bounds = [true, true]} :
-        !memrefA, !vecA
-  %2 = vector.transfer_read %arg1[%c0, %c0], %0 {in_bounds = [true, true]} :
-        !memrefB, !vecB
-  %3 = vector.transfer_read %arg1[%c0, %c8], %0 {in_bounds = [true, true]} :
-        !memrefB, !vecB
-  %4 = vector.transfer_read %arg2[%c0, %c0], %32 {in_bounds = [true, true]} :
-        !memrefC, !vecC
-  %5 = vector.transfer_read %arg2[%c0, %c8], %32 {in_bounds = [true, true]} :
-        !memrefC, !vecC
-
-  %6 = vector.contract {
-    indexing_maps = [#map, #map1, #map2],
-    iterator_types = ["parallel", "parallel", "reduction"],
-    kind = #vector.kind<add>}
-    %1, %2, %4
-    : !vecA, !vecB into !vecC
-
-  %7 = vector.contract {
-    indexing_maps = [#map, #map1, #map2],
-    iterator_types = ["parallel", "parallel", "reduction"],
-    kind = #vector.kind<add>}
-    %1, %3, %5
-    : !vecA, !vecB into !vecC
-
-  vector.transfer_write %6, %arg2[%c0, %c0] {in_bounds = [true, true]} : !vecC, !memrefC
-  vector.transfer_write %7, %arg2[%c0, %c8] {in_bounds = [true, true]} : !vecC, !memrefC
-
-  return %arg2 : !memrefC
-}
-
-// CHECK-LABEL: @shuffle_VC_output_flat_layout_transfer_read
-// CHECK: vector.shuffle{{.*}}[0, 8, 1, 9, 2, 10, 3, 11] : vector<8xf32>, vector<8xf32>
-// CHECK-NEXT: vector.shuffle{{.*}}[4, 12, 5, 13, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32>
-// CHECK: vector.contract
-// CHECK: vector.shuffle{{.*}}[0, 8, 1, 9, 2, 10, 3, 11] : vector<8xf32>, vector<8xf32>
-// CHECK-NEXT: vector.shuffle{{.*}}[4, 12, 5, 13, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32>
-
-module attributes {transform.with_named_sequence} {
-  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
-    %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
-    transform.apply_patterns to %func {
-      transform.apply_patterns.x86vector.shuffle_bf16_vector_contract_result
-    } : !transform.any_op
-    transform.yield
-  }
-}
-
-// -----
-
-!vecA = vector<1x2xbf16>
-!vecB = vector<2x16xbf16>
-!vecC = vector<1x16xf32>
-!memrefA = memref<4x2xbf16>
-!memrefB = memref<2x32xbf16>
-!memrefC = memref<2x32xf32>
-#map = affine_map<(d0, d1, d2) -> (d0, d2)>
-#map1 = affine_map<(d0,  d1, d2) -> (d2, d1)>
-#map2 = affine_map<(d0,  d1, d2) -> (d0, d1)>
-func.func @shuffle_VC_output_flat_layout_bf16dp(
-  %arg0: !memrefA, %arg1: !memrefB, %arg2: !memrefC) -> !memrefC
-{
-  %c0 = arith.constant 0 : index
-  %c16 = arith.constant 16 : index
-  %0 = ub.poison : bf16
-  %32 = ub.poison : f32
-  %1 = vector.transfer_read %arg0[%c0, %c0], %0 {in_bounds = [true, true]} :
-        !memrefA, !vecA
-  %2 = vector.transfer_read %arg1[%c0, %c0], %0 {in_bounds = [true, true]} :
-        !memrefB, !vecB
-  %3 = vector.transfer_read %arg1[%c0, %c16], %0 {in_bounds = [true, true]} :
-        !memrefB, !vecB
-  %4 = vector.transfer_read %arg2[%c0, %c0], %32 {in_bounds = [true, true]} :
-        !memrefC, !vecC
-  %5 = vector.transfer_read %arg2[%c0, %c16], %32 {in_bounds = [true, true]} :
-        !memrefC, !vecC
-
-  %6 = vector.contract {
-    indexing_maps = [#map, #map1, #map2],
-    iterator_types = ["parallel", "parallel", "reduction"],
-    kind = #vector.kind<add>}
-    %1, %2, %4
-    : !vecA, !vecB into !vecC
-
-  %7 = vector.contract {
-    indexing_maps = [#map, #map1, #map2],
-    iterator_types = ["parallel", "parallel", "reduction"],
-    kind = #vector.kind<add>}
-    %1, %3, %5
-    : !vecA, !vecB into !vecC
-
-  vector.transfer_write %6, %arg2[%c0, %c0] {in_bounds = [true, true]} : !vecC, !memrefC
-  vector.transfer_write %7, %arg2[%c0, %c16] {in_bounds = [true, true]} : !vecC, !memrefC
-
-  return %arg2 : !memrefC
-}
-
-// CHECK-LABEL: @shuffle_VC_output_flat_layout_bf16dp
-// CHECK: vector.shuffle{{.*}}[0, 1, 2, 3, 16, 17, 18, 19, 4, 5, 6, 7, 20, 21, 22, 23] : vector<16xf32>, vector<16xf32>
-// CHECK-NEXT: vector.shuffle{{.*}}[8, 9, 10, 11, 24, 25, 26, 27, 12, 13, 14, 15, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32>
-// CHECK: vector.contract
-// CHECK: vector.shuffle{{.*}}[0, 1, 2, 3, 16, 17, 18, 19, 4, 5, 6, 7, 20, 21, 22, 23] : vector<16xf32>, vector<16xf32>
-// CHECK-NEXT: vector.shuffle{{.*}}[8, 9, 10, 11, 24, 25, 26, 27, 12, 13, 14, 15, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32>
-
-module attributes {transform.with_named_sequence} {
-  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
-    %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
-    transform.apply_patterns to %func {
-      transform.apply_patterns.x86vector.shuffle_bf16_vector_contract_result
-    } : !transform.any_op
-    transform.yield
-  }
-}
-
-// -----
-
-!vecA = vector<1x1xbf16>
-!vecB = vector<1x32xbf16>
-!vecC = vector<1x32xf32>
-!memrefA = memref<4x2xbf16>
-!memrefB = memref<2x64xbf16>
-!memrefC = memref<2x64xf32>
-#map = affine_map<(d0, d1, d2) -> (d0, d2)>
-#map1 = affine_map<(d0,  d1, d2) -> (d2, d1)>
-#map2 = affine_map<(d0,  d1, d2) -> (d0, d1)>
-func.func @negative_dim_is_32(
-  %arg0: !memrefA, %arg1: !memrefB, %arg2: !memrefC) -> !memrefC
-{
-  %c0 = arith.constant 0 : index
-  %c32 = arith.constant 32 : index
-  %0 = ub.poison : bf16
-  %32 = ub.poison : f32
-  %1 = vector.load %arg0[%c0, %c0] :
-        !memrefA, !vecA
-  %2 = vector.load %arg1[%c0, %c0] :
-        !memrefB, !vecB
-  %3 = vector.load %arg1[%c0, %c32] :
-        !memrefB, !vecB
-  %4 = vector.load %arg2[%c0, %c0] :
-        !memrefC, !vecC
-  %5 = vector.load %arg2[%c0, %c32] :
-        !memrefC, !vecC
-
-  %6 = vector.contract {
-    indexing_maps = [#map, #map1, #map2],
-    iterator_types = ["parallel", "parallel", "reduction"],
-    kind = #vector.kind<add>}
-    %1, %2, %4
-    : !vecA, !vecB into !vecC
-
-  %7 = vector.contract {
-    indexing_maps = [#map, #map1, #map2],
-    iterator_types = ["parallel", "parallel", "reduction"],
-    kind = #vector.kind<add>}
-    %1, %3, %5
-    : !vecA, !vecB into !vecC
-
-  vector.store %6, %arg2[%c0, %c0] : !memrefC, !vecC
-  vector.store %7, %arg2[%c0, %c32] : !memrefC, !vecC
-
-  return %arg2 : !memrefC
-}
-
-// CHECK-LABEL: @negative_dim_is_32
-// CHECK-NOT: vector.shuffle
-// CHECK-NOT: vector.shuffle
-// CHECK: vector.contract
-// CHECK-NOT: vector.shuffle
-// CHECK-NOT: vector.shuffle
-// CHECK: vector.store
-
-module attributes {transform.with_named_sequence} {
-  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
-    %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
-    transform.apply_patterns to %func {
-      transform.apply_patterns.x86vector.shuffle_bf16_vector_contract_result
-    } : !transform.any_op
-    transform.yield
-  }
-}
-
-// -----
-
-!vecA = vector<1x1xbf16>
-!vecB = vector<1x16xbf16>
-!vecC = vector<1x16xf32>
-!memrefA = memref<4x2xbf16>
-!memrefB = memref<2x64xbf16>
-!memrefC = memref<2x64xf32>
-#map = affine_map<(d0, d1, d2) -> (d0, d2)>
-#map1 = affine_map<(d0,  d1, d2) -> (d2, d1)>
-#map2 = affine_map<(d0,  d1, d2) -> (d0, d1)>
-func.func @negative_offset_diff_is_32(
-  %arg0: !memrefA, %arg1: !memrefB, %arg2: !memrefC) -> !memrefC
-{
-  %c0 = arith.constant 0 : index
-  %c16 = arith.constant 16 : index
-  %c32 = arith.constant 32 : index
-  %0 = ub.poison : bf16
-  %32 = ub.poison : f32
-  %1 = vector.load %arg0[%c0, %c0] :
-        !memrefA, !vecA
-  %2 = vector.load %arg1[%c0, %c0] :
-        !memrefB, !vecB
-  %3 = vector.load %arg1[%c0, %c32] :
-        !memrefB, !vecB
-  %4 = vector.load %arg2[%c0, %c0] :
-        !memrefC, !vecC
-  %5 = vector.load %arg2[%c0, %c16] :
-        !memrefC, !vecC
-
-  %6 = vector.contract {
-    indexing_maps = [#map, #map1, #map2],
-    iterator_types = ["parallel", "parallel", "reduction"],
-    kind = #vector.kind<add>}
-    %1, %2, %4
-    : !vecA, !vecB into !vecC
-
-  %7 = vector.contract {
-    indexing_maps = [#map, #map1, #map2],
-    iterator_types = ["parallel", "parallel", "reduction"],
-    kind = #vector.kind<add>}
-    %1, %3, %5
-    : !vecA, !vecB into !vecC
-
-  vector.store %6, %arg2[%c0, %c0] : !memrefC, !vecC
-  vector.store %7, %arg2[%c0, %c16] : !memrefC, !vecC
-
-  return %arg2 : !memrefC
-}
-
-// CHECK-LABEL: @negative_offset_diff_is_32
-// CHECK-NOT: vector.shuffle
-// CHECK-NOT: vector.shuffle
-// CHECK: vector.contract
-// CHECK-NOT: vector.shuffle
-// CHECK-NOT: vector.shuffle
-// CHECK: vector.store
-
-module attributes {transform.with_named_sequence} {
-  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
-    %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
-    transform.apply_patterns to %func {
-      transform.apply_patterns.x86vector.shuffle_bf16_vector_contract_result
-    } : !transform.any_op
-    transform.yield
-  }
-}
-
-// -----
-
-!vecA = vector<1x1xbf16>
-!vecB = vector<1x16xbf16>
-!vecC = vector<1x16xf32>
-!memrefA = memref<4x2xbf16>
-!memrefB = memref<2x64xbf16>
-!memrefC = memref<2x64xf32>
-#map = affine_map<(d0, d1, d2) -> (d0, d2)>
-#map1 = affine_map<(d0,  d1, d2) -> (d2, d1)>
-#map2 = affine_map<(d0,  d1, d2) -> (d0, d1)>
-func.func @negative_dynamic_offset(
-  %arg0: !memrefA, %arg1: !memrefB, %arg2: !memrefC, %arg3: index) -> !memrefC
-{
-  %c0 = arith.constant 0 : index
-  %c16 = arith.constant 16 : index
-  %0 = ub.poison : bf16
-  %32 = ub.poison : f32
-  %1 = vector.load %arg0[%c0, %c0] :
-        !memrefA, !vecA
-  %2 = vector.load %arg1[%c0, %c0] :
-        !memrefB, !vecB
-  %3 = vector.load %arg1[%c0, %arg3] :
-        !memrefB, !vecB
-  %4 = vector.load %arg2[%c0, %c0] :
-        !memrefC, !vecC
-  %5 = vector.load %arg2[%c0, %c16] :
-        !memrefC, !vecC
-
-  %6 = vector.contract {
-    indexing_maps = [#map, #map1, #map2],
-    iterator_types = ["parallel", "parallel", "reduction"],
-    kind = #vector.kind<add>}
-    %1, %2, %4
-    : !vecA, !vecB into !vecC
-
-  %7 = vector.contract {
-    indexing_maps = [#map, #map1, #map2],
-    iterator_types = ["parallel", "parallel", "reduction"],
-    kind = #vector.kind<add>}
-    %1, %3, %5
-    : !vecA, !vecB into !vecC
-
-  vector.store %6, %arg2[%c0, %c0] : !memrefC, !vecC
-  vector.store %7, %arg2[%c0, %c16] : !memrefC, !vecC
-
-  return %arg2 : !memrefC
-}
-
-// CHECK-LABEL: @negative_dynamic_offset
-// CHECK-NOT: vector.shuffle
-// CHECK-NOT: vector.shuffle
-// CHECK: vector.contract
-// CHECK-NOT: vector.shuffle
-// CHECK-NOT: vector.shuffle
-// CHECK: vector.store
-
-module attributes {transform.with_named_sequence} {
-  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
-    %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
-    transform.apply_patterns to %func {
-      transform.apply_patterns.x86vector.shuffle_bf16_vector_contract_result
-    } : !transform.any_op
-    transform.yield
-  }
-}
-
-
-// -----
-
-!vecA = vector<1x1xbf16>
-!vecB = vector<1x16xbf16>
-!vecC = vector<1x16xf32>
-!memrefA = memref<4x2xbf16>
-!memrefB = memref<2x64xbf16>
-!memrefC = memref<2x64xf32>
-#map = affine_map<(d0, d1, d2) -> (d0, d2)>
-#map1 = affine_map<(d0,  d1, d2) -> (d2, d1)>
-#map2 = affine_map<(d0,  d1, d2) -> (d0, d1)>
-func.func @negative_contracts_not_in_order(
-  %arg0: !memrefA, %arg1: !memrefB, %arg2: !memrefC) -> !memrefC
-{
-  %c0 = arith.constant 0 : index
-  %c16 = arith.constant 16 : index
-  %c32 = arith.constant 32 : index
-  %0 = ub.poison : bf16
-  %32 = ub.poison : f32
-  %1 = vector.load %arg0[%c0, %c0] :
-        !memrefA, !vecA
-  %2 = vector.load %arg1[%c0, %c0] :
-        !memrefB, !vecB
-  %3 = vector.load %arg1[%c0, %c32] :
-        !memrefB, !vecB
-  %4 = vector.load %arg2[%c0, %c0] :
-        !memrefC, !vecC
-  %5 = vector.load %arg2[%c0, %c16] :
-        !memrefC, !vecC
-
-  %7 = vector.contract {
-    indexing_maps = [#map, #map1, #map2],
-    iterator_types = ["parallel", "parallel", "reduction"],
-    kind = #vector.kind<add>}
-    %1, %3, %5
-    : !vecA, !vecB into !vecC
-
-  %6 = vector.contract {
-    indexing_maps = [#map, #map1, #map2],
-    iterator_types = ["parallel", "parallel", "reduction"],
-    kind = #vector.kind<add>}
-    %1, %2, %4
-    : !vecA, !vecB into !vecC
-
-  vector.store %6, %arg2[%c0, %c0] : !memrefC, !vecC
-  vector.store %7, %arg2[%c0, %c16] : !memrefC, !vecC
-
-  return %arg2 : !memrefC
-}
-
-// CHECK-LABEL: @negative_contracts_not_in_order
-// CHECK-NOT: vector.shuffle
-// CHECK-NOT: vector.shuffle
-// CHECK: vector.contract
-// CHECK-NOT: vector.shuffle
-// CHECK-NOT: vector.shuffle
-// CHECK: vector.store
-
-module attributes {transform.with_named_sequence} {
-  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
-    %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
-    transform.apply_patterns to %func {
-      transform.apply_patterns.x86vector.shuffle_bf16_vector_contract_result
-    } : !transform.any_op
-    transform.yield
-  }
-}

>From ace65fc0795e3e43c32f9aea3e3893323122f6e6 Mon Sep 17 00:00:00 2001
From: Arun Thangamani <arun.thangamani at intel.com>
Date: Tue, 10 Feb 2026 22:53:17 -0800
Subject: [PATCH 13/17] remove the assert condition

---
 mlir/lib/Dialect/X86Vector/Utils/X86VectorUtils.cpp | 2 --
 1 file changed, 2 deletions(-)

diff --git a/mlir/lib/Dialect/X86Vector/Utils/X86VectorUtils.cpp b/mlir/lib/Dialect/X86Vector/Utils/X86VectorUtils.cpp
index 9ce16d9ed91bc..4e11aebe75a9c 100644
--- a/mlir/lib/Dialect/X86Vector/Utils/X86VectorUtils.cpp
+++ b/mlir/lib/Dialect/X86Vector/Utils/X86VectorUtils.cpp
@@ -300,8 +300,6 @@ void shuffleBeforeWriteLikeOp(mlir::PatternRewriter &rewriter,
   mlir::Value vecA = getWrittenVector(opA);
   mlir::Value vecB = getWrittenVector(opB);
 
-  assert(vecA && vecB && "expected vector write-like ops");
-
   // Decide insertion point and location
   mlir::Operation *insertBefore = opA->isBeforeInBlock(opB) ? opA : opB;
 

>From b6ad4f2f4e100961582a8d825fe295b4dfb1b96e Mon Sep 17 00:00:00 2001
From: Arun Thangamani <arun.thangamani at intel.com>
Date: Wed, 11 Feb 2026 00:56:02 -0800
Subject: [PATCH 14/17] wrapping the standalone API with LogicalResult

---
 .../mlir/Dialect/X86Vector/Utils/X86VectorUtils.h   |  6 +++---
 .../Transforms/VectorContractBF16ToFMA.cpp          |  9 +++++++--
 .../VectorContractToPackedTypeDotProduct.cpp        |  9 +++++++--
 mlir/lib/Dialect/X86Vector/Utils/X86VectorUtils.cpp | 13 ++++++++++---
 4 files changed, 27 insertions(+), 10 deletions(-)

diff --git a/mlir/include/mlir/Dialect/X86Vector/Utils/X86VectorUtils.h b/mlir/include/mlir/Dialect/X86Vector/Utils/X86VectorUtils.h
index 3b9a10f77d35f..8d359cd7de168 100644
--- a/mlir/include/mlir/Dialect/X86Vector/Utils/X86VectorUtils.h
+++ b/mlir/include/mlir/Dialect/X86Vector/Utils/X86VectorUtils.h
@@ -57,9 +57,9 @@ void shuffleAfterReadLikeOp(PatternRewriter &rewriter, Operation *op,
 
 // Shuffles vectors produced by vector.contraction ops into a flat layout
 // before they are written to memory.
-void shuffleBeforeWriteLikeOp(PatternRewriter &rewriter, Operation *op,
-                              Operation *op1, int64_t nonUnitDimAcc,
-                              VectorType accTy);
+LogicalResult shuffleBeforeWriteLikeOp(PatternRewriter &rewriter, Operation *op,
+                                       Operation *op1, int64_t nonUnitDimAcc,
+                                       VectorType accTy);
 
 } // namespace x86vector
 } // namespace mlir
diff --git a/mlir/lib/Dialect/X86Vector/Transforms/VectorContractBF16ToFMA.cpp b/mlir/lib/Dialect/X86Vector/Transforms/VectorContractBF16ToFMA.cpp
index 798e31f72ee97..493a108d4ac40 100644
--- a/mlir/lib/Dialect/X86Vector/Transforms/VectorContractBF16ToFMA.cpp
+++ b/mlir/lib/Dialect/X86Vector/Transforms/VectorContractBF16ToFMA.cpp
@@ -454,8 +454,13 @@ struct VectorContractBF16ToFMA
       rewriter.replaceOp(pairContractOp, castOddFma);
 
       // Shuffle the output of contract operations before it's use.
-      shuffleBeforeWriteLikeOp(rewriter, resultWriteOp0, resultWriteOp1,
-                               nonUnitDim, accTy);
+      LogicalResult writeShuffle = shuffleBeforeWriteLikeOp(
+          rewriter, resultWriteOp0, resultWriteOp1, nonUnitDim, accTy);
+
+      if (failed(writeShuffle))
+        return rewriter.notifyMatchFailure(
+            contractOp,
+            "Write to accumulator is not by transfer_write or store");
 
       return success();
     }
diff --git a/mlir/lib/Dialect/X86Vector/Transforms/VectorContractToPackedTypeDotProduct.cpp b/mlir/lib/Dialect/X86Vector/Transforms/VectorContractToPackedTypeDotProduct.cpp
index 5a13a83a8b812..ecb67af184d81 100644
--- a/mlir/lib/Dialect/X86Vector/Transforms/VectorContractToPackedTypeDotProduct.cpp
+++ b/mlir/lib/Dialect/X86Vector/Transforms/VectorContractToPackedTypeDotProduct.cpp
@@ -335,8 +335,13 @@ struct VectorContractToPackedTypeDotProduct
                                pairContractOp, nonUnitDimValue, accTy);
 
         // Shuffle the output of contract operations before it's use.
-        shuffleBeforeWriteLikeOp(rewriter, resultWriteOp0, resultWriteOp1,
-                                 nonUnitDimValue, accTy);
+        LogicalResult writeShuffle = shuffleBeforeWriteLikeOp(
+            rewriter, resultWriteOp0, resultWriteOp1, nonUnitDimValue, accTy);
+
+        if (failed(writeShuffle))
+          return rewriter.notifyMatchFailure(
+              contractOp,
+              "Write to accumulator is not by transfer_write or store");
       }
     }
 
diff --git a/mlir/lib/Dialect/X86Vector/Utils/X86VectorUtils.cpp b/mlir/lib/Dialect/X86Vector/Utils/X86VectorUtils.cpp
index 4e11aebe75a9c..1a8bba2030909 100644
--- a/mlir/lib/Dialect/X86Vector/Utils/X86VectorUtils.cpp
+++ b/mlir/lib/Dialect/X86Vector/Utils/X86VectorUtils.cpp
@@ -285,9 +285,11 @@ void shuffleAfterReadLikeOp(mlir::PatternRewriter &rewriter,
 
 // This function shuffles the vectors written by vector.contract operation
 // as a flat layout structure before they are stored.
-void shuffleBeforeWriteLikeOp(mlir::PatternRewriter &rewriter,
-                              mlir::Operation *opA, mlir::Operation *opB,
-                              int64_t nonUnitDimAcc, mlir::VectorType accTy) {
+LogicalResult shuffleBeforeWriteLikeOp(mlir::PatternRewriter &rewriter,
+                                       mlir::Operation *opA,
+                                       mlir::Operation *opB,
+                                       int64_t nonUnitDimAcc,
+                                       mlir::VectorType accTy) {
   // Helper to extract vector operand from write-like ops
   auto getWrittenVector = [](mlir::Operation *op) -> mlir::Value {
     if (auto write = mlir::dyn_cast<mlir::vector::TransferWriteOp>(op))
@@ -300,6 +302,9 @@ void shuffleBeforeWriteLikeOp(mlir::PatternRewriter &rewriter,
   mlir::Value vecA = getWrittenVector(opA);
   mlir::Value vecB = getWrittenVector(opB);
 
+  if (!vecA || !vecB)
+    return mlir::failure();
+
   // Decide insertion point and location
   mlir::Operation *insertBefore = opA->isBeforeInBlock(opB) ? opA : opB;
 
@@ -330,6 +335,8 @@ void shuffleBeforeWriteLikeOp(mlir::PatternRewriter &rewriter,
   // Update write operands in place
   opA->setOperand(0, newVecA.getResult());
   opB->setOperand(0, newVecB.getResult());
+
+  return success();
 }
 
 // Return true if vector.contract operations matches on below conditions:

>From 526fb3400b81f7eb88ab4fe996338cb74abd4003 Mon Sep 17 00:00:00 2001
From: Arun Thangamani <arun.thangamani at intel.com>
Date: Wed, 11 Feb 2026 06:18:30 -0800
Subject: [PATCH 15/17] code clean-up + extra validations

---
 .../Dialect/X86Vector/Utils/X86VectorUtils.h  | 19 ++++---
 .../Transforms/VectorContractBF16ToFMA.cpp    |  9 ++-
 .../VectorContractToPackedTypeDotProduct.cpp  | 15 +++--
 .../X86Vector/Utils/X86VectorUtils.cpp        | 55 ++++++++++---------
 4 files changed, 55 insertions(+), 43 deletions(-)

diff --git a/mlir/include/mlir/Dialect/X86Vector/Utils/X86VectorUtils.h b/mlir/include/mlir/Dialect/X86Vector/Utils/X86VectorUtils.h
index 8d359cd7de168..5357cd322af5e 100644
--- a/mlir/include/mlir/Dialect/X86Vector/Utils/X86VectorUtils.h
+++ b/mlir/include/mlir/Dialect/X86Vector/Utils/X86VectorUtils.h
@@ -40,26 +40,27 @@ bool validatePairVectorContract(vector::ContractionOp contractOp,
 // Walks backward from a value to find its originating vector read-like op
 // (vector.transfer_read or vector.load), following scf.for iter-args but
 // stopping at layout-transforming ops; returns the read op or nullptr.
-Operation *traceToVectorReadLikeParentOperation(mlir::Value v);
+Operation *traceToVectorReadLikeParentOperation(Value v);
 
 // Recursively traces a value to find a downstream vector write-like op
 // (vector.transfer_write or vector.store), crossing scf.for/yield but
 // stopping at layout-altering ops; returns the first match or nullptr.
-Operation *traceToVectorWriteLikeUserOperation(mlir::Value v);
+Operation *traceToVectorWriteLikeUserOperation(Value v);
 
 // Packs the accumulators of two flat BF16 vector.contraction ops into a
 // VNNI-packed layout and replaces the original accumulators to enable post-read
 // packing transformations.
-void shuffleAfterReadLikeOp(PatternRewriter &rewriter, Operation *op,
-                            Operation *op1, vector::ContractionOp contractOp,
-                            vector::ContractionOp pairContractOp,
-                            int64_t nonUnitDimAcc, VectorType accTy);
+LogicalResult shuffleAfterReadLikeOp(PatternRewriter &rewriter, Operation *opA,
+                                     Operation *opB,
+                                     vector::ContractionOp contractA,
+                                     vector::ContractionOp contractB,
+                                     int64_t nonUnitDimAcc, VectorType accTy);
 
 // Shuffles vectors produced by vector.contraction ops into a flat layout
 // before they are written to memory.
-LogicalResult shuffleBeforeWriteLikeOp(PatternRewriter &rewriter, Operation *op,
-                                       Operation *op1, int64_t nonUnitDimAcc,
-                                       VectorType accTy);
+LogicalResult shuffleBeforeWriteLikeOp(PatternRewriter &rewriter,
+                                       Operation *opA, Operation *opB,
+                                       int64_t nonUnitDimAcc, VectorType accTy);
 
 } // namespace x86vector
 } // namespace mlir
diff --git a/mlir/lib/Dialect/X86Vector/Transforms/VectorContractBF16ToFMA.cpp b/mlir/lib/Dialect/X86Vector/Transforms/VectorContractBF16ToFMA.cpp
index 493a108d4ac40..7e9cd32db6fc7 100644
--- a/mlir/lib/Dialect/X86Vector/Transforms/VectorContractBF16ToFMA.cpp
+++ b/mlir/lib/Dialect/X86Vector/Transforms/VectorContractBF16ToFMA.cpp
@@ -412,8 +412,13 @@ struct VectorContractBF16ToFMA
       }
 
       // Shuffle the accumulators of the contract operations.
-      shuffleAfterReadLikeOp(rewriter, accReadOp0, accReadOp1, contractOp,
-                             pairContractOp, nonUnitDim, accTy);
+      LogicalResult readShuffle =
+          shuffleAfterReadLikeOp(rewriter, accReadOp0, accReadOp1, contractOp,
+                                 pairContractOp, nonUnitDim, accTy);
+
+      if (failed(readShuffle))
+        return rewriter.notifyMatchFailure(
+            contractOp, "Accumulator read is not by transfer_read or load");
 
       rewriter.setInsertionPoint(contractOp);
 
diff --git a/mlir/lib/Dialect/X86Vector/Transforms/VectorContractToPackedTypeDotProduct.cpp b/mlir/lib/Dialect/X86Vector/Transforms/VectorContractToPackedTypeDotProduct.cpp
index ecb67af184d81..2cc8736e9cc9f 100644
--- a/mlir/lib/Dialect/X86Vector/Transforms/VectorContractToPackedTypeDotProduct.cpp
+++ b/mlir/lib/Dialect/X86Vector/Transforms/VectorContractToPackedTypeDotProduct.cpp
@@ -32,15 +32,13 @@ namespace {
 // VNNI layout, already.
 static bool isNonUnitDimOperandShuffled(Value nonUnitDimOperand) {
   if (Operation *defOp = nonUnitDimOperand.getDefiningOp()) {
-    if (isa<vector::ShuffleOp>(defOp)) {
+    if (isa<vector::ShuffleOp>(defOp))
       return true;
-    }
 
     if (isa<vector::ShapeCastOp>(defOp)) {
       Operation *defOpShpCst = defOp->getOperand(0).getDefiningOp();
-      if (isa<vector::ShuffleOp>(defOpShpCst)) {
+      if (isa<vector::ShuffleOp>(defOpShpCst))
         return true;
-      }
     }
   }
 
@@ -331,8 +329,13 @@ struct VectorContractToPackedTypeDotProduct
               contractOp, "The store/write operation of contract operation is "
                           "before the pair contract operation");
         // Shuffle the accumulators of the contract operations.
-        shuffleAfterReadLikeOp(rewriter, accReadOp0, accReadOp1, contractOp,
-                               pairContractOp, nonUnitDimValue, accTy);
+        LogicalResult readShuffle =
+            shuffleAfterReadLikeOp(rewriter, accReadOp0, accReadOp1, contractOp,
+                                   pairContractOp, nonUnitDimValue, accTy);
+
+        if (failed(readShuffle))
+          return rewriter.notifyMatchFailure(
+              contractOp, "Accumulator read is not by transfer_read or load");
 
         // Shuffle the output of contract operations before it's use.
         LogicalResult writeShuffle = shuffleBeforeWriteLikeOp(
diff --git a/mlir/lib/Dialect/X86Vector/Utils/X86VectorUtils.cpp b/mlir/lib/Dialect/X86Vector/Utils/X86VectorUtils.cpp
index 1a8bba2030909..6c76d80920ce9 100644
--- a/mlir/lib/Dialect/X86Vector/Utils/X86VectorUtils.cpp
+++ b/mlir/lib/Dialect/X86Vector/Utils/X86VectorUtils.cpp
@@ -197,13 +197,11 @@ Operation *traceToVectorWriteLikeUserOperation(Value v) {
     Operation *user = use.getOwner();
 
     // --- TERMINAL OPS ---
-    if (isa<vector::TransferWriteOp>(user) || isa<vector::StoreOp>(user)) {
+    if (isa<vector::TransferWriteOp>(user) || isa<vector::StoreOp>(user))
       return user;
-    }
 
-    if (isa<vector::ShapeCastOp, vector::ShuffleOp>(user)) {
+    if (isa<vector::ShapeCastOp, vector::ShuffleOp>(user))
       return nullptr;
-    }
 
     // --- SCF YIELD ---
     if (auto yield = dyn_cast<scf::YieldOp>(user)) {
@@ -233,27 +231,23 @@ Operation *traceToVectorWriteLikeUserOperation(Value v) {
   return nullptr;
 }
 
+// This function packs the accumulator of two flat BF16 vector.contract
+// operations into VNNI packed and are then replaced in their respective
+// contraction ops, enabling post-read layout or packing transformations.
 // TODO: replace all use with the packed value along with contration
 // and for op.
-static void rewriteUses(mlir::Value oldVal, mlir::Value newVal) {
-  for (mlir::OpOperand &use : llvm::make_early_inc_range(oldVal.getUses())) {
-    mlir::Operation *user = use.getOwner();
-
-    if (mlir::isa<mlir::vector::ContractionOp>(user) ||
-        mlir::isa<mlir::scf::ForOp>(user)) {
-      use.set(newVal);
-    }
+LogicalResult shuffleAfterReadLikeOp(mlir::PatternRewriter &rewriter,
+                                     mlir::Operation *opA, mlir::Operation *opB,
+                                     mlir::vector::ContractionOp contractA,
+                                     mlir::vector::ContractionOp contractB,
+                                     int64_t nonUnitDimAcc,
+                                     mlir::VectorType accTy) {
+
+  if (!mlir::isa<mlir::vector::TransferReadOp, mlir::vector::LoadOp>(opA) ||
+      !mlir::isa<mlir::vector::TransferReadOp, mlir::vector::LoadOp>(opB)) {
+    return mlir::failure();
   }
-}
 
-// This function packs the accumulator of two flat BF16 vector.contract
-// operations into VNNI packed and are then replaced in their respective
-// contraction ops, enabling post-read layout or packing transformations.
-void shuffleAfterReadLikeOp(mlir::PatternRewriter &rewriter,
-                            mlir::Operation *opA, mlir::Operation *opB,
-                            mlir::vector::ContractionOp contractA,
-                            mlir::vector::ContractionOp contractB,
-                            int64_t nonUnitDimAcc, mlir::VectorType accTy) {
   mlir::Operation *insertAfter = opA->isBeforeInBlock(opB) ? opB : opA;
 
   rewriter.setInsertionPointAfter(insertAfter);
@@ -279,8 +273,17 @@ void shuffleAfterReadLikeOp(mlir::PatternRewriter &rewriter,
   auto newAccB =
       mlir::vector::ShapeCastOp::create(rewriter, loc, accTy, shuffleHi);
 
-  rewriteUses(opA->getResult(0), newAccA.getResult());
-  rewriteUses(opB->getResult(0), newAccB.getResult());
+  rewriter.replaceUsesWithIf(
+      opA->getResult(0), newAccA.getResult(), [&](OpOperand &use) {
+        return isa<vector::ContractionOp, scf::ForOp>(use.getOwner());
+      });
+
+  rewriter.replaceUsesWithIf(
+      opB->getResult(0), newAccB.getResult(), [&](OpOperand &use) {
+        return isa<vector::ContractionOp, scf::ForOp>(use.getOwner());
+      });
+
+  return success();
 }
 
 // This function shuffles the vectors written by vector.contract operation
@@ -303,7 +306,7 @@ LogicalResult shuffleBeforeWriteLikeOp(mlir::PatternRewriter &rewriter,
   mlir::Value vecB = getWrittenVector(opB);
 
   if (!vecA || !vecB)
-    return mlir::failure();
+    return failure();
 
   // Decide insertion point and location
   mlir::Operation *insertBefore = opA->isBeforeInBlock(opB) ? opA : opB;
@@ -344,7 +347,7 @@ LogicalResult shuffleBeforeWriteLikeOp(mlir::PatternRewriter &rewriter,
 //  (2) - the defining source memref should be same for nonUnitDim
 //  operation,
 //  (3) - the nonUnit dim offset difference between the
-//  vector.contracts should be 8.
+//  vector.contracts should be 8 or 16.
 bool validatePairVectorContract(vector::ContractionOp contractOp,
                                 vector::ContractionOp pairContOp,
                                 bool rhsHasMultipleNonUnitDims,
@@ -398,7 +401,7 @@ bool validatePairVectorContract(vector::ContractionOp contractOp,
     return shuffleLw.getV1() == shuffleHw.getV1() &&
            shuffleLw.getV2() == shuffleHw.getV2();
 
-  if (!(srcBuff == srcBuffPairContOp))
+  if (srcBuff != srcBuffPairContOp)
     return false;
 
   for (size_t i = 0; i < indexVals.size(); i++) {

>From 31e1944d73be0b1719509c00d8651506856df0bb Mon Sep 17 00:00:00 2001
From: Arun Thangamani <arun.thangamani at intel.com>
Date: Wed, 11 Feb 2026 07:42:41 -0800
Subject: [PATCH 16/17] code clean-up

---
 .../X86Vector/Utils/X86VectorUtils.cpp        | 86 +++++++++----------
 1 file changed, 40 insertions(+), 46 deletions(-)

diff --git a/mlir/lib/Dialect/X86Vector/Utils/X86VectorUtils.cpp b/mlir/lib/Dialect/X86Vector/Utils/X86VectorUtils.cpp
index 6c76d80920ce9..b1ce3a8ddfd33 100644
--- a/mlir/lib/Dialect/X86Vector/Utils/X86VectorUtils.cpp
+++ b/mlir/lib/Dialect/X86Vector/Utils/X86VectorUtils.cpp
@@ -236,16 +236,15 @@ Operation *traceToVectorWriteLikeUserOperation(Value v) {
 // contraction ops, enabling post-read layout or packing transformations.
 // TODO: replace all use with the packed value along with contration
 // and for op.
-LogicalResult shuffleAfterReadLikeOp(mlir::PatternRewriter &rewriter,
-                                     mlir::Operation *opA, mlir::Operation *opB,
-                                     mlir::vector::ContractionOp contractA,
-                                     mlir::vector::ContractionOp contractB,
-                                     int64_t nonUnitDimAcc,
-                                     mlir::VectorType accTy) {
-
-  if (!mlir::isa<mlir::vector::TransferReadOp, mlir::vector::LoadOp>(opA) ||
-      !mlir::isa<mlir::vector::TransferReadOp, mlir::vector::LoadOp>(opB)) {
-    return mlir::failure();
+LogicalResult shuffleAfterReadLikeOp(PatternRewriter &rewriter, Operation *opA,
+                                     mlir::Operation *opB,
+                                     vector::ContractionOp contractA,
+                                     vector::ContractionOp contractB,
+                                     int64_t nonUnitDimAcc, VectorType accTy) {
+
+  if (!isa<mlir::vector::TransferReadOp, vector::LoadOp>(opA) ||
+      !isa<mlir::vector::TransferReadOp, vector::LoadOp>(opB)) {
+    return failure();
   }
 
   mlir::Operation *insertAfter = opA->isBeforeInBlock(opB) ? opB : opA;
@@ -254,24 +253,22 @@ LogicalResult shuffleAfterReadLikeOp(mlir::PatternRewriter &rewriter,
   mlir::Location loc = insertAfter->getLoc();
 
   auto elemTy = accTy.getElementType();
-  auto flatTy = mlir::VectorType::get(nonUnitDimAcc, elemTy);
+  auto flatTy = VectorType::get(nonUnitDimAcc, elemTy);
 
-  auto castA = mlir::vector::ShapeCastOp::create(rewriter, loc, flatTy,
-                                                 opA->getResult(0));
-  auto castB = mlir::vector::ShapeCastOp::create(rewriter, loc, flatTy,
-                                                 opB->getResult(0));
+  auto castA =
+      vector::ShapeCastOp::create(rewriter, loc, flatTy, opA->getResult(0));
+  auto castB =
+      vector::ShapeCastOp::create(rewriter, loc, flatTy, opB->getResult(0));
 
   auto masks = getShuffleMasks(nonUnitDimAcc);
 
-  auto shuffleLo = mlir::vector::ShuffleOp::create(rewriter, loc, flatTy, castA,
-                                                   castB, masks.maskLo);
-  auto shuffleHi = mlir::vector::ShuffleOp::create(rewriter, loc, flatTy, castA,
-                                                   castB, masks.maskHi);
+  auto shuffleLo = vector::ShuffleOp::create(rewriter, loc, flatTy, castA,
+                                             castB, masks.maskLo);
+  auto shuffleHi = vector::ShuffleOp::create(rewriter, loc, flatTy, castA,
+                                             castB, masks.maskHi);
 
-  auto newAccA =
-      mlir::vector::ShapeCastOp::create(rewriter, loc, accTy, shuffleLo);
-  auto newAccB =
-      mlir::vector::ShapeCastOp::create(rewriter, loc, accTy, shuffleHi);
+  auto newAccA = vector::ShapeCastOp::create(rewriter, loc, accTy, shuffleLo);
+  auto newAccB = vector::ShapeCastOp::create(rewriter, loc, accTy, shuffleHi);
 
   rewriter.replaceUsesWithIf(
       opA->getResult(0), newAccA.getResult(), [&](OpOperand &use) {
@@ -288,52 +285,49 @@ LogicalResult shuffleAfterReadLikeOp(mlir::PatternRewriter &rewriter,
 
 // This function shuffles the vectors written by vector.contract operation
 // as a flat layout structure before they are stored.
-LogicalResult shuffleBeforeWriteLikeOp(mlir::PatternRewriter &rewriter,
-                                       mlir::Operation *opA,
-                                       mlir::Operation *opB,
+LogicalResult shuffleBeforeWriteLikeOp(PatternRewriter &rewriter,
+                                       Operation *opA, Operation *opB,
                                        int64_t nonUnitDimAcc,
-                                       mlir::VectorType accTy) {
+                                       VectorType accTy) {
   // Helper to extract vector operand from write-like ops
-  auto getWrittenVector = [](mlir::Operation *op) -> mlir::Value {
-    if (auto write = mlir::dyn_cast<mlir::vector::TransferWriteOp>(op))
+  auto getWrittenVector = [](Operation *op) -> mlir::Value {
+    if (auto write = mlir::dyn_cast<vector::TransferWriteOp>(op))
       return write.getVector();
-    if (auto store = mlir::dyn_cast<mlir::vector::StoreOp>(op))
+    if (auto store = mlir::dyn_cast<vector::StoreOp>(op))
       return store.getValueToStore();
     return nullptr;
   };
 
-  mlir::Value vecA = getWrittenVector(opA);
-  mlir::Value vecB = getWrittenVector(opB);
+  Value vecA = getWrittenVector(opA);
+  Value vecB = getWrittenVector(opB);
 
   if (!vecA || !vecB)
     return failure();
 
   // Decide insertion point and location
-  mlir::Operation *insertBefore = opA->isBeforeInBlock(opB) ? opA : opB;
+  Operation *insertBefore = opA->isBeforeInBlock(opB) ? opA : opB;
 
   rewriter.setInsertionPoint(insertBefore);
-  mlir::Location loc = insertBefore->getLoc();
+  Location loc = insertBefore->getLoc();
 
   auto elemTy = accTy.getElementType();
   auto flatTy = mlir::VectorType::get(nonUnitDimAcc, elemTy);
 
   // Flatten vectors
-  auto castA = mlir::vector::ShapeCastOp::create(rewriter, loc, flatTy, vecA);
-  auto castB = mlir::vector::ShapeCastOp::create(rewriter, loc, flatTy, vecB);
+  auto castA = vector::ShapeCastOp::create(rewriter, loc, flatTy, vecA);
+  auto castB = vector::ShapeCastOp::create(rewriter, loc, flatTy, vecB);
 
   // TODO: derive shuffle masks instead of hard-coding
   auto masks = getShuffleMasks(nonUnitDimAcc);
 
-  auto shuffledLo = mlir::vector::ShuffleOp::create(rewriter, loc, flatTy,
-                                                    castA, castB, masks.maskLo);
-  auto shuffledHi = mlir::vector::ShuffleOp::create(rewriter, loc, flatTy,
-                                                    castA, castB, masks.maskHi);
+  auto shuffledLo = vector::ShuffleOp::create(rewriter, loc, flatTy, castA,
+                                              castB, masks.maskLo);
+  auto shuffledHi = vector::ShuffleOp::create(rewriter, loc, flatTy, castA,
+                                              castB, masks.maskHi);
 
   // Cast back to accumulator type
-  auto newVecA =
-      mlir::vector::ShapeCastOp::create(rewriter, loc, accTy, shuffledLo);
-  auto newVecB =
-      mlir::vector::ShapeCastOp::create(rewriter, loc, accTy, shuffledHi);
+  auto newVecA = vector::ShapeCastOp::create(rewriter, loc, accTy, shuffledLo);
+  auto newVecB = vector::ShapeCastOp::create(rewriter, loc, accTy, shuffledHi);
 
   // Update write operands in place
   opA->setOperand(0, newVecA.getResult());
@@ -367,7 +361,7 @@ bool validatePairVectorContract(vector::ContractionOp contractOp,
 
   Value srcBuff;
   SmallVector<OpFoldResult> indexVals;
-  llvm::TypeSwitch<mlir::Operation *>(nonUnitOperand.getDefiningOp())
+  llvm::TypeSwitch<Operation *>(nonUnitOperand.getDefiningOp())
       .Case<vector::TransferReadOp, vector::LoadOp>([&](auto readOp) {
         srcBuff = readOp.getOperand(0);
         indexVals = SmallVector<OpFoldResult>(readOp.getIndices().begin(),
@@ -380,7 +374,7 @@ bool validatePairVectorContract(vector::ContractionOp contractOp,
 
   Value srcBuffPairContOp;
   SmallVector<OpFoldResult> indexValsPairContOp;
-  llvm::TypeSwitch<mlir::Operation *>(nonUnitOperandPairContOp.getDefiningOp())
+  llvm::TypeSwitch<Operation *>(nonUnitOperandPairContOp.getDefiningOp())
       .Case<vector::TransferReadOp, vector::LoadOp>([&](auto readOp) {
         srcBuffPairContOp = readOp.getOperand(0);
         indexValsPairContOp = SmallVector<OpFoldResult>(

>From 0ed194809c3369051ec9a09a8e384612b1a27d0f Mon Sep 17 00:00:00 2001
From: Arun Thangamani <arun.thangamani at intel.com>
Date: Thu, 12 Feb 2026 01:34:31 -0800
Subject: [PATCH 17/17] cleanup + typos + new loop uni-tests

---
 .../Transforms/VectorContractBF16ToFMA.cpp    | 11 ++-
 .../VectorContractToPackedTypeDotProduct.cpp  |  2 +-
 .../X86Vector/Utils/X86VectorUtils.cpp        |  6 +-
 .../vector-contract-bf16-to-fma.mlir          | 78 +++++++++++++++++++
 ...or-contract-to-packed-type-dotproduct.mlir | 78 +++++++++++++++++++
 5 files changed, 164 insertions(+), 11 deletions(-)

diff --git a/mlir/lib/Dialect/X86Vector/Transforms/VectorContractBF16ToFMA.cpp b/mlir/lib/Dialect/X86Vector/Transforms/VectorContractBF16ToFMA.cpp
index 7e9cd32db6fc7..f848f8085ca1d 100644
--- a/mlir/lib/Dialect/X86Vector/Transforms/VectorContractBF16ToFMA.cpp
+++ b/mlir/lib/Dialect/X86Vector/Transforms/VectorContractBF16ToFMA.cpp
@@ -64,8 +64,7 @@ static bool validateVectorContractOperands(Value prodOp, bool isVnni) {
   // The x86vector.avx.cvt.packed.even/odd.indexed_to_f32 operations require
   // an eight-element tuple of bf16 values to be contiguous.
   int dimsToCheck = isVnni ? 2 : 1;
-  if (!llvm::cast<mlir::MemRefType>(srcType).areTrailingDimsContiguous(
-          dimsToCheck))
+  if (!cast<mlir::MemRefType>(srcType).areTrailingDimsContiguous(dimsToCheck))
     return false;
 
   // Return false if the vnni offset of load or transfer_read is not zero.
@@ -127,7 +126,7 @@ getSubviewFromVectorInput(Location loc, PatternRewriter &rewriter, Value prodOp,
   }
 
   auto one = rewriter.getIndexAttr(1);
-  llvm::SmallVector<memref::SubViewOp> subviews;
+  SmallVector<memref::SubViewOp> subviews;
 
   if (!isVNNI) {
     SmallVector<OpFoldResult> strides(indexVals.size(), one);
@@ -380,7 +379,7 @@ struct VectorContractBF16ToFMA
       Operation *accReadOp1 =
           traceToVectorReadLikeParentOperation(pairContractOp.getAcc());
 
-      // Iterate dowm to find the users of contact operations until it is store
+      // Iterate down to find the users of contact operations until it is store
       // or transfer_write.
       Operation *resultWriteOp0 =
           traceToVectorWriteLikeUserOperation(contractOp.getResult());
@@ -390,7 +389,7 @@ struct VectorContractBF16ToFMA
       if (!accReadOp0 || !accReadOp1)
         return rewriter.notifyMatchFailure(
             contractOp,
-            "Operands doesn't have load or transfer_read as it's parent op");
+            "Operand doesn't have load or transfer_read as its parent op");
 
       if (!resultWriteOp0 || !resultWriteOp1)
         return rewriter.notifyMatchFailure(
@@ -458,7 +457,7 @@ struct VectorContractBF16ToFMA
                                                     accTyPairCont, oddIdxFMA);
       rewriter.replaceOp(pairContractOp, castOddFma);
 
-      // Shuffle the output of contract operations before it's use.
+      // Shuffle the output of contract operations before its use.
       LogicalResult writeShuffle = shuffleBeforeWriteLikeOp(
           rewriter, resultWriteOp0, resultWriteOp1, nonUnitDim, accTy);
 
diff --git a/mlir/lib/Dialect/X86Vector/Transforms/VectorContractToPackedTypeDotProduct.cpp b/mlir/lib/Dialect/X86Vector/Transforms/VectorContractToPackedTypeDotProduct.cpp
index 2cc8736e9cc9f..0da02c08975ad 100644
--- a/mlir/lib/Dialect/X86Vector/Transforms/VectorContractToPackedTypeDotProduct.cpp
+++ b/mlir/lib/Dialect/X86Vector/Transforms/VectorContractToPackedTypeDotProduct.cpp
@@ -298,7 +298,7 @@ struct VectorContractToPackedTypeDotProduct
         Operation *accReadOp1 =
             traceToVectorReadLikeParentOperation(pairContractOp.getAcc());
 
-        // Iterate dowm to find the users of contact operations until it is
+        // Iterate down to find the users of contact operations until it is
         // store or transfer_write.
         Operation *resultWriteOp0 =
             traceToVectorWriteLikeUserOperation(contractOp.getResult());
diff --git a/mlir/lib/Dialect/X86Vector/Utils/X86VectorUtils.cpp b/mlir/lib/Dialect/X86Vector/Utils/X86VectorUtils.cpp
index b1ce3a8ddfd33..1b10e94db4442 100644
--- a/mlir/lib/Dialect/X86Vector/Utils/X86VectorUtils.cpp
+++ b/mlir/lib/Dialect/X86Vector/Utils/X86VectorUtils.cpp
@@ -147,13 +147,11 @@ Operation *traceToVectorReadLikeParentOperation(Value v) {
   while (true) {
     // Case 1: Value defined by an operation
     if (Operation *defOp = v.getDefiningOp()) {
-      if (isa<vector::TransferReadOp, vector::LoadOp>(defOp)) {
+      if (isa<vector::TransferReadOp, vector::LoadOp>(defOp))
         return defOp;
-      }
 
-      if (isa<vector::ShapeCastOp, vector::ShuffleOp>(defOp)) {
+      if (isa<vector::ShapeCastOp, vector::ShuffleOp>(defOp))
         return nullptr;
-      }
 
       if (defOp->getNumOperands() == 1) {
         v = defOp->getOperand(0);
diff --git a/mlir/test/Dialect/X86Vector/vector-contract-bf16-to-fma.mlir b/mlir/test/Dialect/X86Vector/vector-contract-bf16-to-fma.mlir
index a29ffda283b3c..714dcf3d44b0c 100644
--- a/mlir/test/Dialect/X86Vector/vector-contract-bf16-to-fma.mlir
+++ b/mlir/test/Dialect/X86Vector/vector-contract-bf16-to-fma.mlir
@@ -377,6 +377,84 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
+!vecA = vector<1x1x1xbf16>
+!vecB = vector<1x1x8xbf16>
+!vecC = vector<1x8xf32>
+!memrefA = memref<1x1x1xbf16, strided<[2048, 32, 1], offset: ?>>
+!memrefB = memref<1x1x16xbf16, strided<[2048, 64, 1], offset: ?>>
+!memrefC = memref<1x16xf32, strided<[64, 1], offset: ?>>
+#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
+#map2 = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+
+func.func @matmul_to_fma_flat_layout_loop(%arg0: memref<16x64x32xbf16>, %arg1: memref<16x32x64xbf16>, 
+              %arg2: memref<64x64xf32>) -> memref<64x64xf32>  {
+  %c8 = arith.constant 8 : index
+  %0 = ub.poison : f32
+  %1 = ub.poison : bf16
+  %c0 = arith.constant 0 : index
+  %c64 = arith.constant 64 : index
+  %c16 = arith.constant 16 : index
+  %c32 = arith.constant 32 : index
+  %c1 = arith.constant 1 : index
+  scf.for %arg3 = %c0 to %c64 step %c1 {
+    scf.for %arg4 = %c0 to %c64 step %c16 {
+      %subview = memref.subview %arg2[%arg3, %arg4] [1, 16] [1, 1] : memref<64x64xf32> to !memrefC
+      %2 = vector.transfer_read %subview[%c0, %c0], %0 {in_bounds = [true, true]} : !memrefC, !vecC
+      %3 = vector.transfer_read %subview[%c0, %c8], %0 {in_bounds = [true, true]} : !memrefC, !vecC
+
+      %4:2 = scf.for %arg5 = %c0 to %c16 step %c1 iter_args(%arg6 = %2, %arg7 = %3) -> (!vecC, !vecC) {
+        %5:2 = scf.for %arg8 = %c0 to %c32 step %c1 iter_args(%arg9 = %arg6, %arg10 = %arg7) -> (!vecC, !vecC) {
+
+          %subview_0 = memref.subview %arg0[%arg5, %arg3, %arg8] [1, 1, 1] [1, 1, 1] : memref<16x64x32xbf16> to !memrefA
+          %subview_1 = memref.subview %arg1[%arg5, %arg8, %arg4] [1, 1, 16] [1, 1, 1] : memref<16x32x64xbf16> to !memrefB
+
+          %6 = vector.transfer_read %subview_0[%c0, %c0, %c0], %1 {in_bounds = [true, true, true]} : !memrefA, !vecA
+          %7 = vector.transfer_read %subview_1[%c0, %c0, %c0], %1 {in_bounds = [true, true, true]} : !memrefB, !vecB
+          %8 = vector.transfer_read %subview_1[%c0, %c0, %c8], %1 {in_bounds = [true, true, true]} : !memrefB, !vecB
+
+          %9 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %6, %7, %arg9 {unroll_shape = array<i64: 1, 1, 8, 1>} : !vecA, !vecB into !vecC
+          %10 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %6, %8, %arg10 {unroll_shape = array<i64: 1, 1, 8, 1>} : !vecA, !vecB into !vecC
+
+          scf.yield %9, %10 : !vecC, !vecC
+        }
+        scf.yield %5#0, %5#1 : !vecC, !vecC
+      }
+
+      vector.transfer_write %4#1, %subview[%c0, %c8] {in_bounds = [true, true]} : !vecC, !memrefC
+      vector.transfer_write %4#0, %subview[%c0, %c0] {in_bounds = [true, true]} : !vecC, !memrefC
+    }
+  }
+
+  return %arg2 : memref<64x64xf32>
+}
+
+// CHECK-LABEL: @matmul_to_fma_flat_layout_loop
+// CHECK: vector.shuffle{{.*}}[0, 8, 1, 9, 2, 10, 3, 11] : vector<8xf32>, vector<8xf32>
+// CHECK-NEXT: vector.shuffle{{.*}}[4, 12, 5, 13, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32>
+// CHECK: scf.for
+// CHECK: scf.for
+// CHECK: x86vector.avx.bcst_to_f32.packed
+// CHECK: x86vector.avx.cvt.packed.even.indexed_to_f32
+// CHECK: vector.fma {{.*}} : vector<8xf32>
+// CHECK: x86vector.avx.cvt.packed.odd.indexed_to_f32
+// CHECK: vector.fma {{.*}} : vector<8xf32>
+// CHECK: scf.yield
+// CHECK: vector.shuffle{{.*}}[0, 8, 1, 9, 2, 10, 3, 11] : vector<8xf32>, vector<8xf32>
+// CHECK-NEXT: vector.shuffle{{.*}}[4, 12, 5, 13, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32>
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.x86vector.vector_contract_bf16_to_fma
+    } : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
 !vecA = vector<8x1xbf16>
 !vecB = vector<1x1xbf16>
 !vecC = vector<8x1xf32>
diff --git a/mlir/test/Dialect/X86Vector/vector-contract-to-packed-type-dotproduct.mlir b/mlir/test/Dialect/X86Vector/vector-contract-to-packed-type-dotproduct.mlir
index 8c78a293f57a0..76b517bf9c872 100644
--- a/mlir/test/Dialect/X86Vector/vector-contract-to-packed-type-dotproduct.mlir
+++ b/mlir/test/Dialect/X86Vector/vector-contract-to-packed-type-dotproduct.mlir
@@ -412,6 +412,84 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
+!vecA = vector<1x1x2xbf16>
+!vecB = vector<1x2x16xbf16>
+!vecC = vector<1x16xf32>
+!memrefA = memref<1x1x2xbf16, strided<[2048, 32, 1], offset: ?>>
+!memrefB = memref<1x2x32xbf16, strided<[2048, 64, 1], offset: ?>>
+!memrefC = memref<1x32xf32, strided<[64, 1], offset: ?>>
+#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
+#map2 = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+
+func.func @brmatmul_bf16dp_flat_layout_loop(%arg0: memref<16x64x32xbf16>, %arg1: memref<16x32x64xbf16>,
+                             %arg2: memref<64x64xf32>) -> memref<64x64xf32> {
+  %0 = ub.poison : f32
+  %1 = ub.poison : bf16
+  %c0 = arith.constant 0 : index
+  %c64 = arith.constant 64 : index
+  %c16 = arith.constant 16 : index
+  %c32 = arith.constant 32 : index
+  %c1 = arith.constant 1 : index
+  %c2 = arith.constant 2 : index
+  scf.for %arg3 = %c0 to %c64 step %c1 {
+    scf.for %arg4 = %c0 to %c64 step %c32 {
+      %subview = memref.subview %arg2[%arg3, %arg4] [1, 32] [1, 1] : memref<64x64xf32> to !memrefC
+      %2 = vector.transfer_read %subview[%c0, %c0], %0 {in_bounds = [true, true]} : !memrefC, !vecC
+      %3 = vector.transfer_read %subview[%c0, %c16], %0 {in_bounds = [true, true]} : !memrefC, !vecC
+
+      %4:2 = scf.for %arg5 = %c0 to %c16 step %c1 iter_args(%arg6 = %2, %arg7 = %3) -> (!vecC, !vecC) {
+        %5:2 = scf.for %arg8 = %c0 to %c32 step %c2 iter_args(%arg9 = %arg6, %arg10 = %arg7) -> (!vecC, !vecC) {
+
+          %subview_0 = memref.subview %arg0[%arg5, %arg3, %arg8] [1, 1, 2] [1, 1, 1] : memref<16x64x32xbf16> to !memrefA
+          %subview_1 = memref.subview %arg1[%arg5, %arg8, %arg4] [1, 2, 32] [1, 1, 1] : memref<16x32x64xbf16> to !memrefB
+
+          %6 = vector.transfer_read %subview_0[%c0, %c0, %c0], %1 {in_bounds = [true, true, true]} : !memrefA, !vecA
+          %7 = vector.transfer_read %subview_1[%c0, %c0, %c0], %1 {in_bounds = [true, true, true]} : !memrefB, !vecB
+          %8 = vector.transfer_read %subview_1[%c0, %c0, %c16], %1 {in_bounds = [true, true, true]} : !memrefB, !vecB
+
+          %9 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %6, %7, %arg9 {unroll_shape = array<i64: 1, 1, 16, 2>} : !vecA, !vecB into !vecC
+          %10 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %6, %8, %arg10 {unroll_shape = array<i64: 1, 1, 16, 2>} : !vecA, !vecB into !vecC
+
+          scf.yield %9, %10 : !vecC, !vecC
+        }
+        scf.yield %5#0, %5#1 : !vecC, !vecC
+      }
+
+      vector.transfer_write %4#1, %subview[%c0, %c16] {in_bounds = [true, true]} : !vecC, !memrefC
+      vector.transfer_write %4#0, %subview[%c0, %c0] {in_bounds = [true, true]} : !vecC, !memrefC
+    }
+  }
+
+  return %arg2 : memref<64x64xf32>
+}
+
+// CHECK-LABEL: @brmatmul_bf16dp_flat_layout_loop
+// CHECK: vector.shuffle{{.*}}[0, 1, 2, 3, 16, 17, 18, 19, 4, 5, 6, 7, 20, 21, 22, 23] : vector<16xf32>, vector<16xf32>
+// CHECK-NEXT: vector.shuffle{{.*}}[8, 9, 10, 11, 24, 25, 26, 27, 12, 13, 14, 15, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32>
+// CHECK: scf.for
+// CHECK: scf.for
+// CHECK: vector.shuffle{{.*}}[0, 32, 1, 33, 2, 34, 3, 35, 8, 40, 9, 41, 10, 42, 11, 43, 16, 48, 17, 49, 18, 50, 19, 51, 24, 56, 25, 57, 26, 58, 27, 59] : vector<32xbf16>, vector<32xbf16>
+// CHECK-NEXT: vector.shuffle{{.*}}[4, 36, 5, 37, 6, 38, 7, 39, 12, 44, 13, 45, 14, 46, 15, 47, 20, 52, 21, 53, 22, 54, 23, 55, 28, 60, 29, 61, 30, 62, 31, 63] : vector<32xbf16>, vector<32xbf16>
+// CHECK: x86vector.avx512.dot
+// CHECK: x86vector.avx512.dot
+// CHECK: scf.yield
+// CHECK: scf.yield
+// CHECK: vector.shuffle{{.*}}[0, 1, 2, 3, 16, 17, 18, 19, 4, 5, 6, 7, 20, 21, 22, 23] : vector<16xf32>, vector<16xf32>
+// CHECK-NEXT: vector.shuffle{{.*}}[8, 9, 10, 11, 24, 25, 26, 27, 12, 13, 14, 15, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32>
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product
+    } : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
 !vecA = vector<1x2xbf16>
 !vecB = vector<2x16xbf16>
 !vecC = vector<1x16xf32>



More information about the Mlir-commits mailing list