[Mlir-commits] [mlir] [mlir][x86vector] Lower vector.contract to FMA or packed type dot-product (PR #168074)

Arun Thangamani llvmlistbot at llvm.org
Fri Nov 21 03:23:18 PST 2025


https://github.com/arun-thmn updated https://github.com/llvm/llvm-project/pull/168074

>From 5b9f7e01084e8f21f873e4b893a0dbb5f7f464d1 Mon Sep 17 00:00:00 2001
From: Arun Thangamani <arun.thangamani at intel.com>
Date: Fri, 14 Nov 2025 07:34:52 -0800
Subject: [PATCH 1/6] transform pass to lower vector.contract to fma or packed
 dp

---
 .../mlir/Dialect/X86Vector/CMakeLists.txt     |   2 +
 .../X86Vector/TransformOps/CMakeLists.txt     |   4 +
 .../TransformOps/X86VectorTransformOps.h      |  31 ++
 .../TransformOps/X86VectorTransformOps.td     |  42 ++
 .../mlir/Dialect/X86Vector/Transforms.h       |  10 +
 mlir/lib/Dialect/X86Vector/CMakeLists.txt     |   1 +
 .../X86Vector/TransformOps/CMakeLists.txt     |  17 +
 .../TransformOps/X86VectorTransformOps.cpp    |  65 +++
 .../X86Vector/Transforms/CMakeLists.txt       |   2 +
 .../Transforms/VectorContractToFMA.cpp        |  99 +++++
 .../VectorContractToPackedTypeDotProduct.cpp  | 148 +++++++
 mlir/lib/RegisterAllExtensions.cpp            |   2 +
 .../X86Vector/vector-contract-to-fma.mlir     | 210 ++++++++++
 ...or-contract-to-packed-type-dotproduct.mlir | 374 ++++++++++++++++++
 14 files changed, 1007 insertions(+)
 create mode 100644 mlir/include/mlir/Dialect/X86Vector/TransformOps/CMakeLists.txt
 create mode 100644 mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.h
 create mode 100644 mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td
 create mode 100644 mlir/lib/Dialect/X86Vector/TransformOps/CMakeLists.txt
 create mode 100644 mlir/lib/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp
 create mode 100644 mlir/lib/Dialect/X86Vector/Transforms/VectorContractToFMA.cpp
 create mode 100644 mlir/lib/Dialect/X86Vector/Transforms/VectorContractToPackedTypeDotProduct.cpp
 create mode 100644 mlir/test/Dialect/X86Vector/vector-contract-to-fma.mlir
 create mode 100644 mlir/test/Dialect/X86Vector/vector-contract-to-packed-type-dotproduct.mlir

diff --git a/mlir/include/mlir/Dialect/X86Vector/CMakeLists.txt b/mlir/include/mlir/Dialect/X86Vector/CMakeLists.txt
index 0fe01824b8248..bbe8e4eb892dd 100644
--- a/mlir/include/mlir/Dialect/X86Vector/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/X86Vector/CMakeLists.txt
@@ -3,3 +3,5 @@ add_mlir_doc(X86Vector X86Vector Dialects/ -gen-dialect-doc -dialect=x86vector)
 
 add_mlir_interface(X86VectorInterfaces)
 add_dependencies(MLIRX86VectorIncGen MLIRX86VectorInterfacesIncGen)
+
+add_subdirectory(TransformOps)
diff --git a/mlir/include/mlir/Dialect/X86Vector/TransformOps/CMakeLists.txt b/mlir/include/mlir/Dialect/X86Vector/TransformOps/CMakeLists.txt
new file mode 100644
index 0000000000000..6f377e10fa8f8
--- /dev/null
+++ b/mlir/include/mlir/Dialect/X86Vector/TransformOps/CMakeLists.txt
@@ -0,0 +1,4 @@
+set(LLVM_TARGET_DEFINITIONS X86VectorTransformOps.td)
+mlir_tablegen(X86VectorTransformOps.h.inc -gen-op-decls)
+mlir_tablegen(X86VectorTransformOps.cpp.inc -gen-op-defs)
+add_mlir_dialect_tablegen_target(MLIRX86VectorTransformOpsIncGen)
diff --git a/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.h b/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.h
new file mode 100644
index 0000000000000..e1d8b8762e799
--- /dev/null
+++ b/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.h
@@ -0,0 +1,31 @@
+//===- X86VectorTransformOps.h - X86Vector transform ops --------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_X86VECTOR_TRANSFORMOPS_X86VECTORTRANSFORMOPS_H
+#define MLIR_DIALECT_X86VECTOR_TRANSFORMOPS_X86VECTORTRANSFORMOPS_H
+
+#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
+#include "mlir/IR/OpImplementation.h"
+
+//===----------------------------------------------------------------------===//
+// X86Vector Transform Operations
+//===----------------------------------------------------------------------===//
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.h.inc"
+
+namespace mlir {
+class DialectRegistry;
+
+namespace x86vector {
+void registerTransformDialectExtension(DialectRegistry &registry);
+
+} // namespace x86vector
+} // namespace mlir
+
+#endif // MLIR_DIALECT_X86VECTOR_TRANSFORMOPS_X86VECTORTRANSFORMOPS_H
diff --git a/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td b/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td
new file mode 100644
index 0000000000000..4009a140bb097
--- /dev/null
+++ b/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td
@@ -0,0 +1,42 @@
+//===- X86VectorTransformOps.td - X86Vector transform ops ---*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef X86VECTOR_TRANSFORM_OPS
+#define X86VECTOR_TRANSFORM_OPS
+
+include "mlir/Dialect/Transform/IR/TransformDialect.td"
+include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.td"
+include "mlir/Interfaces/SideEffectInterfaces.td"
+include "mlir/IR/OpBase.td"
+include "mlir/Dialect/Transform/IR/TransformAttrs.td"
+include "mlir/Dialect/Transform/IR/TransformTypes.td"
+include "mlir/IR/RegionKindInterface.td"
+
+def ApplyVectorContractToFMAPatternsOp : Op<Transform_Dialect,
+    "apply_patterns.x86vector.vector_contract_to_fma",
+    [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+  let description = [{
+    Indicates that vector contract operation can be lowered to a FMA.
+  }];
+  
+  let assemblyFormat = "attr-dict";
+}
+
+def ApplyVectorContractToPackedTypeDotProductPatternsOp : Op<Transform_Dialect,
+    "apply_patterns.x86vector.vector_contract_to_packed_type_dot_product",
+    [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+  let description = [{
+    Indicates that vector contract operation can be lowered to a BF16/Int8 dot-product.
+  }];
+
+  let assemblyFormat = "attr-dict";
+}
+
+
+#endif // X86VECTOR_TRANSFORM_OPS
+
diff --git a/mlir/include/mlir/Dialect/X86Vector/Transforms.h b/mlir/include/mlir/Dialect/X86Vector/Transforms.h
index d54111ca41e69..0fda47c180971 100644
--- a/mlir/include/mlir/Dialect/X86Vector/Transforms.h
+++ b/mlir/include/mlir/Dialect/X86Vector/Transforms.h
@@ -11,6 +11,10 @@
 
 #include "mlir/IR/Value.h"
 
+#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/PatternMatch.h"
+
 namespace mlir {
 
 class ImplicitLocOpBuilder;
@@ -79,6 +83,12 @@ struct MaskHelper {
   }
 };
 
+//===----------------------------------------------------------------------===//
+
+void populateVectorContractToFMAPatterns(RewritePatternSet &patterns);
+
+void populateVectorContractToPackedTypeDotProductPatterns(RewritePatternSet &patterns);
+
 //===----------------------------------------------------------------------===//
 /// Helpers extracted from:
 ///   - clang/lib/Headers/avxintrin.h
diff --git a/mlir/lib/Dialect/X86Vector/CMakeLists.txt b/mlir/lib/Dialect/X86Vector/CMakeLists.txt
index 9f57627c321fb..cb1e9d01821a2 100644
--- a/mlir/lib/Dialect/X86Vector/CMakeLists.txt
+++ b/mlir/lib/Dialect/X86Vector/CMakeLists.txt
@@ -1,2 +1,3 @@
 add_subdirectory(IR)
 add_subdirectory(Transforms)
+add_subdirectory(TransformOps)
diff --git a/mlir/lib/Dialect/X86Vector/TransformOps/CMakeLists.txt b/mlir/lib/Dialect/X86Vector/TransformOps/CMakeLists.txt
new file mode 100644
index 0000000000000..f4c9f8a05acbc
--- /dev/null
+++ b/mlir/lib/Dialect/X86Vector/TransformOps/CMakeLists.txt
@@ -0,0 +1,17 @@
+add_mlir_dialect_library(MLIRX86VectorTransformOps
+  X86VectorTransformOps.cpp
+
+  DEPENDS
+  MLIRX86VectorTransformOpsIncGen
+
+  LINK_LIBS PUBLIC
+  MLIRIR
+  MLIRLLVMCommonConversion
+  MLIRLLVMDialect
+  MLIRVectorDialect
+  MLIRSideEffectInterfaces
+  MLIRTransformDialect
+  MLIRTransformDialectUtils
+  MLIRX86VectorDialect
+  MLIRX86VectorTransforms
+  )
diff --git a/mlir/lib/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp b/mlir/lib/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp
new file mode 100644
index 0000000000000..39c5cf5d9bb54
--- /dev/null
+++ b/mlir/lib/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp
@@ -0,0 +1,65 @@
+//===- X86VectorTransformOps.cpp - Implementation of Vector transform ops
+//--===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.h"
+#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/Transform/IR/TransformDialect.h"
+#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/X86Vector/Transforms.h"
+#include "mlir/Dialect/X86Vector/X86VectorDialect.h"
+
+#include "mlir/IR/OpImplementation.h"
+#include "mlir/IR/RegionKindInterface.h"
+
+using namespace mlir;
+using namespace mlir::x86vector;
+using namespace mlir::transform;
+
+void mlir::transform::ApplyVectorContractToFMAPatternsOp::
+    populatePatterns(RewritePatternSet &patterns) {
+  x86vector::populateVectorContractToFMAPatterns(patterns);
+}
+
+void mlir::transform::ApplyVectorContractToPackedTypeDotProductPatternsOp::
+    populatePatterns(RewritePatternSet &patterns) {
+  x86vector::populateVectorContractToPackedTypeDotProductPatterns(patterns);
+}
+
+//===----------------------------------------------------------------------===//
+// Transform op registration
+//===----------------------------------------------------------------------===//
+
+namespace {
+class X86VectorTransformDialectExtension
+    : public transform::TransformDialectExtension<
+          X86VectorTransformDialectExtension> {
+public:
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
+      X86VectorTransformDialectExtension)
+
+  X86VectorTransformDialectExtension() {
+    declareGeneratedDialect<x86vector::X86VectorDialect>();
+    declareGeneratedDialect<LLVM::LLVMDialect>();
+    registerTransformOps<
+#define GET_OP_LIST
+#include "mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp.inc"
+        >();
+  }
+};
+} // namespace
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp.inc"
+
+void mlir::x86vector::registerTransformDialectExtension(
+    DialectRegistry &registry) {
+  registry.addExtensions<X86VectorTransformDialectExtension>();
+}
diff --git a/mlir/lib/Dialect/X86Vector/Transforms/CMakeLists.txt b/mlir/lib/Dialect/X86Vector/Transforms/CMakeLists.txt
index c51266afe9e8f..3d2288049e49e 100644
--- a/mlir/lib/Dialect/X86Vector/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/X86Vector/Transforms/CMakeLists.txt
@@ -1,6 +1,8 @@
 add_mlir_dialect_library(MLIRX86VectorTransforms
   AVXTranspose.cpp
   LegalizeForLLVMExport.cpp
+  VectorContractToFMA.cpp
+  VectorContractToPackedTypeDotProduct.cpp
 
   LINK_LIBS PUBLIC
   MLIRArithDialect
diff --git a/mlir/lib/Dialect/X86Vector/Transforms/VectorContractToFMA.cpp b/mlir/lib/Dialect/X86Vector/Transforms/VectorContractToFMA.cpp
new file mode 100644
index 0000000000000..764ec46681094
--- /dev/null
+++ b/mlir/lib/Dialect/X86Vector/Transforms/VectorContractToFMA.cpp
@@ -0,0 +1,99 @@
+//===- VectorContractToFMA.cpp --------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
+#include "mlir/Dialect/X86Vector/Transforms.h"
+#include "mlir/Dialect/X86Vector/X86VectorDialect.h"
+
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/Dominance.h"
+#include "mlir/IR/PatternMatch.h"
+
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+using namespace mlir;
+using namespace mlir::vector;
+using namespace mlir::x86vector;
+
+// Implements outer product contraction as a sequence of broadcast and
+// FMA operations.
+//
+// For example - for F32 type:
+// ```
+//   vector.contract <1x1xf32>, <1x16xf32> into <1x16xf32>
+// ```
+// to
+// ```
+//   vector.broadcast %lhs to <16xf32>
+//   vector.fma vector<16xf32>
+// ```
+struct VectorContractToFMA : public OpRewritePattern<vector::ContractionOp> {
+  using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
+                                PatternRewriter &rewriter) const override {
+
+    if (contractOp.getKind() != vector::CombiningKind::ADD) {
+      return rewriter.notifyMatchFailure(contractOp,
+                                         "Expects add combining kind");
+    }
+
+    VectorType lhsTy = contractOp.getLhsType();
+    if (!lhsTy.getElementType().isF32())
+      return rewriter.notifyMatchFailure(contractOp,
+                                         "Only F32 lowering is supported.");
+    if (llvm::any_of(lhsTy.getShape(), [](int64_t dim) { return dim != 1; }))
+      return rewriter.notifyMatchFailure(
+          contractOp, "Expects one for all dimensions of LHS");
+
+    VectorType rhsTy = contractOp.getRhsType();
+    ArrayRef<int64_t> rhsShape = rhsTy.getShape();
+    llvm::SmallVector<int64_t> dimsRhs;
+    llvm::copy_if(rhsShape, std::back_inserter(dimsRhs),
+                  [](int64_t dim) { return dim != 1; });
+    if (dimsRhs.size() != 1)
+      return rewriter.notifyMatchFailure(contractOp, "Irregular RHS shape");
+
+    VectorType accTy = dyn_cast<VectorType>(contractOp.getAccType());
+    assert(accTy && "Invalid accumulator");
+    ArrayRef<int64_t> accShape = accTy.getShape();
+    llvm::SmallVector<int64_t> dimsAcc;
+    llvm::copy_if(accShape, std::back_inserter(dimsAcc),
+                  [](int64_t dim) { return dim != 1; });
+    if (dimsAcc.size() != 1)
+      return rewriter.notifyMatchFailure(contractOp, "Irregular ACC shape");
+
+    // Lowers vector.contract into a broadcast+FMA sequence.
+    auto loc = contractOp.getLoc();
+    auto castLhs = vector::ShapeCastOp::create(
+        rewriter, loc, VectorType::get(1, lhsTy.getElementType()),
+        contractOp.getLhs());
+    auto castRhs = vector::ShapeCastOp::create(
+        rewriter, loc, VectorType::get(dimsRhs.front(), rhsTy.getElementType()),
+        contractOp.getRhs());
+    auto castAcc = vector::ShapeCastOp::create(
+        rewriter, loc, VectorType::get(dimsAcc.front(), accTy.getElementType()),
+        contractOp.getAcc());
+    auto broadcastLhs = vector::BroadcastOp::create(
+        rewriter, loc, castRhs.getResult().getType(), castLhs);
+    auto fma =
+        vector::FMAOp::create(rewriter, loc, broadcastLhs, castRhs, castAcc);
+    auto castFma = vector::ShapeCastOp::create(rewriter, loc, accTy, fma);
+
+    rewriter.replaceOp(contractOp, castFma);
+
+    return success();
+  }
+};
+
+void x86vector::populateVectorContractToFMAPatterns(
+    RewritePatternSet &patterns) {
+  patterns.add<VectorContractToFMA>(patterns.getContext());
+}
diff --git a/mlir/lib/Dialect/X86Vector/Transforms/VectorContractToPackedTypeDotProduct.cpp b/mlir/lib/Dialect/X86Vector/Transforms/VectorContractToPackedTypeDotProduct.cpp
new file mode 100644
index 0000000000000..1dabbddbebb7e
--- /dev/null
+++ b/mlir/lib/Dialect/X86Vector/Transforms/VectorContractToPackedTypeDotProduct.cpp
@@ -0,0 +1,148 @@
+//===- VectorContractToPackedTypeDotProduct.cpp ---------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
+#include "mlir/Dialect/X86Vector/Transforms.h"
+#include "mlir/Dialect/X86Vector/X86VectorDialect.h"
+
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/Dominance.h"
+#include "mlir/IR/PatternMatch.h"
+
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+using namespace mlir;
+using namespace mlir::vector;
+using namespace mlir::x86vector;
+
+// Implements packed type outer product contraction as a sequence
+// of broadcast and packed dot-product operations.
+//
+// For example - for F32 type:
+// ```
+//   vector.contract <1x1x2xbf16>, <1x16x2xbf16> into <1x16xf32>
+// ```
+// to
+// ```
+//   vector.broadcast %lhs to <32xbf16>
+//   x86vector.avx512.dot vector<32xbf16> -> vector<16xf32>
+// ```
+struct VectorContractToPackedTypeDotProduct
+    : public OpRewritePattern<vector::ContractionOp> {
+  using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
+                                PatternRewriter &rewriter) const override {
+
+    if (contractOp.getKind() != vector::CombiningKind::ADD) {
+      return rewriter.notifyMatchFailure(contractOp,
+                                         "Expects add combining kind");
+    }
+
+    VectorType lhsTy = contractOp.getLhsType();
+    if (!lhsTy.getElementType().isBF16() &&
+        !lhsTy.getElementType().isSignlessInteger(8))
+      return rewriter.notifyMatchFailure(
+          contractOp, "Only BF16/Int8 lowering is supported.");
+    ArrayRef<int64_t> lhsShape = lhsTy.getShape();
+    if (lhsTy.getElementType().isBF16() && lhsShape.back() != 2)
+      return rewriter.notifyMatchFailure(
+          contractOp, "The LHS vnni dim should be 2 for BF16.");
+
+    if (lhsTy.getElementType().isSignlessInteger(8) && lhsShape.back() != 4)
+      return rewriter.notifyMatchFailure(
+          contractOp, "The LHS vnni dim should be 4 for Int8.");
+    llvm::SmallVector<int64_t> dimsLhs;
+    llvm::copy_if(lhsShape, std::back_inserter(dimsLhs),
+                  [](int64_t dim) { return dim != 1; });
+    if (dimsLhs.size() != 1)
+      return rewriter.notifyMatchFailure(contractOp, "Irregular LHS shape");
+
+    VectorType rhsTy = contractOp.getRhsType();
+    ArrayRef<int64_t> rhsShape = rhsTy.getShape();
+    if (lhsTy.getElementType().isBF16() && rhsShape.back() != 2)
+      return rewriter.notifyMatchFailure(
+          contractOp, "The RHS vnni dim should be 2 for BF16.");
+    if (lhsTy.getElementType().isSignlessInteger(8) && rhsShape.back() != 4)
+      return rewriter.notifyMatchFailure(
+          contractOp, "The RHS vnni dim should be 4 for Int8.");
+    llvm::SmallVector<int64_t> dimsRhs;
+    llvm::copy_if(rhsShape, std::back_inserter(dimsRhs),
+                  [](int64_t dim) { return dim != 1; });
+    if (dimsRhs.size() != 2)
+      return rewriter.notifyMatchFailure(contractOp, "Irregular RHS shape");
+
+    VectorType accTy = dyn_cast<VectorType>(contractOp.getAccType());
+    assert(accTy && "Invalid accumulator");
+    if (!accTy.getElementType().isF32() &&
+        !accTy.getElementType().isSignlessInteger(32))
+      return rewriter.notifyMatchFailure(
+          contractOp, "Only F32/Int32 accumulation is supported.");
+    ArrayRef<int64_t> accShape = accTy.getShape();
+    llvm::SmallVector<int64_t> dimsAcc;
+    llvm::copy_if(accShape, std::back_inserter(dimsAcc),
+                  [](int64_t dim) { return dim != 1; });
+    if (dimsAcc.size() != 1)
+      return rewriter.notifyMatchFailure(contractOp, "Irregular ACC shape");
+
+    auto loc = contractOp.getLoc();
+    auto castRhs = vector::ShapeCastOp::create(
+        rewriter, loc,
+        VectorType::get(dimsRhs.front() * dimsRhs.back(),
+                        rhsTy.getElementType()),
+        contractOp.getRhs());
+
+    auto castAcc = vector::ShapeCastOp::create(
+        rewriter, loc, VectorType::get(dimsAcc.front(), accTy.getElementType()),
+        contractOp.getAcc());
+
+    auto castLhs = vector::ShapeCastOp::create(
+        rewriter, loc, VectorType::get(dimsLhs.front(), lhsTy.getElementType()),
+        contractOp.getLhs());
+    auto bitcastLhs = vector::BitCastOp::create(
+        rewriter, loc, VectorType::get({1}, rewriter.getIntegerType(32)),
+        castLhs);
+    auto broadcastLhs = vector::BroadcastOp::create(
+        rewriter, loc,
+        VectorType::get({dimsRhs.front()}, rewriter.getIntegerType(32)),
+        bitcastLhs);
+    auto bitcastLhsPkType = vector::BitCastOp::create(
+        rewriter, loc, castRhs.getResult().getType(), broadcastLhs);
+
+    Value dp;
+
+    if (lhsTy.getElementType().isBF16()) {
+      dp = x86vector::DotBF16Op::create(
+          rewriter, loc,
+          VectorType::get(dimsRhs.front(), rewriter.getF32Type()), castAcc,
+          bitcastLhsPkType, castRhs);
+    }
+
+    if (lhsTy.getElementType().isSignlessInteger(8)) {
+      dp = x86vector::DotInt8Op::create(
+          rewriter, loc,
+          VectorType::get(dimsRhs.front(), rewriter.getIntegerType(32)),
+          castAcc, bitcastLhsPkType, castRhs);
+    }
+
+    if (dp) {
+      auto castDp = vector::ShapeCastOp::create(rewriter, loc, accTy, dp);
+      rewriter.replaceOp(contractOp, castDp);
+      return success();
+    }
+
+    return failure();
+  }
+};
+
+void x86vector::populateVectorContractToPackedTypeDotProductPatterns(
+    RewritePatternSet &patterns) {
+  patterns.add<VectorContractToPackedTypeDotProduct>(patterns.getContext());
+}
diff --git a/mlir/lib/RegisterAllExtensions.cpp b/mlir/lib/RegisterAllExtensions.cpp
index c857c38df717c..4312100a0c0b0 100644
--- a/mlir/lib/RegisterAllExtensions.cpp
+++ b/mlir/lib/RegisterAllExtensions.cpp
@@ -56,6 +56,7 @@
 #include "mlir/Dialect/Transform/SMTExtension/SMTExtension.h"
 #include "mlir/Dialect/Transform/TuneExtension/TuneExtension.h"
 #include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.h"
+#include "mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.h"
 #include "mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.h"
 #include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h"
 #include "mlir/Target/LLVMIR/Dialect/GPU/GPUToLLVMIRTranslation.h"
@@ -113,6 +114,7 @@ void mlir::registerAllExtensions(DialectRegistry &registry) {
   transform::registerSMTExtension(registry);
   transform::registerTuneExtension(registry);
   vector::registerTransformDialectExtension(registry);
+  x86vector::registerTransformDialectExtension(registry);
   xegpu::registerTransformDialectExtension(registry);
   arm_neon::registerTransformDialectExtension(registry);
   arm_sve::registerTransformDialectExtension(registry);
diff --git a/mlir/test/Dialect/X86Vector/vector-contract-to-fma.mlir b/mlir/test/Dialect/X86Vector/vector-contract-to-fma.mlir
new file mode 100644
index 0000000000000..3a79037ca37c2
--- /dev/null
+++ b/mlir/test/Dialect/X86Vector/vector-contract-to-fma.mlir
@@ -0,0 +1,210 @@
+// RUN: mlir-opt %s -transform-interpreter -cse -split-input-file | FileCheck %s
+
+!vecA = vector<1x1xf32>
+!vecB = vector<1x64xf32>
+!vecC = vector<1x64xf32>
+#map = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
+func.func @matmul_outer_product_to_fma(
+  %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC
+{
+  %0 = vector.contract {
+    indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["parallel", "parallel", "reduction"],
+    kind = #vector.kind<add>}
+    %arg0, %arg1, %arg2
+    : !vecA, !vecB into !vecC
+  return %0 : !vecC
+}
+
+// CHECK-LABEL: @matmul_outer_product_to_fma
+// CHECK-COUNT-1: vector.shape_cast{{.*}}to vector<1xf32>
+// CHECK-COUNT-2: vector.shape_cast{{.*}}to vector<64xf32>
+// CHECK: vector.broadcast{{.*}}vector<1xf32> to vector<64xf32>
+// CHECK: vector.fma{{.*}}vector<64xf32>
+// CHECK: vector.shape_cast{{.*}}vector<64xf32> to vector<1x64xf32>
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.x86vector.vector_contract_to_fma
+    } : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+!vecA = vector<1x1x1xf32>
+!vecB = vector<1x1x64xf32>
+!vecC = vector<1x1x64xf32>
+#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
+#map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+func.func @batch_matmul_to_fma(
+  %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC
+{
+  %0 = vector.contract {
+    indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["parallel", "parallel", "parallel", "reduction"],
+    kind = #vector.kind<add>}
+    %arg0, %arg1, %arg2
+    : !vecA, !vecB into !vecC
+  return %0 : !vecC
+}
+
+// CHECK-LABEL: @batch_matmul_to_fma
+// CHECK: vector.broadcast
+// CHECK: vector.fma{{.*}}vector<64xf32>
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.x86vector.vector_contract_to_fma
+    } : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+!vecA = vector<1x1x1xf32>
+!vecB = vector<1x1x64xf32>
+!vecC = vector<1x64xf32>
+#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
+#map2 = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+func.func @brgemm_to_fma(
+  %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC
+{
+  %0 = vector.contract {
+    indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["reduction", "parallel", "parallel", "reduction"],
+    kind = #vector.kind<add>}
+    %arg0, %arg1, %arg2
+    : !vecA, !vecB into !vecC
+  return %0 : !vecC
+}
+
+// CHECK-LABEL: @brgemm_to_fma
+// CHECK: vector.broadcast
+// CHECK: vector.fma{{.*}}vector<64xf32>
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.x86vector.vector_contract_to_fma
+    } : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+!vecA = vector<3x1x1xf32>
+!vecB = vector<3x1x64xf32>
+!vecC = vector<3x1x64xf32>
+#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
+#map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+func.func @negative_non_unit_batch_dim(
+  %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC
+{
+  %0 = vector.contract {
+    indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["parallel", "parallel", "parallel", "reduction"],
+    kind = #vector.kind<add>}
+    %arg0, %arg1, %arg2
+    : !vecA, !vecB into !vecC
+  return %0 : !vecC
+}
+
+// Batch dimension should've been simplified earlier.
+
+// CHECK-LABEL: @negative_non_unit_batch_dim
+// CHECK-NOT: vector.fma
+// CHECK: vector.contract
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.x86vector.vector_contract_to_fma
+    } : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+!vecA = vector<3x1x1xf32>
+!vecB = vector<3x1x64xf32>
+!vecC = vector<1x64xf32>
+#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
+#map2 = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+func.func @negative_non_unit_batch_reduce_dim(
+  %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC
+{
+  %0 = vector.contract {
+    indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["reduction", "parallel", "parallel", "reduction"],
+    kind = #vector.kind<add>}
+    %arg0, %arg1, %arg2
+    : !vecA, !vecB into !vecC
+  return %0 : !vecC
+}
+
+// Batch-reduce dimension should've been simplified earlier.
+
+// CHECK-LABEL: @negative_non_unit_batch_reduce_dim
+// CHECK-NOT: vector.fma
+// CHECK: vector.contract
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.x86vector.vector_contract_to_fma
+    } : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+!vecA = vector<1x1xf32>
+!vecB = vector<1x64xf32>
+!vecC = vector<1x64xf32>
+#map = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
+func.func @negative_invalid_kind(
+  %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC
+{
+  %0 = vector.contract {
+    indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["parallel", "parallel", "reduction"],
+    kind = #vector.kind<mul>}
+    %arg0, %arg1, %arg2
+    : !vecA, !vecB into !vecC
+  return %0 : !vecC
+}
+
+// CHECK-LABEL: @negative_invalid_kind
+// CHECK-NOT: vector.fma
+// CHECK: vector.contract
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.x86vector.vector_contract_to_fma
+    } : !transform.any_op
+    transform.yield
+  }
+}
diff --git a/mlir/test/Dialect/X86Vector/vector-contract-to-packed-type-dotproduct.mlir b/mlir/test/Dialect/X86Vector/vector-contract-to-packed-type-dotproduct.mlir
new file mode 100644
index 0000000000000..551f1f95ed9c0
--- /dev/null
+++ b/mlir/test/Dialect/X86Vector/vector-contract-to-packed-type-dotproduct.mlir
@@ -0,0 +1,374 @@
+// RUN: mlir-opt %s -transform-interpreter -cse -split-input-file | FileCheck %s
+
+!vecA = vector<1x1x1x2xbf16>
+!vecB = vector<1x1x16x2xbf16>
+!vecC = vector<1x16xf32>
+#map = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d3, d4)>
+#map1 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d3, d2, d4)>
+#map2 = affine_map<(d0, d4, d1, d2, d3) -> (d1, d2)>
+func.func @brgemm_to_bf16dp(
+  %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC
+{
+  %0 = vector.contract {
+    indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["reduction", "reduction", "parallel", "parallel", "reduction"],
+    kind = #vector.kind<add>}
+    %arg0, %arg1, %arg2
+    : !vecA, !vecB into !vecC
+  return %0 : !vecC
+}
+
+// CHECK-LABEL: @brgemm_to_bf16dp
+// CHECK: vector.broadcast
+// CHECK: x86vector.avx512.dot
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product
+    } : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+!vecA = vector<1x1x1x4xi8>
+!vecB = vector<1x1x8x4xi8>
+!vecC = vector<1x8xi32>
+#map = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d3, d4)>
+#map1 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d3, d2, d4)>
+#map2 = affine_map<(d0, d4, d1, d2, d3) -> (d1, d2)>
+func.func @brgemm_to_int8dp(
+  %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC
+{
+  %0 = vector.contract {
+    indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["reduction", "reduction", "parallel", "parallel", "reduction"],
+    kind = #vector.kind<add>}
+    %arg0, %arg1, %arg2
+    : !vecA, !vecB into !vecC
+  return %0 : !vecC
+}
+
+// CHECK-LABEL: @brgemm_to_int8dp
+// CHECK: vector.broadcast
+// CHECK: x86vector.avx.dot.i8
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product
+    } : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+!vecA = vector<1x1x1x2xbf16>
+!vecB = vector<1x1x16x2xbf16>
+!vecC = vector<1x1x16xf32>
+#map = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d3, d4)>
+#map1 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d3, d2, d4)>
+#map2 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d2)>
+func.func @batch_matmul_bf16dp(
+  %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC
+{
+  %0 = vector.contract {
+    indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["parallel", "reduction", "parallel", "parallel", "reduction"],
+    kind = #vector.kind<add>}
+    %arg0, %arg1, %arg2
+    : !vecA, !vecB into !vecC
+  return %0 : !vecC
+}
+
+// CHECK-LABEL: @batch_matmul_bf16dp
+// CHECK: vector.broadcast
+// CHECK: x86vector.avx512.dot
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product
+    } : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+!vecA = vector<1x1x1x4xi8>
+!vecB = vector<1x1x8x4xi8>
+!vecC = vector<1x1x8xi32>
+#map = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d3, d4)>
+#map1 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d3, d2, d4)>
+#map2 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d2)>
+func.func @batch_matmul_int8dp(
+  %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC
+{
+  %0 = vector.contract {
+    indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["parallel", "reduction", "parallel", "parallel", "reduction"],
+    kind = #vector.kind<add>}
+    %arg0, %arg1, %arg2
+    : !vecA, !vecB into !vecC
+  return %0 : !vecC
+}
+
+
+// CHECK-LABEL: @batch_matmul_int8dp
+// CHECK: vector.broadcast
+// CHECK: x86vector.avx.dot.i8
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product
+    } : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+!vecA = vector<1x1x2xbf16>
+!vecB = vector<1x16x2xbf16>
+!vecC = vector<1x16xf32>
+#map = affine_map<(d4, d1, d2, d3) -> (d1, d3, d4)>
+#map1 = affine_map<(d4, d1, d2, d3) -> (d3, d2, d4)>
+#map2 = affine_map<(d4, d1, d2, d3) -> (d1, d2)>
+func.func @matmul_outer_product_to_bf16dp(
+  %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC
+{
+  %0 = vector.contract {
+    indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["reduction", "parallel", "parallel", "reduction"],
+    kind = #vector.kind<add>}
+    %arg0, %arg1, %arg2
+    : !vecA, !vecB into !vecC
+  return %0 : !vecC
+}
+
+// CHECK-LABEL: @matmul_outer_product_to_bf16dp
+// CHECK: vector.broadcast
+// CHECK: x86vector.avx512.dot
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product
+    } : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+!vecA = vector<1x1x4xi8>
+!vecB = vector<1x8x4xi8>
+!vecC = vector<1x8xi32>
+#map = affine_map<(d4, d1, d2, d3) -> (d1, d3, d4)>
+#map1 = affine_map<(d4, d1, d2, d3) -> (d3, d2, d4)>
+#map2 = affine_map<(d4, d1, d2, d3) -> (d1, d2)>
+func.func @matmul_outer_product_to_int8dp(
+  %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC
+{
+  %0 = vector.contract {
+    indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["reduction", "parallel", "parallel", "reduction"],
+    kind = #vector.kind<add>}
+    %arg0, %arg1, %arg2
+    : !vecA, !vecB into !vecC
+  return %0 : !vecC
+}
+
+// CHECK-LABEL: @matmul_outer_product_to_int8dp
+// CHECK: vector.broadcast
+// CHECK: x86vector.avx.dot.i8
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product
+    } : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+!vecA = vector<1x1x2xbf16>
+!vecB = vector<1x16x2xbf16>
+!vecC = vector<1x16xf32>
+#map = affine_map<(d4, d1, d2, d3) -> (d1, d3, d4)>
+#map1 = affine_map<(d4, d1, d2, d3) -> (d3, d2, d4)>
+#map2 = affine_map<(d4, d1, d2, d3) -> (d1, d2)>
+func.func @negative_invalid_vc_kind(
+  %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC
+{
+  %0 = vector.contract {
+    indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["reduction", "parallel", "parallel", "reduction"],
+    kind = #vector.kind<mul>}
+    %arg0, %arg1, %arg2
+    : !vecA, !vecB into !vecC
+  return %0 : !vecC
+}
+
+// CHECK-LABEL: @negative_invalid_vc_kind
+// CHECK-NOT: x86vector.avx512.dot
+// CHECK: vector.contract
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product
+    } : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+!vecA = vector<1x1x1x4xbf16>
+!vecB = vector<1x1x16x4xbf16>
+!vecC = vector<1x16xf32>
+#map = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d3, d4)>
+#map1 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d3, d2, d4)>
+#map2 = affine_map<(d0, d4, d1, d2, d3) -> (d1, d2)>
+func.func @negative_false_vnni_bf16(
+  %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC
+{
+  %0 = vector.contract {
+    indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["reduction", "reduction", "parallel", "parallel", "reduction"],
+    kind = #vector.kind<add>}
+    %arg0, %arg1, %arg2
+    : !vecA, !vecB into !vecC
+  return %0 : !vecC
+}
+
+// CHECK-LABEL: @negative_false_vnni_bf16
+// CHECK-NOT: x86vector.avx512.dot
+// CHECK: vector.contract
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product
+    } : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+!vecA = vector<1x1x1x2xi8>
+!vecB = vector<1x1x8x2xi8>
+!vecC = vector<1x8xi32>
+#map = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d3, d4)>
+#map1 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d3, d2, d4)>
+#map2 = affine_map<(d0, d4, d1, d2, d3) -> (d1, d2)>
+func.func @negative_false_vnni_int8(
+  %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC
+{
+  %0 = vector.contract {
+    indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["reduction", "reduction", "parallel", "parallel", "reduction"],
+    kind = #vector.kind<add>}
+    %arg0, %arg1, %arg2
+    : !vecA, !vecB into !vecC
+  return %0 : !vecC
+}
+
+// CHECK-LABEL: @negative_false_vnni_int8
+// CHECK-NOT: x86vector.avx.dot.i8
+// CHECK: vector.contract
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product
+    } : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+!vecA = vector<3x1x1x2xbf16>
+!vecB = vector<3x1x16x2xbf16>
+!vecC = vector<3x1x16xf32>
+#map = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d3, d4)>
+#map1 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d3, d2, d4)>
+#map2 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d2)>
+func.func @negative_batch_dimension(
+  %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC
+{
+  %0 = vector.contract {
+    indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["parallel", "reduction", "parallel", "parallel", "reduction"],
+    kind = #vector.kind<add>}
+    %arg0, %arg1, %arg2
+    : !vecA, !vecB into !vecC
+  return %0 : !vecC
+}
+
+// CHECK-LABEL: @negative_batch_dimension
+// CHECK-NOT: x86vector.avx512.dot
+// CHECK: vector.contract
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product
+    } : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+!vecA = vector<2x1x1x4xi8>
+!vecB = vector<2x1x8x4xi8>
+!vecC = vector<1x8xi32>
+#map = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d3, d4)>
+#map1 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d3, d2, d4)>
+#map2 = affine_map<(d0, d4, d1, d2, d3) -> (d1, d2)>
+func.func @negative_brgemm_dimension(
+  %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC
+{
+  %0 = vector.contract {
+    indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["reduction", "reduction", "parallel", "parallel", "reduction"],
+    kind = #vector.kind<add>}
+    %arg0, %arg1, %arg2
+    : !vecA, !vecB into !vecC
+  return %0 : !vecC
+}
+
+// CHECK-LABEL: @negative_brgemm_dimension
+// CHECK-NOT: x86vector.avx.dot.i8
+// CHECK: vector.contract
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product
+    } : !transform.any_op
+    transform.yield
+  }
+}

>From a7366b22238e084667733f38091eb4d77f6421de Mon Sep 17 00:00:00 2001
From: Arun Thangamani <arun.thangamani at intel.com>
Date: Fri, 14 Nov 2025 07:49:59 -0800
Subject: [PATCH 2/6] fix clang issues

---
 mlir/include/mlir/Dialect/X86Vector/Transforms.h              | 3 ++-
 .../Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp  | 4 ++--
 2 files changed, 4 insertions(+), 3 deletions(-)

diff --git a/mlir/include/mlir/Dialect/X86Vector/Transforms.h b/mlir/include/mlir/Dialect/X86Vector/Transforms.h
index 0fda47c180971..943d7182d1960 100644
--- a/mlir/include/mlir/Dialect/X86Vector/Transforms.h
+++ b/mlir/include/mlir/Dialect/X86Vector/Transforms.h
@@ -87,7 +87,8 @@ struct MaskHelper {
 
 void populateVectorContractToFMAPatterns(RewritePatternSet &patterns);
 
-void populateVectorContractToPackedTypeDotProductPatterns(RewritePatternSet &patterns);
+void populateVectorContractToPackedTypeDotProductPatterns(
+    RewritePatternSet &patterns);
 
 //===----------------------------------------------------------------------===//
 /// Helpers extracted from:
diff --git a/mlir/lib/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp b/mlir/lib/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp
index 39c5cf5d9bb54..68d577326a308 100644
--- a/mlir/lib/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp
+++ b/mlir/lib/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp
@@ -23,8 +23,8 @@ using namespace mlir;
 using namespace mlir::x86vector;
 using namespace mlir::transform;
 
-void mlir::transform::ApplyVectorContractToFMAPatternsOp::
-    populatePatterns(RewritePatternSet &patterns) {
+void mlir::transform::ApplyVectorContractToFMAPatternsOp::populatePatterns(
+    RewritePatternSet &patterns) {
   x86vector::populateVectorContractToFMAPatterns(patterns);
 }
 

>From 416efac3ced77e2008145ee41c513331f62a821e Mon Sep 17 00:00:00 2001
From: Arun Thangamani <arun.thangamani at intel.com>
Date: Thu, 20 Nov 2025 05:15:39 -0800
Subject: [PATCH 3/6] VNNI check, new test-cass and B broadcast

---
 .../TransformOps/X86VectorTransformOps.td     |   5 +-
 .../TransformOps/X86VectorTransformOps.cpp    |   3 +-
 .../Transforms/VectorContractToFMA.cpp        |  65 +++--
 .../VectorContractToPackedTypeDotProduct.cpp  | 233 ++++++++++++++----
 .../X86Vector/vector-contract-to-fma.mlir     | 105 +++++++-
 ...or-contract-to-packed-type-dotproduct.mlir | 205 +++++++++++++++
 6 files changed, 539 insertions(+), 77 deletions(-)

diff --git a/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td b/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td
index 4009a140bb097..6192d5e31ffc5 100644
--- a/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td
+++ b/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td
@@ -21,7 +21,7 @@ def ApplyVectorContractToFMAPatternsOp : Op<Transform_Dialect,
     "apply_patterns.x86vector.vector_contract_to_fma",
     [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
   let description = [{
-    Indicates that vector contract operation can be lowered to a FMA.
+    Collect patterns to lower a F32 type vector.contract operation to a FMA.
   }];
   
   let assemblyFormat = "attr-dict";
@@ -31,7 +31,8 @@ def ApplyVectorContractToPackedTypeDotProductPatternsOp : Op<Transform_Dialect,
     "apply_patterns.x86vector.vector_contract_to_packed_type_dot_product",
     [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
   let description = [{
-    Indicates that vector contract operation can be lowered to a BF16/Int8 dot-product.
+    Collect patterns to lower a BF16/Int8 type vector.contract operation 
+	to a BF16/Int8 dot-product.
   }];
 
   let assemblyFormat = "attr-dict";
diff --git a/mlir/lib/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp b/mlir/lib/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp
index 68d577326a308..980c585848080 100644
--- a/mlir/lib/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp
+++ b/mlir/lib/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp
@@ -1,5 +1,4 @@
-//===- X86VectorTransformOps.cpp - Implementation of Vector transform ops
-//--===//
+//===- X86VectorTransformOps.cpp -============================================//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
diff --git a/mlir/lib/Dialect/X86Vector/Transforms/VectorContractToFMA.cpp b/mlir/lib/Dialect/X86Vector/Transforms/VectorContractToFMA.cpp
index 764ec46681094..9349466ba1a34 100644
--- a/mlir/lib/Dialect/X86Vector/Transforms/VectorContractToFMA.cpp
+++ b/mlir/lib/Dialect/X86Vector/Transforms/VectorContractToFMA.cpp
@@ -40,29 +40,38 @@ struct VectorContractToFMA : public OpRewritePattern<vector::ContractionOp> {
   LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
                                 PatternRewriter &rewriter) const override {
 
-    if (contractOp.getKind() != vector::CombiningKind::ADD) {
+    if (contractOp.getKind() != vector::CombiningKind::ADD)
       return rewriter.notifyMatchFailure(contractOp,
                                          "Expects add combining kind");
-    }
 
     VectorType lhsTy = contractOp.getLhsType();
     if (!lhsTy.getElementType().isF32())
       return rewriter.notifyMatchFailure(contractOp,
                                          "Only F32 lowering is supported.");
-    if (llvm::any_of(lhsTy.getShape(), [](int64_t dim) { return dim != 1; }))
-      return rewriter.notifyMatchFailure(
-          contractOp, "Expects one for all dimensions of LHS");
+
+    ArrayRef<int64_t> lhsShape = lhsTy.getShape();
+    llvm::SmallVector<int64_t> dimsLhs;
+    llvm::copy_if(lhsShape, std::back_inserter(dimsLhs),
+                  [](int64_t dim) { return dim != 1; });
 
     VectorType rhsTy = contractOp.getRhsType();
     ArrayRef<int64_t> rhsShape = rhsTy.getShape();
     llvm::SmallVector<int64_t> dimsRhs;
     llvm::copy_if(rhsShape, std::back_inserter(dimsRhs),
                   [](int64_t dim) { return dim != 1; });
-    if (dimsRhs.size() != 1)
-      return rewriter.notifyMatchFailure(contractOp, "Irregular RHS shape");
+
+    if (dimsLhs.size() > 0 && dimsRhs.size() > 0)
+      return rewriter.notifyMatchFailure(
+          contractOp, "Excepts unit dimensions for either LHS or RHS shape.");
+
+    if (dimsLhs.size() != 1 && dimsRhs.size() != 1)
+      return rewriter.notifyMatchFailure(contractOp,
+                                         "Irregular LHS or RHS shape.");
 
     VectorType accTy = dyn_cast<VectorType>(contractOp.getAccType());
-    assert(accTy && "Invalid accumulator");
+    if (!accTy)
+      return rewriter.notifyMatchFailure(contractOp, "Wrong accmulator type");
+
     ArrayRef<int64_t> accShape = accTy.getShape();
     llvm::SmallVector<int64_t> dimsAcc;
     llvm::copy_if(accShape, std::back_inserter(dimsAcc),
@@ -72,21 +81,39 @@ struct VectorContractToFMA : public OpRewritePattern<vector::ContractionOp> {
 
     // Lowers vector.contract into a broadcast+FMA sequence.
     auto loc = contractOp.getLoc();
-    auto castLhs = vector::ShapeCastOp::create(
-        rewriter, loc, VectorType::get(1, lhsTy.getElementType()),
-        contractOp.getLhs());
-    auto castRhs = vector::ShapeCastOp::create(
-        rewriter, loc, VectorType::get(dimsRhs.front(), rhsTy.getElementType()),
-        contractOp.getRhs());
     auto castAcc = vector::ShapeCastOp::create(
         rewriter, loc, VectorType::get(dimsAcc.front(), accTy.getElementType()),
         contractOp.getAcc());
-    auto broadcastLhs = vector::BroadcastOp::create(
-        rewriter, loc, castRhs.getResult().getType(), castLhs);
-    auto fma =
-        vector::FMAOp::create(rewriter, loc, broadcastLhs, castRhs, castAcc);
-    auto castFma = vector::ShapeCastOp::create(rewriter, loc, accTy, fma);
 
+    vector::FMAOp fma;
+
+    if (dimsRhs.size() > 0) {
+      auto castLhs = vector::ShapeCastOp::create(
+          rewriter, loc, VectorType::get(1, lhsTy.getElementType()),
+          contractOp.getLhs());
+      auto castRhs = vector::ShapeCastOp::create(
+          rewriter, loc,
+          VectorType::get(dimsRhs.front(), rhsTy.getElementType()),
+          contractOp.getRhs());
+      auto broadcastLhs = vector::BroadcastOp::create(
+          rewriter, loc, castRhs.getResult().getType(), castLhs);
+      fma =
+          vector::FMAOp::create(rewriter, loc, broadcastLhs, castRhs, castAcc);
+    } else {
+      auto castLhs = vector::ShapeCastOp::create(
+          rewriter, loc,
+          VectorType::get(dimsLhs.front(), lhsTy.getElementType()),
+          contractOp.getLhs());
+      auto castRhs = vector::ShapeCastOp::create(
+          rewriter, loc, VectorType::get(1, rhsTy.getElementType()),
+          contractOp.getRhs());
+      auto broadcastRhs = vector::BroadcastOp::create(
+          rewriter, loc, castLhs.getResult().getType(), castRhs);
+      fma =
+          vector::FMAOp::create(rewriter, loc, castLhs, broadcastRhs, castAcc);
+    }
+
+    auto castFma = vector::ShapeCastOp::create(rewriter, loc, accTy, fma);
     rewriter.replaceOp(contractOp, castFma);
 
     return success();
diff --git a/mlir/lib/Dialect/X86Vector/Transforms/VectorContractToPackedTypeDotProduct.cpp b/mlir/lib/Dialect/X86Vector/Transforms/VectorContractToPackedTypeDotProduct.cpp
index 1dabbddbebb7e..59082b9761135 100644
--- a/mlir/lib/Dialect/X86Vector/Transforms/VectorContractToPackedTypeDotProduct.cpp
+++ b/mlir/lib/Dialect/X86Vector/Transforms/VectorContractToPackedTypeDotProduct.cpp
@@ -6,6 +6,8 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/Dialect/Vector/Utils/VectorUtils.h"
 #include "mlir/Dialect/X86Vector/Transforms.h"
@@ -22,6 +24,90 @@ using namespace mlir;
 using namespace mlir::vector;
 using namespace mlir::x86vector;
 
+static FailureOr<SmallVector<mlir::utils::IteratorType>>
+inferIteratorsFromOutMap(AffineMap map) {
+  if (!map.isProjectedPermutation())
+    return failure();
+  SmallVector<mlir::utils::IteratorType> iterators(
+      map.getNumDims(), mlir::utils::IteratorType::reduction);
+  for (auto expr : map.getResults())
+    if (auto dim = dyn_cast<AffineDimExpr>(expr))
+      iterators[dim.getPosition()] = mlir::utils::IteratorType::parallel;
+  return iterators;
+}
+
+static bool isInVnniLayout(Operation *op, ArrayRef<AffineMap> indexingMaps,
+                           std::optional<unsigned> blockingFactor) {
+  // Narrow down type operations - VNNI only applies to contractions.
+  FailureOr<linalg::ContractionDimensions> dims =
+      linalg::inferContractionDims(indexingMaps);
+  if (failed(dims))
+    return false;
+
+  auto matA = op->getOperand(0);
+  auto matB = op->getOperand(1);
+  auto typeA = dyn_cast<ShapedType>(matA.getType());
+  auto typeB = dyn_cast<ShapedType>(matB.getType());
+  unsigned rankA = typeA.getRank();
+  unsigned rankB = typeB.getRank();
+  // VNNI format requires at least 1 parallel and 2 reduction dimensions.
+  if (rankA < 3 || rankB < 3)
+    return false;
+
+  // At least two reduction dimensions are expected:
+  // one for the VNNI factor and one for the K dimension
+  if (dims->k.size() < 2)
+    return false;
+
+  // Validate affine maps - VNNI computation should be defined by the two
+  // innermost reduction iterators.
+  // The input matrix dimensions layout must match the following:
+  //   - matrix A - [...][K/vnniFactor][vnniFactor]
+  //   - matrix B - [...][K/vnniFactor][N][vnniFactor]
+  auto maybeIters = inferIteratorsFromOutMap(indexingMaps[2]);
+  if (failed(maybeIters))
+    return false;
+  SmallVector<mlir::utils::IteratorType> iteratorTypes = *maybeIters;
+  AffineMap mapA = indexingMaps[0];
+  AffineMap mapB = indexingMaps[1];
+
+  auto vnniDimA = dyn_cast<AffineDimExpr>(mapA.getResult(rankA - 1));
+  auto vnniDimB = dyn_cast<AffineDimExpr>(mapB.getResult(rankB - 1));
+  if (!vnniDimA || !vnniDimB || vnniDimA != vnniDimB ||
+      iteratorTypes[vnniDimA.getPosition()] !=
+          mlir::utils::IteratorType::reduction)
+    return false;
+  auto redDimA = dyn_cast<AffineDimExpr>(mapA.getResult(rankA - 2));
+  auto redDimB = dyn_cast<AffineDimExpr>(mapB.getResult(rankB - 3));
+  if (!redDimA || !redDimB || redDimA != redDimB ||
+      iteratorTypes[redDimA.getPosition()] !=
+          mlir::utils::IteratorType::reduction)
+    return false;
+  auto parallelDimB = dyn_cast<AffineDimExpr>(mapB.getResult(rankB - 2));
+  if (!parallelDimB || iteratorTypes[parallelDimB.getPosition()] !=
+                           mlir::utils::IteratorType::parallel)
+    return false;
+
+  // VNNI factor must be:
+  //   - the innermost inputs' dimension
+  //   - statically known
+  //   - multiple of 2 or equal to the specified factor
+  auto vnniDimSize = typeB.getShape().back();
+  if (vnniDimSize == ShapedType::kDynamic || vnniDimSize == 0 ||
+      vnniDimSize % 2 != 0)
+    return false;
+  if (typeA.getShape().back() != vnniDimSize)
+    return false;
+  if (blockingFactor && vnniDimSize != *blockingFactor)
+    return false;
+
+  // The split reduction dimension size should also match.
+  if (typeA.getShape().end()[-2] != typeB.getShape().end()[-3])
+    return false;
+
+  return true;
+}
+
 // Implements packed type outer product contraction as a sequence
 // of broadcast and packed dot-product operations.
 //
@@ -41,50 +127,58 @@ struct VectorContractToPackedTypeDotProduct
   LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
                                 PatternRewriter &rewriter) const override {
 
-    if (contractOp.getKind() != vector::CombiningKind::ADD) {
+    if (contractOp.getKind() != vector::CombiningKind::ADD)
       return rewriter.notifyMatchFailure(contractOp,
                                          "Expects add combining kind");
-    }
 
     VectorType lhsTy = contractOp.getLhsType();
     if (!lhsTy.getElementType().isBF16() &&
         !lhsTy.getElementType().isSignlessInteger(8))
       return rewriter.notifyMatchFailure(
           contractOp, "Only BF16/Int8 lowering is supported.");
-    ArrayRef<int64_t> lhsShape = lhsTy.getShape();
-    if (lhsTy.getElementType().isBF16() && lhsShape.back() != 2)
-      return rewriter.notifyMatchFailure(
-          contractOp, "The LHS vnni dim should be 2 for BF16.");
 
-    if (lhsTy.getElementType().isSignlessInteger(8) && lhsShape.back() != 4)
-      return rewriter.notifyMatchFailure(
-          contractOp, "The LHS vnni dim should be 4 for Int8.");
+    if (lhsTy.getElementType().isBF16() &&
+        !isInVnniLayout(contractOp.getOperation(),
+                        contractOp.getIndexingMapsArray(), 2))
+      return rewriter.notifyMatchFailure(contractOp,
+                                         "Input matrices not in VNNI format");
+
+    if (lhsTy.getElementType().isSignlessInteger(8) &&
+        !isInVnniLayout(contractOp.getOperation(),
+                        contractOp.getIndexingMapsArray(), 4))
+      return rewriter.notifyMatchFailure(contractOp,
+                                         "Input matrices not in VNNI format");
+
+    ArrayRef<int64_t> lhsShape = lhsTy.getShape();
     llvm::SmallVector<int64_t> dimsLhs;
     llvm::copy_if(lhsShape, std::back_inserter(dimsLhs),
                   [](int64_t dim) { return dim != 1; });
-    if (dimsLhs.size() != 1)
-      return rewriter.notifyMatchFailure(contractOp, "Irregular LHS shape");
 
     VectorType rhsTy = contractOp.getRhsType();
     ArrayRef<int64_t> rhsShape = rhsTy.getShape();
-    if (lhsTy.getElementType().isBF16() && rhsShape.back() != 2)
-      return rewriter.notifyMatchFailure(
-          contractOp, "The RHS vnni dim should be 2 for BF16.");
-    if (lhsTy.getElementType().isSignlessInteger(8) && rhsShape.back() != 4)
-      return rewriter.notifyMatchFailure(
-          contractOp, "The RHS vnni dim should be 4 for Int8.");
     llvm::SmallVector<int64_t> dimsRhs;
     llvm::copy_if(rhsShape, std::back_inserter(dimsRhs),
                   [](int64_t dim) { return dim != 1; });
-    if (dimsRhs.size() != 2)
-      return rewriter.notifyMatchFailure(contractOp, "Irregular RHS shape");
 
-    VectorType accTy = dyn_cast<VectorType>(contractOp.getAccType());
-    assert(accTy && "Invalid accumulator");
-    if (!accTy.getElementType().isF32() &&
-        !accTy.getElementType().isSignlessInteger(32))
+    if ((dimsLhs.size() - 1) > 0 && (dimsRhs.size() - 1) > 0)
       return rewriter.notifyMatchFailure(
-          contractOp, "Only F32/Int32 accumulation is supported.");
+          contractOp, "Excepts unit dimensions for either LHS or RHS shape.");
+
+    if ((dimsLhs.size() - 1) != 1 && (dimsRhs.size() - 1) != 1)
+      return rewriter.notifyMatchFailure(contractOp,
+                                         "Irregular LHS or RHS shape.");
+
+    VectorType accTy = dyn_cast<VectorType>(contractOp.getAccType());
+    if (!accTy)
+      return rewriter.notifyMatchFailure(contractOp, "Wrong accmulator type");
+
+    if ((lhsTy.getElementType().isBF16() && !accTy.getElementType().isF32()) ||
+        (lhsTy.getElementType().isSignlessInteger(8) &&
+         !accTy.getElementType().isSignlessInteger(32)))
+      return rewriter.notifyMatchFailure(contractOp,
+                                         "Only F32 for BF16 or Int32 for Int8 "
+                                         "accumulation type is supported.");
+
     ArrayRef<int64_t> accShape = accTy.getShape();
     llvm::SmallVector<int64_t> dimsAcc;
     llvm::copy_if(accShape, std::back_inserter(dimsAcc),
@@ -93,43 +187,78 @@ struct VectorContractToPackedTypeDotProduct
       return rewriter.notifyMatchFailure(contractOp, "Irregular ACC shape");
 
     auto loc = contractOp.getLoc();
-    auto castRhs = vector::ShapeCastOp::create(
-        rewriter, loc,
-        VectorType::get(dimsRhs.front() * dimsRhs.back(),
-                        rhsTy.getElementType()),
-        contractOp.getRhs());
-
     auto castAcc = vector::ShapeCastOp::create(
         rewriter, loc, VectorType::get(dimsAcc.front(), accTy.getElementType()),
         contractOp.getAcc());
 
-    auto castLhs = vector::ShapeCastOp::create(
-        rewriter, loc, VectorType::get(dimsLhs.front(), lhsTy.getElementType()),
-        contractOp.getLhs());
-    auto bitcastLhs = vector::BitCastOp::create(
-        rewriter, loc, VectorType::get({1}, rewriter.getIntegerType(32)),
-        castLhs);
-    auto broadcastLhs = vector::BroadcastOp::create(
-        rewriter, loc,
-        VectorType::get({dimsRhs.front()}, rewriter.getIntegerType(32)),
-        bitcastLhs);
-    auto bitcastLhsPkType = vector::BitCastOp::create(
-        rewriter, loc, castRhs.getResult().getType(), broadcastLhs);
-
     Value dp;
 
-    if (lhsTy.getElementType().isBF16()) {
-      dp = x86vector::DotBF16Op::create(
+    if ((dimsRhs.size() - 1) > 0) {
+      auto castRhs = vector::ShapeCastOp::create(
           rewriter, loc,
-          VectorType::get(dimsRhs.front(), rewriter.getF32Type()), castAcc,
-          bitcastLhsPkType, castRhs);
-    }
+          VectorType::get(dimsRhs.front() * dimsRhs.back(),
+                          rhsTy.getElementType()),
+          contractOp.getRhs());
+      auto castLhs = vector::ShapeCastOp::create(
+          rewriter, loc,
+          VectorType::get(dimsLhs.front(), lhsTy.getElementType()),
+          contractOp.getLhs());
+      auto bitcastLhs = vector::BitCastOp::create(
+          rewriter, loc, VectorType::get({1}, rewriter.getIntegerType(32)),
+          castLhs);
+      auto broadcastLhs = vector::BroadcastOp::create(
+          rewriter, loc,
+          VectorType::get({dimsRhs.front()}, rewriter.getIntegerType(32)),
+          bitcastLhs);
+      auto bitcastLhsPkType = vector::BitCastOp::create(
+          rewriter, loc, castRhs.getResult().getType(), broadcastLhs);
 
-    if (lhsTy.getElementType().isSignlessInteger(8)) {
-      dp = x86vector::DotInt8Op::create(
+      if (lhsTy.getElementType().isBF16()) {
+        dp = x86vector::DotBF16Op::create(
+            rewriter, loc,
+            VectorType::get(dimsRhs.front(), rewriter.getF32Type()), castAcc,
+            bitcastLhsPkType, castRhs);
+      }
+
+      if (lhsTy.getElementType().isSignlessInteger(8)) {
+        dp = x86vector::DotInt8Op::create(
+            rewriter, loc,
+            VectorType::get(dimsRhs.front(), rewriter.getIntegerType(32)),
+            castAcc, bitcastLhsPkType, castRhs);
+      }
+    } else {
+      auto castLhs = vector::ShapeCastOp::create(
+          rewriter, loc,
+          VectorType::get(dimsLhs.front() * dimsLhs.back(),
+                          lhsTy.getElementType()),
+          contractOp.getLhs());
+      auto castRhs = vector::ShapeCastOp::create(
+          rewriter, loc,
+          VectorType::get(dimsRhs.front(), rhsTy.getElementType()),
+          contractOp.getRhs());
+      auto bitcastRhs = vector::BitCastOp::create(
+          rewriter, loc, VectorType::get({1}, rewriter.getIntegerType(32)),
+          castRhs);
+      auto broadcastRhs = vector::BroadcastOp::create(
           rewriter, loc,
-          VectorType::get(dimsRhs.front(), rewriter.getIntegerType(32)),
-          castAcc, bitcastLhsPkType, castRhs);
+          VectorType::get({dimsLhs.front()}, rewriter.getIntegerType(32)),
+          bitcastRhs);
+      auto bitcastRhsPkType = vector::BitCastOp::create(
+          rewriter, loc, castLhs.getResult().getType(), broadcastRhs);
+
+      if (lhsTy.getElementType().isBF16()) {
+        dp = x86vector::DotBF16Op::create(
+            rewriter, loc,
+            VectorType::get(dimsLhs.front(), rewriter.getF32Type()), castAcc,
+            castLhs, bitcastRhsPkType);
+      }
+
+      if (lhsTy.getElementType().isSignlessInteger(8)) {
+        dp = x86vector::DotInt8Op::create(
+            rewriter, loc,
+            VectorType::get(dimsLhs.front(), rewriter.getIntegerType(32)),
+            castAcc, castLhs, bitcastRhsPkType);
+      }
     }
 
     if (dp) {
diff --git a/mlir/test/Dialect/X86Vector/vector-contract-to-fma.mlir b/mlir/test/Dialect/X86Vector/vector-contract-to-fma.mlir
index 3a79037ca37c2..cf53710d839a1 100644
--- a/mlir/test/Dialect/X86Vector/vector-contract-to-fma.mlir
+++ b/mlir/test/Dialect/X86Vector/vector-contract-to-fma.mlir
@@ -19,8 +19,6 @@ func.func @matmul_outer_product_to_fma(
 }
 
 // CHECK-LABEL: @matmul_outer_product_to_fma
-// CHECK-COUNT-1: vector.shape_cast{{.*}}to vector<1xf32>
-// CHECK-COUNT-2: vector.shape_cast{{.*}}to vector<64xf32>
 // CHECK: vector.broadcast{{.*}}vector<1xf32> to vector<64xf32>
 // CHECK: vector.fma{{.*}}vector<64xf32>
 // CHECK: vector.shape_cast{{.*}}vector<64xf32> to vector<1x64xf32>
@@ -37,6 +35,40 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
+!vecA = vector<64x1xf32>
+!vecB = vector<1x1xf32>
+!vecC = vector<64x1xf32>
+#map = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
+func.func @matmul_outer_product_to_fma_bcst_B(
+  %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC
+{
+  %0 = vector.contract {
+    indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["parallel", "parallel", "reduction"],
+    kind = #vector.kind<add>}
+    %arg0, %arg1, %arg2
+    : !vecA, !vecB into !vecC
+  return %0 : !vecC
+}
+
+// CHECK-LABEL: @matmul_outer_product_to_fma_bcst_B
+// CHECK: vector.broadcast
+// CHECK: vector.fma{{.*}}vector<64xf32>
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.x86vector.vector_contract_to_fma
+    } : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
 !vecA = vector<1x1x1xf32>
 !vecB = vector<1x1x64xf32>
 !vecC = vector<1x1x64xf32>
@@ -71,6 +103,40 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
+!vecA = vector<1x64x1xf32>
+!vecB = vector<1x1x1xf32>
+!vecC = vector<1x64x1xf32>
+#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
+#map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+func.func @batch_matmul_to_fma_bcst_B(
+  %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC
+{
+  %0 = vector.contract {
+    indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["parallel", "parallel", "parallel", "reduction"],
+    kind = #vector.kind<add>}
+    %arg0, %arg1, %arg2
+    : !vecA, !vecB into !vecC
+  return %0 : !vecC
+}
+
+// CHECK-LABEL: @batch_matmul_to_fma_bcst_B
+// CHECK: vector.broadcast
+// CHECK: vector.fma{{.*}}vector<64xf32>
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.x86vector.vector_contract_to_fma
+    } : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
 !vecA = vector<1x1x1xf32>
 !vecB = vector<1x1x64xf32>
 !vecC = vector<1x64xf32>
@@ -105,6 +171,40 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
+!vecA = vector<1x64x1xf32>
+!vecB = vector<1x1x1xf32>
+!vecC = vector<64x1xf32>
+#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
+#map2 = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+func.func @brgemm_to_fma_bcst_B(
+  %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC
+{
+  %0 = vector.contract {
+    indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["reduction", "parallel", "parallel", "reduction"],
+    kind = #vector.kind<add>}
+    %arg0, %arg1, %arg2
+    : !vecA, !vecB into !vecC
+  return %0 : !vecC
+}
+
+// CHECK-LABEL: @brgemm_to_fma_bcst_B
+// CHECK: vector.broadcast
+// CHECK: vector.fma{{.*}}vector<64xf32>
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.x86vector.vector_contract_to_fma
+    } : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
 !vecA = vector<3x1x1xf32>
 !vecB = vector<3x1x64xf32>
 !vecC = vector<3x1x64xf32>
@@ -208,3 +308,4 @@ module attributes {transform.with_named_sequence} {
     transform.yield
   }
 }
+
diff --git a/mlir/test/Dialect/X86Vector/vector-contract-to-packed-type-dotproduct.mlir b/mlir/test/Dialect/X86Vector/vector-contract-to-packed-type-dotproduct.mlir
index 551f1f95ed9c0..268592c09a3a7 100644
--- a/mlir/test/Dialect/X86Vector/vector-contract-to-packed-type-dotproduct.mlir
+++ b/mlir/test/Dialect/X86Vector/vector-contract-to-packed-type-dotproduct.mlir
@@ -34,6 +34,40 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
+!vecA = vector<1x16x1x2xbf16>
+!vecB = vector<1x1x1x2xbf16>
+!vecC = vector<16x1xf32>
+#map = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d3, d4)>
+#map1 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d3, d2, d4)>
+#map2 = affine_map<(d0, d4, d1, d2, d3) -> (d1, d2)>
+func.func @brgemm_to_bf16dp_bcst_B(
+  %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC
+{
+  %0 = vector.contract {
+    indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["reduction", "reduction", "parallel", "parallel", "reduction"],
+    kind = #vector.kind<add>}
+    %arg0, %arg1, %arg2
+    : !vecA, !vecB into !vecC
+  return %0 : !vecC
+}
+
+// CHECK-LABEL: @brgemm_to_bf16dp_bcst_B
+// CHECK: vector.broadcast
+// CHECK: x86vector.avx512.dot
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product
+    } : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
 !vecA = vector<1x1x1x4xi8>
 !vecB = vector<1x1x8x4xi8>
 !vecC = vector<1x8xi32>
@@ -137,6 +171,41 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
+!vecA = vector<1x8x1x4xi8>
+!vecB = vector<1x1x1x4xi8>
+!vecC = vector<1x8x1xi32>
+#map = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d3, d4)>
+#map1 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d3, d2, d4)>
+#map2 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d2)>
+func.func @batch_matmul_int8dp_bcst_B(
+  %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC
+{
+  %0 = vector.contract {
+    indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["parallel", "reduction", "parallel", "parallel", "reduction"],
+    kind = #vector.kind<add>}
+    %arg0, %arg1, %arg2
+    : !vecA, !vecB into !vecC
+  return %0 : !vecC
+}
+
+
+// CHECK-LABEL: @batch_matmul_int8dp_bcst_B
+// CHECK: vector.broadcast
+// CHECK: x86vector.avx.dot.i8
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product
+    } : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
 !vecA = vector<1x1x2xbf16>
 !vecB = vector<1x16x2xbf16>
 !vecC = vector<1x16xf32>
@@ -171,6 +240,40 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
+!vecA = vector<16x1x2xbf16>
+!vecB = vector<1x1x2xbf16>
+!vecC = vector<16x1xf32>
+#map = affine_map<(d4, d1, d2, d3) -> (d1, d3, d4)>
+#map1 = affine_map<(d4, d1, d2, d3) -> (d3, d2, d4)>
+#map2 = affine_map<(d4, d1, d2, d3) -> (d1, d2)>
+func.func @matmul_outer_product_to_bf16dp_bcst_B(
+  %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC
+{
+  %0 = vector.contract {
+    indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["reduction", "parallel", "parallel", "reduction"],
+    kind = #vector.kind<add>}
+    %arg0, %arg1, %arg2
+    : !vecA, !vecB into !vecC
+  return %0 : !vecC
+}
+
+// CHECK-LABEL: @matmul_outer_product_to_bf16dp_bcst_B
+// CHECK: vector.broadcast
+// CHECK: x86vector.avx512.dot
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product
+    } : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
 !vecA = vector<1x1x4xi8>
 !vecB = vector<1x8x4xi8>
 !vecC = vector<1x8xi32>
@@ -372,3 +475,105 @@ module attributes {transform.with_named_sequence} {
     transform.yield
   }
 }
+
+// -----
+
+!vecA = vector<1x1x1x2xbf16>
+!vecB = vector<1x1x16x2xbf16>
+!vecC = vector<1x1x16xbf16>
+#map = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d3, d4)>
+#map1 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d3, d2, d4)>
+#map2 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d2)>
+func.func @negative_acc_type(
+  %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC
+{
+  %0 = vector.contract {
+    indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["parallel", "reduction", "parallel", "parallel", "reduction"],
+    kind = #vector.kind<add>}
+    %arg0, %arg1, %arg2
+    : !vecA, !vecB into !vecC
+  return %0 : !vecC
+}
+
+// CHECK-LABEL: @negative_acc_type
+// CHECK-NOT: x86vector.avx512.dot
+// CHECK: vector.contract
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product
+    } : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+!vecA = vector<1x1x1x4xbf16>
+!vecB = vector<1x1x16x4xbf16>
+!vecC = vector<1x1x16xbf16>
+#map = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d3, d4)>
+#map1 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d3, d2, d4)>
+#map2 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d2)>
+func.func @negative_acc_type(
+  %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC
+{
+  %0 = vector.contract {
+    indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["parallel", "reduction", "parallel", "parallel", "reduction"],
+    kind = #vector.kind<add>}
+    %arg0, %arg1, %arg2
+    : !vecA, !vecB into !vecC
+  return %0 : !vecC
+}
+
+// CHECK-LABEL: @negative_acc_type
+// CHECK-NOT: x86vector.avx512.dot
+// CHECK: vector.contract
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product
+    } : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+!vecA = vector<1x1x1xbf16>
+!vecB = vector<1x1x32xbf16>
+!vecC = vector<1x32xf32>
+#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
+#map2 = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+func.func @negative_brgemm_not_vnni(
+  %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC
+{
+  %0 = vector.contract {
+    indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["reduction", "parallel", "parallel", "reduction"],
+    kind = #vector.kind<add>}
+    %arg0, %arg1, %arg2
+    : !vecA, !vecB into !vecC
+  return %0 : !vecC
+}
+
+// CHECK-LABEL: @negative_brgemm_not_vnni
+// CHECK-NOT: x86vector.avx512.dot
+// CHECK: vector.contract
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product
+    } : !transform.any_op
+    transform.yield
+  }
+}

>From 942699fe8c8599798b102a648a174ae522576d9f Mon Sep 17 00:00:00 2001
From: Arun Thangamani <arun.thangamani at intel.com>
Date: Fri, 21 Nov 2025 00:29:33 -0800
Subject: [PATCH 4/6] formatting, few more test-cases, and dp vector shape
 checks.

---
 .../TransformOps/X86VectorTransformOps.td     |   2 +-
 .../TransformOps/X86VectorTransformOps.cpp    |   2 +-
 .../Transforms/VectorContractToFMA.cpp        |  46 +++++---
 .../VectorContractToPackedTypeDotProduct.cpp  |  98 +++++++++-------
 .../X86Vector/vector-contract-to-fma.mlir     |  33 ++++++
 ...or-contract-to-packed-type-dotproduct.mlir | 110 +++++++++++++++++-
 6 files changed, 226 insertions(+), 65 deletions(-)

diff --git a/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td b/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td
index 6192d5e31ffc5..3c5294ff14fc7 100644
--- a/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td
+++ b/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td
@@ -1,4 +1,4 @@
-//===- X86VectorTransformOps.td - X86Vector transform ops ---*- tablegen -*-===//
+//===- X86VectorTransformOps.td - X86Vector transform ops --*- tablegen -*-===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
diff --git a/mlir/lib/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp b/mlir/lib/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp
index 980c585848080..95db208207672 100644
--- a/mlir/lib/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp
+++ b/mlir/lib/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp
@@ -1,4 +1,4 @@
-//===- X86VectorTransformOps.cpp -============================================//
+//===- X86VectorTransformOps.cpp ------------------------------------------===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
diff --git a/mlir/lib/Dialect/X86Vector/Transforms/VectorContractToFMA.cpp b/mlir/lib/Dialect/X86Vector/Transforms/VectorContractToFMA.cpp
index 9349466ba1a34..9a6b4d07ecf58 100644
--- a/mlir/lib/Dialect/X86Vector/Transforms/VectorContractToFMA.cpp
+++ b/mlir/lib/Dialect/X86Vector/Transforms/VectorContractToFMA.cpp
@@ -42,7 +42,7 @@ struct VectorContractToFMA : public OpRewritePattern<vector::ContractionOp> {
 
     if (contractOp.getKind() != vector::CombiningKind::ADD)
       return rewriter.notifyMatchFailure(contractOp,
-                                         "Expects add combining kind");
+                                         "Expects add combining kind.");
 
     VectorType lhsTy = contractOp.getLhsType();
     if (!lhsTy.getElementType().isF32())
@@ -50,50 +50,60 @@ struct VectorContractToFMA : public OpRewritePattern<vector::ContractionOp> {
                                          "Only F32 lowering is supported.");
 
     ArrayRef<int64_t> lhsShape = lhsTy.getShape();
-    llvm::SmallVector<int64_t> dimsLhs;
-    llvm::copy_if(lhsShape, std::back_inserter(dimsLhs),
+    llvm::SmallVector<int64_t> nonUnitDimLhs;
+    llvm::copy_if(lhsShape, std::back_inserter(nonUnitDimLhs),
                   [](int64_t dim) { return dim != 1; });
 
     VectorType rhsTy = contractOp.getRhsType();
     ArrayRef<int64_t> rhsShape = rhsTy.getShape();
-    llvm::SmallVector<int64_t> dimsRhs;
-    llvm::copy_if(rhsShape, std::back_inserter(dimsRhs),
+    llvm::SmallVector<int64_t> nonUnitDimRhs;
+    llvm::copy_if(rhsShape, std::back_inserter(nonUnitDimRhs),
                   [](int64_t dim) { return dim != 1; });
 
-    if (dimsLhs.size() > 0 && dimsRhs.size() > 0)
+    if (nonUnitDimLhs.size() > 0 && nonUnitDimRhs.size() > 0)
       return rewriter.notifyMatchFailure(
           contractOp, "Excepts unit dimensions for either LHS or RHS shape.");
 
-    if (dimsLhs.size() != 1 && dimsRhs.size() != 1)
-      return rewriter.notifyMatchFailure(contractOp,
-                                         "Irregular LHS or RHS shape.");
+    if (nonUnitDimLhs.size() != 1 && nonUnitDimRhs.size() != 1)
+      return rewriter.notifyMatchFailure(
+          contractOp,
+          "Excepts a one non-unit A/B dimension for either LHS or RHS shape.");
 
     VectorType accTy = dyn_cast<VectorType>(contractOp.getAccType());
     if (!accTy)
-      return rewriter.notifyMatchFailure(contractOp, "Wrong accmulator type");
+      return rewriter.notifyMatchFailure(contractOp,
+                                         "Accmulator is not a vector type");
+
+    if (!accTy.getElementType().isF32())
+      return rewriter.notifyMatchFailure(contractOp,
+                                         "Accmulator should be F32 type.");
 
     ArrayRef<int64_t> accShape = accTy.getShape();
-    llvm::SmallVector<int64_t> dimsAcc;
-    llvm::copy_if(accShape, std::back_inserter(dimsAcc),
+    llvm::SmallVector<int64_t> nonUnitDimAcc;
+    llvm::copy_if(accShape, std::back_inserter(nonUnitDimAcc),
                   [](int64_t dim) { return dim != 1; });
-    if (dimsAcc.size() != 1)
-      return rewriter.notifyMatchFailure(contractOp, "Irregular ACC shape");
+    if (nonUnitDimAcc.size() != 1)
+      return rewriter.notifyMatchFailure(contractOp,
+                                         "A or B dimension should be non-unit.");
 
     // Lowers vector.contract into a broadcast+FMA sequence.
     auto loc = contractOp.getLoc();
     auto castAcc = vector::ShapeCastOp::create(
-        rewriter, loc, VectorType::get(dimsAcc.front(), accTy.getElementType()),
+        rewriter, loc,
+        VectorType::get(nonUnitDimAcc.front(), accTy.getElementType()),
         contractOp.getAcc());
 
     vector::FMAOp fma;
 
-    if (dimsRhs.size() > 0) {
+    // LHS shape is unit dimension. Broadcast into vector-size of non-unit
+    // dimension in RHS shape.
+    if (nonUnitDimRhs.size() > 0) {
       auto castLhs = vector::ShapeCastOp::create(
           rewriter, loc, VectorType::get(1, lhsTy.getElementType()),
           contractOp.getLhs());
       auto castRhs = vector::ShapeCastOp::create(
           rewriter, loc,
-          VectorType::get(dimsRhs.front(), rhsTy.getElementType()),
+          VectorType::get(nonUnitDimRhs.front(), rhsTy.getElementType()),
           contractOp.getRhs());
       auto broadcastLhs = vector::BroadcastOp::create(
           rewriter, loc, castRhs.getResult().getType(), castLhs);
@@ -102,7 +112,7 @@ struct VectorContractToFMA : public OpRewritePattern<vector::ContractionOp> {
     } else {
       auto castLhs = vector::ShapeCastOp::create(
           rewriter, loc,
-          VectorType::get(dimsLhs.front(), lhsTy.getElementType()),
+          VectorType::get(nonUnitDimLhs.front(), lhsTy.getElementType()),
           contractOp.getLhs());
       auto castRhs = vector::ShapeCastOp::create(
           rewriter, loc, VectorType::get(1, rhsTy.getElementType()),
diff --git a/mlir/lib/Dialect/X86Vector/Transforms/VectorContractToPackedTypeDotProduct.cpp b/mlir/lib/Dialect/X86Vector/Transforms/VectorContractToPackedTypeDotProduct.cpp
index 59082b9761135..68cd8700f5be6 100644
--- a/mlir/lib/Dialect/X86Vector/Transforms/VectorContractToPackedTypeDotProduct.cpp
+++ b/mlir/lib/Dialect/X86Vector/Transforms/VectorContractToPackedTypeDotProduct.cpp
@@ -129,7 +129,7 @@ struct VectorContractToPackedTypeDotProduct
 
     if (contractOp.getKind() != vector::CombiningKind::ADD)
       return rewriter.notifyMatchFailure(contractOp,
-                                         "Expects add combining kind");
+                                         "Expects add combining kind.");
 
     VectorType lhsTy = contractOp.getLhsType();
     if (!lhsTy.getElementType().isBF16() &&
@@ -137,40 +137,36 @@ struct VectorContractToPackedTypeDotProduct
       return rewriter.notifyMatchFailure(
           contractOp, "Only BF16/Int8 lowering is supported.");
 
-    if (lhsTy.getElementType().isBF16() &&
-        !isInVnniLayout(contractOp.getOperation(),
-                        contractOp.getIndexingMapsArray(), 2))
+    unsigned int blockingFactor = lhsTy.getElementType().isBF16() ? 2 : 4;
+    if (!isInVnniLayout(contractOp.getOperation(),
+                        contractOp.getIndexingMapsArray(), blockingFactor))
       return rewriter.notifyMatchFailure(contractOp,
-                                         "Input matrices not in VNNI format");
-
-    if (lhsTy.getElementType().isSignlessInteger(8) &&
-        !isInVnniLayout(contractOp.getOperation(),
-                        contractOp.getIndexingMapsArray(), 4))
-      return rewriter.notifyMatchFailure(contractOp,
-                                         "Input matrices not in VNNI format");
+                                         "Input matrices not in VNNI format.");
 
     ArrayRef<int64_t> lhsShape = lhsTy.getShape();
-    llvm::SmallVector<int64_t> dimsLhs;
-    llvm::copy_if(lhsShape, std::back_inserter(dimsLhs),
+    llvm::SmallVector<int64_t> nonUnitDimLhs;
+    llvm::copy_if(lhsShape, std::back_inserter(nonUnitDimLhs),
                   [](int64_t dim) { return dim != 1; });
 
     VectorType rhsTy = contractOp.getRhsType();
     ArrayRef<int64_t> rhsShape = rhsTy.getShape();
-    llvm::SmallVector<int64_t> dimsRhs;
-    llvm::copy_if(rhsShape, std::back_inserter(dimsRhs),
+    llvm::SmallVector<int64_t> nonUnitDimRhs;
+    llvm::copy_if(rhsShape, std::back_inserter(nonUnitDimRhs),
                   [](int64_t dim) { return dim != 1; });
 
-    if ((dimsLhs.size() - 1) > 0 && (dimsRhs.size() - 1) > 0)
-      return rewriter.notifyMatchFailure(
-          contractOp, "Excepts unit dimensions for either LHS or RHS shape.");
-
-    if ((dimsLhs.size() - 1) != 1 && (dimsRhs.size() - 1) != 1)
+    if ((nonUnitDimLhs.size() - 1) > 0 && (nonUnitDimRhs.size() - 1) > 0)
       return rewriter.notifyMatchFailure(contractOp,
-                                         "Irregular LHS or RHS shape.");
+                                         "Excepts unit dimensions for either "
+                                         "LHS or RHS shape other than VNNI.");
+
+    if ((nonUnitDimLhs.size() - 1) != 1 && (nonUnitDimRhs.size() - 1) != 1)
+      return rewriter.notifyMatchFailure(
+          contractOp,
+          "Excepts a one non-unit A/B dimension for either LHS or RHS shape.");
 
     VectorType accTy = dyn_cast<VectorType>(contractOp.getAccType());
     if (!accTy)
-      return rewriter.notifyMatchFailure(contractOp, "Wrong accmulator type");
+      return rewriter.notifyMatchFailure(contractOp, "Wrong accmulator type.");
 
     if ((lhsTy.getElementType().isBF16() && !accTy.getElementType().isF32()) ||
         (lhsTy.getElementType().isSignlessInteger(8) &&
@@ -180,35 +176,55 @@ struct VectorContractToPackedTypeDotProduct
                                          "accumulation type is supported.");
 
     ArrayRef<int64_t> accShape = accTy.getShape();
-    llvm::SmallVector<int64_t> dimsAcc;
-    llvm::copy_if(accShape, std::back_inserter(dimsAcc),
+    llvm::SmallVector<int64_t> nonUnitDimAcc;
+    llvm::copy_if(accShape, std::back_inserter(nonUnitDimAcc),
                   [](int64_t dim) { return dim != 1; });
-    if (dimsAcc.size() != 1)
-      return rewriter.notifyMatchFailure(contractOp, "Irregular ACC shape");
+    if (nonUnitDimAcc.size() != 1)
+      return rewriter.notifyMatchFailure(
+          contractOp, "A or B should be a non-unit dim in acc.");
+
+    // Non-unit dimensions should match the vector length of BF16 or Int8
+    // dot-product.
+    unsigned int nonUnitDim = nonUnitDimLhs.size() == 2 ? nonUnitDimLhs.front()
+                                                        : nonUnitDimRhs.front();
+    if (lhsTy.getElementType().isBF16() && nonUnitDim != 4 && nonUnitDim != 8 &&
+        nonUnitDim != 16 && nonUnitDimAcc.front() == nonUnitDim)
+      return rewriter.notifyMatchFailure(
+          contractOp, "BF16 dot-product operation expects non-unit (LHR or "
+                      "RHS) dim and acc dim of size 4/8/16.");
+
+    if (lhsTy.getElementType().isSignlessInteger(8) && nonUnitDim != 4 &&
+        nonUnitDim != 8 && nonUnitDimAcc.front() == nonUnitDim)
+      return rewriter.notifyMatchFailure(
+          contractOp, "Int8 dot-product operation expects non-unit (LHR or "
+                      "RHS) dim and acc dim of size 4/8.");
 
     auto loc = contractOp.getLoc();
     auto castAcc = vector::ShapeCastOp::create(
-        rewriter, loc, VectorType::get(dimsAcc.front(), accTy.getElementType()),
+        rewriter, loc,
+        VectorType::get(nonUnitDimAcc.front(), accTy.getElementType()),
         contractOp.getAcc());
 
     Value dp;
 
-    if ((dimsRhs.size() - 1) > 0) {
+    // LHS shape is unit dimension. Broadcast into vector-size of non-unit
+    // dimension in RHS shape. Subtract one to remove VNNI dim.
+    if ((nonUnitDimRhs.size() - 1) > 0) {
       auto castRhs = vector::ShapeCastOp::create(
           rewriter, loc,
-          VectorType::get(dimsRhs.front() * dimsRhs.back(),
+          VectorType::get(nonUnitDimRhs.front() * nonUnitDimRhs.back(),
                           rhsTy.getElementType()),
           contractOp.getRhs());
       auto castLhs = vector::ShapeCastOp::create(
           rewriter, loc,
-          VectorType::get(dimsLhs.front(), lhsTy.getElementType()),
+          VectorType::get(nonUnitDimLhs.front(), lhsTy.getElementType()),
           contractOp.getLhs());
       auto bitcastLhs = vector::BitCastOp::create(
           rewriter, loc, VectorType::get({1}, rewriter.getIntegerType(32)),
           castLhs);
       auto broadcastLhs = vector::BroadcastOp::create(
           rewriter, loc,
-          VectorType::get({dimsRhs.front()}, rewriter.getIntegerType(32)),
+          VectorType::get({nonUnitDimRhs.front()}, rewriter.getIntegerType(32)),
           bitcastLhs);
       auto bitcastLhsPkType = vector::BitCastOp::create(
           rewriter, loc, castRhs.getResult().getType(), broadcastLhs);
@@ -216,32 +232,32 @@ struct VectorContractToPackedTypeDotProduct
       if (lhsTy.getElementType().isBF16()) {
         dp = x86vector::DotBF16Op::create(
             rewriter, loc,
-            VectorType::get(dimsRhs.front(), rewriter.getF32Type()), castAcc,
-            bitcastLhsPkType, castRhs);
+            VectorType::get(nonUnitDimRhs.front(), rewriter.getF32Type()),
+            castAcc, bitcastLhsPkType, castRhs);
       }
 
       if (lhsTy.getElementType().isSignlessInteger(8)) {
         dp = x86vector::DotInt8Op::create(
             rewriter, loc,
-            VectorType::get(dimsRhs.front(), rewriter.getIntegerType(32)),
+            VectorType::get(nonUnitDimRhs.front(), rewriter.getIntegerType(32)),
             castAcc, bitcastLhsPkType, castRhs);
       }
-    } else {
+    } else { // RHS shape is unit dimension.
       auto castLhs = vector::ShapeCastOp::create(
           rewriter, loc,
-          VectorType::get(dimsLhs.front() * dimsLhs.back(),
+          VectorType::get(nonUnitDimLhs.front() * nonUnitDimLhs.back(),
                           lhsTy.getElementType()),
           contractOp.getLhs());
       auto castRhs = vector::ShapeCastOp::create(
           rewriter, loc,
-          VectorType::get(dimsRhs.front(), rhsTy.getElementType()),
+          VectorType::get(nonUnitDimRhs.front(), rhsTy.getElementType()),
           contractOp.getRhs());
       auto bitcastRhs = vector::BitCastOp::create(
           rewriter, loc, VectorType::get({1}, rewriter.getIntegerType(32)),
           castRhs);
       auto broadcastRhs = vector::BroadcastOp::create(
           rewriter, loc,
-          VectorType::get({dimsLhs.front()}, rewriter.getIntegerType(32)),
+          VectorType::get({nonUnitDimLhs.front()}, rewriter.getIntegerType(32)),
           bitcastRhs);
       auto bitcastRhsPkType = vector::BitCastOp::create(
           rewriter, loc, castLhs.getResult().getType(), broadcastRhs);
@@ -249,14 +265,14 @@ struct VectorContractToPackedTypeDotProduct
       if (lhsTy.getElementType().isBF16()) {
         dp = x86vector::DotBF16Op::create(
             rewriter, loc,
-            VectorType::get(dimsLhs.front(), rewriter.getF32Type()), castAcc,
-            castLhs, bitcastRhsPkType);
+            VectorType::get(nonUnitDimLhs.front(), rewriter.getF32Type()),
+            castAcc, castLhs, bitcastRhsPkType);
       }
 
       if (lhsTy.getElementType().isSignlessInteger(8)) {
         dp = x86vector::DotInt8Op::create(
             rewriter, loc,
-            VectorType::get(dimsLhs.front(), rewriter.getIntegerType(32)),
+            VectorType::get(nonUnitDimLhs.front(), rewriter.getIntegerType(32)),
             castAcc, castLhs, bitcastRhsPkType);
       }
     }
diff --git a/mlir/test/Dialect/X86Vector/vector-contract-to-fma.mlir b/mlir/test/Dialect/X86Vector/vector-contract-to-fma.mlir
index cf53710d839a1..e506b166d43ff 100644
--- a/mlir/test/Dialect/X86Vector/vector-contract-to-fma.mlir
+++ b/mlir/test/Dialect/X86Vector/vector-contract-to-fma.mlir
@@ -309,3 +309,36 @@ module attributes {transform.with_named_sequence} {
   }
 }
 
+// -----
+
+!vecA = vector<1x1x1xf32>
+!vecB = vector<1x1x64xf32>
+!vecC = vector<1x1x64xi32>
+#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
+#map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+func.func @negative_accumulator_type(
+  %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC
+{
+  %0 = vector.contract {
+    indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["parallel", "parallel", "parallel", "reduction"],
+    kind = #vector.kind<add>}
+    %arg0, %arg1, %arg2
+    : !vecA, !vecB into !vecC
+  return %0 : !vecC
+}
+
+// CHECK-LABEL: @negative_accumulator_type
+// CHECK-NOT: vector.fma
+// CHECK: vector.contract
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.x86vector.vector_contract_to_fma
+    } : !transform.any_op
+    transform.yield
+  }
+}
diff --git a/mlir/test/Dialect/X86Vector/vector-contract-to-packed-type-dotproduct.mlir b/mlir/test/Dialect/X86Vector/vector-contract-to-packed-type-dotproduct.mlir
index 268592c09a3a7..65676cbae772c 100644
--- a/mlir/test/Dialect/X86Vector/vector-contract-to-packed-type-dotproduct.mlir
+++ b/mlir/test/Dialect/X86Vector/vector-contract-to-packed-type-dotproduct.mlir
@@ -484,7 +484,7 @@ module attributes {transform.with_named_sequence} {
 #map = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d3, d4)>
 #map1 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d3, d2, d4)>
 #map2 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d2)>
-func.func @negative_acc_type(
+func.func @negative_float_acc_type(
   %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC
 {
   %0 = vector.contract {
@@ -496,7 +496,7 @@ func.func @negative_acc_type(
   return %0 : !vecC
 }
 
-// CHECK-LABEL: @negative_acc_type
+// CHECK-LABEL: @negative_float_acc_type
 // CHECK-NOT: x86vector.avx512.dot
 // CHECK: vector.contract
 
@@ -512,13 +512,47 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
+!vecA = vector<1x1x1x4xi8>
+!vecB = vector<1x1x8x4xi8>
+!vecC = vector<1x1x8xi8>
+#map = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d3, d4)>
+#map1 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d3, d2, d4)>
+#map2 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d2)>
+func.func @negative_int_acc_type(
+  %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC
+{
+  %0 = vector.contract {
+    indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["parallel", "reduction", "parallel", "parallel", "reduction"],
+    kind = #vector.kind<add>}
+    %arg0, %arg1, %arg2
+    : !vecA, !vecB into !vecC
+  return %0 : !vecC
+}
+
+// CHECK-LABEL: @negative_int_acc_type
+// CHECK-NOT: x86vector.avx.dot.i8
+// CHECK: vector.contract
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product
+    } : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
 !vecA = vector<1x1x1x4xbf16>
 !vecB = vector<1x1x16x4xbf16>
 !vecC = vector<1x1x16xbf16>
 #map = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d3, d4)>
 #map1 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d3, d2, d4)>
 #map2 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d2)>
-func.func @negative_acc_type(
+func.func @negative_wrong_vnni_blocking_factor_bf16(
   %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC
 {
   %0 = vector.contract {
@@ -530,7 +564,7 @@ func.func @negative_acc_type(
   return %0 : !vecC
 }
 
-// CHECK-LABEL: @negative_acc_type
+// CHECK-LABEL: @negative_wrong_vnni_blocking_factor_bf16
 // CHECK-NOT: x86vector.avx512.dot
 // CHECK: vector.contract
 
@@ -577,3 +611,71 @@ module attributes {transform.with_named_sequence} {
     transform.yield
   }
 }
+
+// -----
+
+!vecA = vector<1x1x1x4xi8>
+!vecB = vector<1x1x16x4xi8>
+!vecC = vector<1x1x16xi32>
+#map = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d3, d4)>
+#map1 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d3, d2, d4)>
+#map2 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d2)>
+func.func @negative_wrong_vector_shape_int8(
+  %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC
+{
+  %0 = vector.contract {
+    indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["parallel", "reduction", "parallel", "parallel", "reduction"],
+    kind = #vector.kind<add>}
+    %arg0, %arg1, %arg2
+    : !vecA, !vecB into !vecC
+  return %0 : !vecC
+}
+
+// CHECK-LABEL: @negative_wrong_vector_shape_int8
+// CHECK-NOT: x86vector.avx.dot.i8
+// CHECK: vector.contract
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product
+    } : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+!vecA = vector<1x1x1x2xbf16>
+!vecB = vector<1x1x32x2xbf16>
+!vecC = vector<1x1x32xf32>
+#map = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d3, d4)>
+#map1 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d3, d2, d4)>
+#map2 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d2)>
+func.func @negative_wrong_vector_shape_bf16(
+  %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC
+{
+  %0 = vector.contract {
+    indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["parallel", "reduction", "parallel", "parallel", "reduction"],
+    kind = #vector.kind<add>}
+    %arg0, %arg1, %arg2
+    : !vecA, !vecB into !vecC
+  return %0 : !vecC
+}
+
+// CHECK-LABEL: @negative_wrong_vector_shape_bf16
+// CHECK-NOT: x86vector.avx512.dot
+// CHECK: vector.contract
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product
+    } : !transform.any_op
+    transform.yield
+  }
+}

>From 960c24c66e1e96b7b31b0177cd189985ad1ba80d Mon Sep 17 00:00:00 2001
From: Arun Thangamani <arun.thangamani at intel.com>
Date: Fri, 21 Nov 2025 00:56:52 -0800
Subject: [PATCH 5/6] fix clang-format issue

---
 mlir/lib/Dialect/X86Vector/Transforms/VectorContractToFMA.cpp | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/X86Vector/Transforms/VectorContractToFMA.cpp b/mlir/lib/Dialect/X86Vector/Transforms/VectorContractToFMA.cpp
index 9a6b4d07ecf58..f6169cb6b5400 100644
--- a/mlir/lib/Dialect/X86Vector/Transforms/VectorContractToFMA.cpp
+++ b/mlir/lib/Dialect/X86Vector/Transforms/VectorContractToFMA.cpp
@@ -83,8 +83,8 @@ struct VectorContractToFMA : public OpRewritePattern<vector::ContractionOp> {
     llvm::copy_if(accShape, std::back_inserter(nonUnitDimAcc),
                   [](int64_t dim) { return dim != 1; });
     if (nonUnitDimAcc.size() != 1)
-      return rewriter.notifyMatchFailure(contractOp,
-                                         "A or B dimension should be non-unit.");
+      return rewriter.notifyMatchFailure(
+          contractOp, "A or B dimension should be non-unit.");
 
     // Lowers vector.contract into a broadcast+FMA sequence.
     auto loc = contractOp.getLoc();

>From 64a6b5ddc390ec438a51793ec7fa029b3f8dca70 Mon Sep 17 00:00:00 2001
From: Arun Thangamani <arun.thangamani at intel.com>
Date: Fri, 21 Nov 2025 03:23:01 -0800
Subject: [PATCH 6/6] code refactoring

---
 .../mlir/Dialect/X86Vector/Transforms.h       |  7 +++--
 .../Transforms/VectorContractToFMA.cpp        | 11 ++++++--
 .../VectorContractToPackedTypeDotProduct.cpp  | 26 ++++++++++++-------
 3 files changed, 29 insertions(+), 15 deletions(-)

diff --git a/mlir/include/mlir/Dialect/X86Vector/Transforms.h b/mlir/include/mlir/Dialect/X86Vector/Transforms.h
index 943d7182d1960..0bfd3a45ff7ba 100644
--- a/mlir/include/mlir/Dialect/X86Vector/Transforms.h
+++ b/mlir/include/mlir/Dialect/X86Vector/Transforms.h
@@ -11,10 +11,6 @@
 
 #include "mlir/IR/Value.h"
 
-#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
-#include "mlir/IR/OpDefinition.h"
-#include "mlir/IR/PatternMatch.h"
-
 namespace mlir {
 
 class ImplicitLocOpBuilder;
@@ -85,8 +81,11 @@ struct MaskHelper {
 
 //===----------------------------------------------------------------------===//
 
+// Lowers a FP32 type vector.contract operation to an FMA operation.
 void populateVectorContractToFMAPatterns(RewritePatternSet &patterns);
 
+// Lowers a BF16/Int8 type vector.contract operation to a BF16/Int8 dot-product
+// operation.
 void populateVectorContractToPackedTypeDotProductPatterns(
     RewritePatternSet &patterns);
 
diff --git a/mlir/lib/Dialect/X86Vector/Transforms/VectorContractToFMA.cpp b/mlir/lib/Dialect/X86Vector/Transforms/VectorContractToFMA.cpp
index f6169cb6b5400..f3af5ca167a35 100644
--- a/mlir/lib/Dialect/X86Vector/Transforms/VectorContractToFMA.cpp
+++ b/mlir/lib/Dialect/X86Vector/Transforms/VectorContractToFMA.cpp
@@ -22,6 +22,8 @@ using namespace mlir;
 using namespace mlir::vector;
 using namespace mlir::x86vector;
 
+namespace {
+
 // Implements outer product contraction as a sequence of broadcast and
 // FMA operations.
 //
@@ -95,8 +97,11 @@ struct VectorContractToFMA : public OpRewritePattern<vector::ContractionOp> {
 
     vector::FMAOp fma;
 
-    // LHS shape is unit dimension. Broadcast into vector-size of non-unit
-    // dimension in RHS shape.
+    // Broadcast the unit-dimension LHS or RHS to match the vector length of the
+    // corresponding non-unit dimension on the other operand. For example,
+    // if LHS has type vector<1x1xf32> and RHS has type vector<1x16xf32>, we
+    // broadcast the LHS to vector<1x16xf32>. In the opposite case (non-unit
+    // dimension on the LHS), we broadcast the RHS instead.
     if (nonUnitDimRhs.size() > 0) {
       auto castLhs = vector::ShapeCastOp::create(
           rewriter, loc, VectorType::get(1, lhsTy.getElementType()),
@@ -130,6 +135,8 @@ struct VectorContractToFMA : public OpRewritePattern<vector::ContractionOp> {
   }
 };
 
+} // namespace
+
 void x86vector::populateVectorContractToFMAPatterns(
     RewritePatternSet &patterns) {
   patterns.add<VectorContractToFMA>(patterns.getContext());
diff --git a/mlir/lib/Dialect/X86Vector/Transforms/VectorContractToPackedTypeDotProduct.cpp b/mlir/lib/Dialect/X86Vector/Transforms/VectorContractToPackedTypeDotProduct.cpp
index 68cd8700f5be6..1e64811db910b 100644
--- a/mlir/lib/Dialect/X86Vector/Transforms/VectorContractToPackedTypeDotProduct.cpp
+++ b/mlir/lib/Dialect/X86Vector/Transforms/VectorContractToPackedTypeDotProduct.cpp
@@ -24,6 +24,8 @@ using namespace mlir;
 using namespace mlir::vector;
 using namespace mlir::x86vector;
 
+namespace {
+
 static FailureOr<SmallVector<mlir::utils::IteratorType>>
 inferIteratorsFromOutMap(AffineMap map) {
   if (!map.isProjectedPermutation())
@@ -36,6 +38,8 @@ inferIteratorsFromOutMap(AffineMap map) {
   return iterators;
 }
 
+// Returns true if the operation is in VNNI layout.
+// Optionally, the check can be constrained to a specific VNNI blocking factor.
 static bool isInVnniLayout(Operation *op, ArrayRef<AffineMap> indexingMaps,
                            std::optional<unsigned> blockingFactor) {
   // Narrow down type operations - VNNI only applies to contractions.
@@ -207,8 +211,11 @@ struct VectorContractToPackedTypeDotProduct
 
     Value dp;
 
-    // LHS shape is unit dimension. Broadcast into vector-size of non-unit
-    // dimension in RHS shape. Subtract one to remove VNNI dim.
+    // Broadcast the unit-dimension LHS or RHS to match the vector length of the
+    // corresponding non-unit dimension on the other operand. For example,
+    // if LHS has type vector<1x1x2xbf16> and RHS has type vector<1x16x2xbf16>,
+    // we broadcast the LHS to vector<16x2xbf16>. In the opposite case (non-unit
+    // dimension on the LHS), we broadcast the RHS instead.
     if ((nonUnitDimRhs.size() - 1) > 0) {
       auto castRhs = vector::ShapeCastOp::create(
           rewriter, loc,
@@ -242,7 +249,7 @@ struct VectorContractToPackedTypeDotProduct
             VectorType::get(nonUnitDimRhs.front(), rewriter.getIntegerType(32)),
             castAcc, bitcastLhsPkType, castRhs);
       }
-    } else { // RHS shape is unit dimension.
+    } else {
       auto castLhs = vector::ShapeCastOp::create(
           rewriter, loc,
           VectorType::get(nonUnitDimLhs.front() * nonUnitDimLhs.back(),
@@ -277,16 +284,17 @@ struct VectorContractToPackedTypeDotProduct
       }
     }
 
-    if (dp) {
-      auto castDp = vector::ShapeCastOp::create(rewriter, loc, accTy, dp);
-      rewriter.replaceOp(contractOp, castDp);
-      return success();
-    }
+    if (!dp)
+      return failure();
 
-    return failure();
+    auto castDp = vector::ShapeCastOp::create(rewriter, loc, accTy, dp);
+    rewriter.replaceOp(contractOp, castDp);
+    return success();
   }
 };
 
+} // namespace
+
 void x86vector::populateVectorContractToPackedTypeDotProductPatterns(
     RewritePatternSet &patterns) {
   patterns.add<VectorContractToPackedTypeDotProduct>(patterns.getContext());



More information about the Mlir-commits mailing list