[Mlir-commits] [mlir] [mlir][x86vector] Lower vector.contract to FMA or packed type dot-product (PR #168074)
Arun Thangamani
llvmlistbot at llvm.org
Fri Nov 21 00:57:09 PST 2025
https://github.com/arun-thmn updated https://github.com/llvm/llvm-project/pull/168074
>From 5b9f7e01084e8f21f873e4b893a0dbb5f7f464d1 Mon Sep 17 00:00:00 2001
From: Arun Thangamani <arun.thangamani at intel.com>
Date: Fri, 14 Nov 2025 07:34:52 -0800
Subject: [PATCH 1/5] transform pass to lower vector.contract to fma or packed
dp
---
.../mlir/Dialect/X86Vector/CMakeLists.txt | 2 +
.../X86Vector/TransformOps/CMakeLists.txt | 4 +
.../TransformOps/X86VectorTransformOps.h | 31 ++
.../TransformOps/X86VectorTransformOps.td | 42 ++
.../mlir/Dialect/X86Vector/Transforms.h | 10 +
mlir/lib/Dialect/X86Vector/CMakeLists.txt | 1 +
.../X86Vector/TransformOps/CMakeLists.txt | 17 +
.../TransformOps/X86VectorTransformOps.cpp | 65 +++
.../X86Vector/Transforms/CMakeLists.txt | 2 +
.../Transforms/VectorContractToFMA.cpp | 99 +++++
.../VectorContractToPackedTypeDotProduct.cpp | 148 +++++++
mlir/lib/RegisterAllExtensions.cpp | 2 +
.../X86Vector/vector-contract-to-fma.mlir | 210 ++++++++++
...or-contract-to-packed-type-dotproduct.mlir | 374 ++++++++++++++++++
14 files changed, 1007 insertions(+)
create mode 100644 mlir/include/mlir/Dialect/X86Vector/TransformOps/CMakeLists.txt
create mode 100644 mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.h
create mode 100644 mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td
create mode 100644 mlir/lib/Dialect/X86Vector/TransformOps/CMakeLists.txt
create mode 100644 mlir/lib/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp
create mode 100644 mlir/lib/Dialect/X86Vector/Transforms/VectorContractToFMA.cpp
create mode 100644 mlir/lib/Dialect/X86Vector/Transforms/VectorContractToPackedTypeDotProduct.cpp
create mode 100644 mlir/test/Dialect/X86Vector/vector-contract-to-fma.mlir
create mode 100644 mlir/test/Dialect/X86Vector/vector-contract-to-packed-type-dotproduct.mlir
diff --git a/mlir/include/mlir/Dialect/X86Vector/CMakeLists.txt b/mlir/include/mlir/Dialect/X86Vector/CMakeLists.txt
index 0fe01824b8248..bbe8e4eb892dd 100644
--- a/mlir/include/mlir/Dialect/X86Vector/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/X86Vector/CMakeLists.txt
@@ -3,3 +3,5 @@ add_mlir_doc(X86Vector X86Vector Dialects/ -gen-dialect-doc -dialect=x86vector)
add_mlir_interface(X86VectorInterfaces)
add_dependencies(MLIRX86VectorIncGen MLIRX86VectorInterfacesIncGen)
+
+add_subdirectory(TransformOps)
diff --git a/mlir/include/mlir/Dialect/X86Vector/TransformOps/CMakeLists.txt b/mlir/include/mlir/Dialect/X86Vector/TransformOps/CMakeLists.txt
new file mode 100644
index 0000000000000..6f377e10fa8f8
--- /dev/null
+++ b/mlir/include/mlir/Dialect/X86Vector/TransformOps/CMakeLists.txt
@@ -0,0 +1,4 @@
+set(LLVM_TARGET_DEFINITIONS X86VectorTransformOps.td)
+mlir_tablegen(X86VectorTransformOps.h.inc -gen-op-decls)
+mlir_tablegen(X86VectorTransformOps.cpp.inc -gen-op-defs)
+add_mlir_dialect_tablegen_target(MLIRX86VectorTransformOpsIncGen)
diff --git a/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.h b/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.h
new file mode 100644
index 0000000000000..e1d8b8762e799
--- /dev/null
+++ b/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.h
@@ -0,0 +1,31 @@
+//===- X86VectorTransformOps.h - X86Vector transform ops --------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_X86VECTOR_TRANSFORMOPS_X86VECTORTRANSFORMOPS_H
+#define MLIR_DIALECT_X86VECTOR_TRANSFORMOPS_X86VECTORTRANSFORMOPS_H
+
+#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
+#include "mlir/IR/OpImplementation.h"
+
+//===----------------------------------------------------------------------===//
+// X86Vector Transform Operations
+//===----------------------------------------------------------------------===//
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.h.inc"
+
+namespace mlir {
+class DialectRegistry;
+
+namespace x86vector {
+void registerTransformDialectExtension(DialectRegistry ®istry);
+
+} // namespace x86vector
+} // namespace mlir
+
+#endif // MLIR_DIALECT_X86VECTOR_TRANSFORMOPS_X86VECTORTRANSFORMOPS_H
diff --git a/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td b/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td
new file mode 100644
index 0000000000000..4009a140bb097
--- /dev/null
+++ b/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td
@@ -0,0 +1,42 @@
+//===- X86VectorTransformOps.td - X86Vector transform ops ---*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef X86VECTOR_TRANSFORM_OPS
+#define X86VECTOR_TRANSFORM_OPS
+
+include "mlir/Dialect/Transform/IR/TransformDialect.td"
+include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.td"
+include "mlir/Interfaces/SideEffectInterfaces.td"
+include "mlir/IR/OpBase.td"
+include "mlir/Dialect/Transform/IR/TransformAttrs.td"
+include "mlir/Dialect/Transform/IR/TransformTypes.td"
+include "mlir/IR/RegionKindInterface.td"
+
+def ApplyVectorContractToFMAPatternsOp : Op<Transform_Dialect,
+ "apply_patterns.x86vector.vector_contract_to_fma",
+ [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+ let description = [{
+ Indicates that vector contract operation can be lowered to a FMA.
+ }];
+
+ let assemblyFormat = "attr-dict";
+}
+
+def ApplyVectorContractToPackedTypeDotProductPatternsOp : Op<Transform_Dialect,
+ "apply_patterns.x86vector.vector_contract_to_packed_type_dot_product",
+ [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+ let description = [{
+ Indicates that vector contract operation can be lowered to a BF16/Int8 dot-product.
+ }];
+
+ let assemblyFormat = "attr-dict";
+}
+
+
+#endif // X86VECTOR_TRANSFORM_OPS
+
diff --git a/mlir/include/mlir/Dialect/X86Vector/Transforms.h b/mlir/include/mlir/Dialect/X86Vector/Transforms.h
index d54111ca41e69..0fda47c180971 100644
--- a/mlir/include/mlir/Dialect/X86Vector/Transforms.h
+++ b/mlir/include/mlir/Dialect/X86Vector/Transforms.h
@@ -11,6 +11,10 @@
#include "mlir/IR/Value.h"
+#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/PatternMatch.h"
+
namespace mlir {
class ImplicitLocOpBuilder;
@@ -79,6 +83,12 @@ struct MaskHelper {
}
};
+//===----------------------------------------------------------------------===//
+
+void populateVectorContractToFMAPatterns(RewritePatternSet &patterns);
+
+void populateVectorContractToPackedTypeDotProductPatterns(RewritePatternSet &patterns);
+
//===----------------------------------------------------------------------===//
/// Helpers extracted from:
/// - clang/lib/Headers/avxintrin.h
diff --git a/mlir/lib/Dialect/X86Vector/CMakeLists.txt b/mlir/lib/Dialect/X86Vector/CMakeLists.txt
index 9f57627c321fb..cb1e9d01821a2 100644
--- a/mlir/lib/Dialect/X86Vector/CMakeLists.txt
+++ b/mlir/lib/Dialect/X86Vector/CMakeLists.txt
@@ -1,2 +1,3 @@
add_subdirectory(IR)
add_subdirectory(Transforms)
+add_subdirectory(TransformOps)
diff --git a/mlir/lib/Dialect/X86Vector/TransformOps/CMakeLists.txt b/mlir/lib/Dialect/X86Vector/TransformOps/CMakeLists.txt
new file mode 100644
index 0000000000000..f4c9f8a05acbc
--- /dev/null
+++ b/mlir/lib/Dialect/X86Vector/TransformOps/CMakeLists.txt
@@ -0,0 +1,17 @@
+add_mlir_dialect_library(MLIRX86VectorTransformOps
+ X86VectorTransformOps.cpp
+
+ DEPENDS
+ MLIRX86VectorTransformOpsIncGen
+
+ LINK_LIBS PUBLIC
+ MLIRIR
+ MLIRLLVMCommonConversion
+ MLIRLLVMDialect
+ MLIRVectorDialect
+ MLIRSideEffectInterfaces
+ MLIRTransformDialect
+ MLIRTransformDialectUtils
+ MLIRX86VectorDialect
+ MLIRX86VectorTransforms
+ )
diff --git a/mlir/lib/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp b/mlir/lib/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp
new file mode 100644
index 0000000000000..39c5cf5d9bb54
--- /dev/null
+++ b/mlir/lib/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp
@@ -0,0 +1,65 @@
+//===- X86VectorTransformOps.cpp - Implementation of Vector transform ops
+//--===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.h"
+#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/Transform/IR/TransformDialect.h"
+#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/X86Vector/Transforms.h"
+#include "mlir/Dialect/X86Vector/X86VectorDialect.h"
+
+#include "mlir/IR/OpImplementation.h"
+#include "mlir/IR/RegionKindInterface.h"
+
+using namespace mlir;
+using namespace mlir::x86vector;
+using namespace mlir::transform;
+
+void mlir::transform::ApplyVectorContractToFMAPatternsOp::
+ populatePatterns(RewritePatternSet &patterns) {
+ x86vector::populateVectorContractToFMAPatterns(patterns);
+}
+
+void mlir::transform::ApplyVectorContractToPackedTypeDotProductPatternsOp::
+ populatePatterns(RewritePatternSet &patterns) {
+ x86vector::populateVectorContractToPackedTypeDotProductPatterns(patterns);
+}
+
+//===----------------------------------------------------------------------===//
+// Transform op registration
+//===----------------------------------------------------------------------===//
+
+namespace {
+class X86VectorTransformDialectExtension
+ : public transform::TransformDialectExtension<
+ X86VectorTransformDialectExtension> {
+public:
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
+ X86VectorTransformDialectExtension)
+
+ X86VectorTransformDialectExtension() {
+ declareGeneratedDialect<x86vector::X86VectorDialect>();
+ declareGeneratedDialect<LLVM::LLVMDialect>();
+ registerTransformOps<
+#define GET_OP_LIST
+#include "mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp.inc"
+ >();
+ }
+};
+} // namespace
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp.inc"
+
+void mlir::x86vector::registerTransformDialectExtension(
+ DialectRegistry ®istry) {
+ registry.addExtensions<X86VectorTransformDialectExtension>();
+}
diff --git a/mlir/lib/Dialect/X86Vector/Transforms/CMakeLists.txt b/mlir/lib/Dialect/X86Vector/Transforms/CMakeLists.txt
index c51266afe9e8f..3d2288049e49e 100644
--- a/mlir/lib/Dialect/X86Vector/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/X86Vector/Transforms/CMakeLists.txt
@@ -1,6 +1,8 @@
add_mlir_dialect_library(MLIRX86VectorTransforms
AVXTranspose.cpp
LegalizeForLLVMExport.cpp
+ VectorContractToFMA.cpp
+ VectorContractToPackedTypeDotProduct.cpp
LINK_LIBS PUBLIC
MLIRArithDialect
diff --git a/mlir/lib/Dialect/X86Vector/Transforms/VectorContractToFMA.cpp b/mlir/lib/Dialect/X86Vector/Transforms/VectorContractToFMA.cpp
new file mode 100644
index 0000000000000..764ec46681094
--- /dev/null
+++ b/mlir/lib/Dialect/X86Vector/Transforms/VectorContractToFMA.cpp
@@ -0,0 +1,99 @@
+//===- VectorContractToFMA.cpp --------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
+#include "mlir/Dialect/X86Vector/Transforms.h"
+#include "mlir/Dialect/X86Vector/X86VectorDialect.h"
+
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/Dominance.h"
+#include "mlir/IR/PatternMatch.h"
+
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+using namespace mlir;
+using namespace mlir::vector;
+using namespace mlir::x86vector;
+
+// Implements outer product contraction as a sequence of broadcast and
+// FMA operations.
+//
+// For example - for F32 type:
+// ```
+// vector.contract <1x1xf32>, <1x16xf32> into <1x16xf32>
+// ```
+// to
+// ```
+// vector.broadcast %lhs to <16xf32>
+// vector.fma vector<16xf32>
+// ```
+struct VectorContractToFMA : public OpRewritePattern<vector::ContractionOp> {
+ using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
+ PatternRewriter &rewriter) const override {
+
+ if (contractOp.getKind() != vector::CombiningKind::ADD) {
+ return rewriter.notifyMatchFailure(contractOp,
+ "Expects add combining kind");
+ }
+
+ VectorType lhsTy = contractOp.getLhsType();
+ if (!lhsTy.getElementType().isF32())
+ return rewriter.notifyMatchFailure(contractOp,
+ "Only F32 lowering is supported.");
+ if (llvm::any_of(lhsTy.getShape(), [](int64_t dim) { return dim != 1; }))
+ return rewriter.notifyMatchFailure(
+ contractOp, "Expects one for all dimensions of LHS");
+
+ VectorType rhsTy = contractOp.getRhsType();
+ ArrayRef<int64_t> rhsShape = rhsTy.getShape();
+ llvm::SmallVector<int64_t> dimsRhs;
+ llvm::copy_if(rhsShape, std::back_inserter(dimsRhs),
+ [](int64_t dim) { return dim != 1; });
+ if (dimsRhs.size() != 1)
+ return rewriter.notifyMatchFailure(contractOp, "Irregular RHS shape");
+
+ VectorType accTy = dyn_cast<VectorType>(contractOp.getAccType());
+ assert(accTy && "Invalid accumulator");
+ ArrayRef<int64_t> accShape = accTy.getShape();
+ llvm::SmallVector<int64_t> dimsAcc;
+ llvm::copy_if(accShape, std::back_inserter(dimsAcc),
+ [](int64_t dim) { return dim != 1; });
+ if (dimsAcc.size() != 1)
+ return rewriter.notifyMatchFailure(contractOp, "Irregular ACC shape");
+
+ // Lowers vector.contract into a broadcast+FMA sequence.
+ auto loc = contractOp.getLoc();
+ auto castLhs = vector::ShapeCastOp::create(
+ rewriter, loc, VectorType::get(1, lhsTy.getElementType()),
+ contractOp.getLhs());
+ auto castRhs = vector::ShapeCastOp::create(
+ rewriter, loc, VectorType::get(dimsRhs.front(), rhsTy.getElementType()),
+ contractOp.getRhs());
+ auto castAcc = vector::ShapeCastOp::create(
+ rewriter, loc, VectorType::get(dimsAcc.front(), accTy.getElementType()),
+ contractOp.getAcc());
+ auto broadcastLhs = vector::BroadcastOp::create(
+ rewriter, loc, castRhs.getResult().getType(), castLhs);
+ auto fma =
+ vector::FMAOp::create(rewriter, loc, broadcastLhs, castRhs, castAcc);
+ auto castFma = vector::ShapeCastOp::create(rewriter, loc, accTy, fma);
+
+ rewriter.replaceOp(contractOp, castFma);
+
+ return success();
+ }
+};
+
+void x86vector::populateVectorContractToFMAPatterns(
+ RewritePatternSet &patterns) {
+ patterns.add<VectorContractToFMA>(patterns.getContext());
+}
diff --git a/mlir/lib/Dialect/X86Vector/Transforms/VectorContractToPackedTypeDotProduct.cpp b/mlir/lib/Dialect/X86Vector/Transforms/VectorContractToPackedTypeDotProduct.cpp
new file mode 100644
index 0000000000000..1dabbddbebb7e
--- /dev/null
+++ b/mlir/lib/Dialect/X86Vector/Transforms/VectorContractToPackedTypeDotProduct.cpp
@@ -0,0 +1,148 @@
+//===- VectorContractToPackedTypeDotProduct.cpp ---------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
+#include "mlir/Dialect/X86Vector/Transforms.h"
+#include "mlir/Dialect/X86Vector/X86VectorDialect.h"
+
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/Dominance.h"
+#include "mlir/IR/PatternMatch.h"
+
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+using namespace mlir;
+using namespace mlir::vector;
+using namespace mlir::x86vector;
+
+// Implements packed type outer product contraction as a sequence
+// of broadcast and packed dot-product operations.
+//
+// For example - for F32 type:
+// ```
+// vector.contract <1x1x2xbf16>, <1x16x2xbf16> into <1x16xf32>
+// ```
+// to
+// ```
+// vector.broadcast %lhs to <32xbf16>
+// x86vector.avx512.dot vector<32xbf16> -> vector<16xf32>
+// ```
+struct VectorContractToPackedTypeDotProduct
+ : public OpRewritePattern<vector::ContractionOp> {
+ using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
+ PatternRewriter &rewriter) const override {
+
+ if (contractOp.getKind() != vector::CombiningKind::ADD) {
+ return rewriter.notifyMatchFailure(contractOp,
+ "Expects add combining kind");
+ }
+
+ VectorType lhsTy = contractOp.getLhsType();
+ if (!lhsTy.getElementType().isBF16() &&
+ !lhsTy.getElementType().isSignlessInteger(8))
+ return rewriter.notifyMatchFailure(
+ contractOp, "Only BF16/Int8 lowering is supported.");
+ ArrayRef<int64_t> lhsShape = lhsTy.getShape();
+ if (lhsTy.getElementType().isBF16() && lhsShape.back() != 2)
+ return rewriter.notifyMatchFailure(
+ contractOp, "The LHS vnni dim should be 2 for BF16.");
+
+ if (lhsTy.getElementType().isSignlessInteger(8) && lhsShape.back() != 4)
+ return rewriter.notifyMatchFailure(
+ contractOp, "The LHS vnni dim should be 4 for Int8.");
+ llvm::SmallVector<int64_t> dimsLhs;
+ llvm::copy_if(lhsShape, std::back_inserter(dimsLhs),
+ [](int64_t dim) { return dim != 1; });
+ if (dimsLhs.size() != 1)
+ return rewriter.notifyMatchFailure(contractOp, "Irregular LHS shape");
+
+ VectorType rhsTy = contractOp.getRhsType();
+ ArrayRef<int64_t> rhsShape = rhsTy.getShape();
+ if (lhsTy.getElementType().isBF16() && rhsShape.back() != 2)
+ return rewriter.notifyMatchFailure(
+ contractOp, "The RHS vnni dim should be 2 for BF16.");
+ if (lhsTy.getElementType().isSignlessInteger(8) && rhsShape.back() != 4)
+ return rewriter.notifyMatchFailure(
+ contractOp, "The RHS vnni dim should be 4 for Int8.");
+ llvm::SmallVector<int64_t> dimsRhs;
+ llvm::copy_if(rhsShape, std::back_inserter(dimsRhs),
+ [](int64_t dim) { return dim != 1; });
+ if (dimsRhs.size() != 2)
+ return rewriter.notifyMatchFailure(contractOp, "Irregular RHS shape");
+
+ VectorType accTy = dyn_cast<VectorType>(contractOp.getAccType());
+ assert(accTy && "Invalid accumulator");
+ if (!accTy.getElementType().isF32() &&
+ !accTy.getElementType().isSignlessInteger(32))
+ return rewriter.notifyMatchFailure(
+ contractOp, "Only F32/Int32 accumulation is supported.");
+ ArrayRef<int64_t> accShape = accTy.getShape();
+ llvm::SmallVector<int64_t> dimsAcc;
+ llvm::copy_if(accShape, std::back_inserter(dimsAcc),
+ [](int64_t dim) { return dim != 1; });
+ if (dimsAcc.size() != 1)
+ return rewriter.notifyMatchFailure(contractOp, "Irregular ACC shape");
+
+ auto loc = contractOp.getLoc();
+ auto castRhs = vector::ShapeCastOp::create(
+ rewriter, loc,
+ VectorType::get(dimsRhs.front() * dimsRhs.back(),
+ rhsTy.getElementType()),
+ contractOp.getRhs());
+
+ auto castAcc = vector::ShapeCastOp::create(
+ rewriter, loc, VectorType::get(dimsAcc.front(), accTy.getElementType()),
+ contractOp.getAcc());
+
+ auto castLhs = vector::ShapeCastOp::create(
+ rewriter, loc, VectorType::get(dimsLhs.front(), lhsTy.getElementType()),
+ contractOp.getLhs());
+ auto bitcastLhs = vector::BitCastOp::create(
+ rewriter, loc, VectorType::get({1}, rewriter.getIntegerType(32)),
+ castLhs);
+ auto broadcastLhs = vector::BroadcastOp::create(
+ rewriter, loc,
+ VectorType::get({dimsRhs.front()}, rewriter.getIntegerType(32)),
+ bitcastLhs);
+ auto bitcastLhsPkType = vector::BitCastOp::create(
+ rewriter, loc, castRhs.getResult().getType(), broadcastLhs);
+
+ Value dp;
+
+ if (lhsTy.getElementType().isBF16()) {
+ dp = x86vector::DotBF16Op::create(
+ rewriter, loc,
+ VectorType::get(dimsRhs.front(), rewriter.getF32Type()), castAcc,
+ bitcastLhsPkType, castRhs);
+ }
+
+ if (lhsTy.getElementType().isSignlessInteger(8)) {
+ dp = x86vector::DotInt8Op::create(
+ rewriter, loc,
+ VectorType::get(dimsRhs.front(), rewriter.getIntegerType(32)),
+ castAcc, bitcastLhsPkType, castRhs);
+ }
+
+ if (dp) {
+ auto castDp = vector::ShapeCastOp::create(rewriter, loc, accTy, dp);
+ rewriter.replaceOp(contractOp, castDp);
+ return success();
+ }
+
+ return failure();
+ }
+};
+
+void x86vector::populateVectorContractToPackedTypeDotProductPatterns(
+ RewritePatternSet &patterns) {
+ patterns.add<VectorContractToPackedTypeDotProduct>(patterns.getContext());
+}
diff --git a/mlir/lib/RegisterAllExtensions.cpp b/mlir/lib/RegisterAllExtensions.cpp
index c857c38df717c..4312100a0c0b0 100644
--- a/mlir/lib/RegisterAllExtensions.cpp
+++ b/mlir/lib/RegisterAllExtensions.cpp
@@ -56,6 +56,7 @@
#include "mlir/Dialect/Transform/SMTExtension/SMTExtension.h"
#include "mlir/Dialect/Transform/TuneExtension/TuneExtension.h"
#include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.h"
+#include "mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.h"
#include "mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.h"
#include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h"
#include "mlir/Target/LLVMIR/Dialect/GPU/GPUToLLVMIRTranslation.h"
@@ -113,6 +114,7 @@ void mlir::registerAllExtensions(DialectRegistry ®istry) {
transform::registerSMTExtension(registry);
transform::registerTuneExtension(registry);
vector::registerTransformDialectExtension(registry);
+ x86vector::registerTransformDialectExtension(registry);
xegpu::registerTransformDialectExtension(registry);
arm_neon::registerTransformDialectExtension(registry);
arm_sve::registerTransformDialectExtension(registry);
diff --git a/mlir/test/Dialect/X86Vector/vector-contract-to-fma.mlir b/mlir/test/Dialect/X86Vector/vector-contract-to-fma.mlir
new file mode 100644
index 0000000000000..3a79037ca37c2
--- /dev/null
+++ b/mlir/test/Dialect/X86Vector/vector-contract-to-fma.mlir
@@ -0,0 +1,210 @@
+// RUN: mlir-opt %s -transform-interpreter -cse -split-input-file | FileCheck %s
+
+!vecA = vector<1x1xf32>
+!vecB = vector<1x64xf32>
+!vecC = vector<1x64xf32>
+#map = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
+func.func @matmul_outer_product_to_fma(
+ %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC
+{
+ %0 = vector.contract {
+ indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>}
+ %arg0, %arg1, %arg2
+ : !vecA, !vecB into !vecC
+ return %0 : !vecC
+}
+
+// CHECK-LABEL: @matmul_outer_product_to_fma
+// CHECK-COUNT-1: vector.shape_cast{{.*}}to vector<1xf32>
+// CHECK-COUNT-2: vector.shape_cast{{.*}}to vector<64xf32>
+// CHECK: vector.broadcast{{.*}}vector<1xf32> to vector<64xf32>
+// CHECK: vector.fma{{.*}}vector<64xf32>
+// CHECK: vector.shape_cast{{.*}}vector<64xf32> to vector<1x64xf32>
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.x86vector.vector_contract_to_fma
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+!vecA = vector<1x1x1xf32>
+!vecB = vector<1x1x64xf32>
+!vecC = vector<1x1x64xf32>
+#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
+#map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+func.func @batch_matmul_to_fma(
+ %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC
+{
+ %0 = vector.contract {
+ indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["parallel", "parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>}
+ %arg0, %arg1, %arg2
+ : !vecA, !vecB into !vecC
+ return %0 : !vecC
+}
+
+// CHECK-LABEL: @batch_matmul_to_fma
+// CHECK: vector.broadcast
+// CHECK: vector.fma{{.*}}vector<64xf32>
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.x86vector.vector_contract_to_fma
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+!vecA = vector<1x1x1xf32>
+!vecB = vector<1x1x64xf32>
+!vecC = vector<1x64xf32>
+#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
+#map2 = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+func.func @brgemm_to_fma(
+ %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC
+{
+ %0 = vector.contract {
+ indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["reduction", "parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>}
+ %arg0, %arg1, %arg2
+ : !vecA, !vecB into !vecC
+ return %0 : !vecC
+}
+
+// CHECK-LABEL: @brgemm_to_fma
+// CHECK: vector.broadcast
+// CHECK: vector.fma{{.*}}vector<64xf32>
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.x86vector.vector_contract_to_fma
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+!vecA = vector<3x1x1xf32>
+!vecB = vector<3x1x64xf32>
+!vecC = vector<3x1x64xf32>
+#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
+#map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+func.func @negative_non_unit_batch_dim(
+ %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC
+{
+ %0 = vector.contract {
+ indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["parallel", "parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>}
+ %arg0, %arg1, %arg2
+ : !vecA, !vecB into !vecC
+ return %0 : !vecC
+}
+
+// Batch dimension should've been simplified earlier.
+
+// CHECK-LABEL: @negative_non_unit_batch_dim
+// CHECK-NOT: vector.fma
+// CHECK: vector.contract
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.x86vector.vector_contract_to_fma
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+!vecA = vector<3x1x1xf32>
+!vecB = vector<3x1x64xf32>
+!vecC = vector<1x64xf32>
+#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
+#map2 = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+func.func @negative_non_unit_batch_reduce_dim(
+ %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC
+{
+ %0 = vector.contract {
+ indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["reduction", "parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>}
+ %arg0, %arg1, %arg2
+ : !vecA, !vecB into !vecC
+ return %0 : !vecC
+}
+
+// Batch-reduce dimension should've been simplified earlier.
+
+// CHECK-LABEL: @negative_non_unit_batch_reduce_dim
+// CHECK-NOT: vector.fma
+// CHECK: vector.contract
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.x86vector.vector_contract_to_fma
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+!vecA = vector<1x1xf32>
+!vecB = vector<1x64xf32>
+!vecC = vector<1x64xf32>
+#map = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
+func.func @negative_invalid_kind(
+ %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC
+{
+ %0 = vector.contract {
+ indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["parallel", "parallel", "reduction"],
+ kind = #vector.kind<mul>}
+ %arg0, %arg1, %arg2
+ : !vecA, !vecB into !vecC
+ return %0 : !vecC
+}
+
+// CHECK-LABEL: @negative_invalid_kind
+// CHECK-NOT: vector.fma
+// CHECK: vector.contract
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.x86vector.vector_contract_to_fma
+ } : !transform.any_op
+ transform.yield
+ }
+}
diff --git a/mlir/test/Dialect/X86Vector/vector-contract-to-packed-type-dotproduct.mlir b/mlir/test/Dialect/X86Vector/vector-contract-to-packed-type-dotproduct.mlir
new file mode 100644
index 0000000000000..551f1f95ed9c0
--- /dev/null
+++ b/mlir/test/Dialect/X86Vector/vector-contract-to-packed-type-dotproduct.mlir
@@ -0,0 +1,374 @@
+// RUN: mlir-opt %s -transform-interpreter -cse -split-input-file | FileCheck %s
+
+!vecA = vector<1x1x1x2xbf16>
+!vecB = vector<1x1x16x2xbf16>
+!vecC = vector<1x16xf32>
+#map = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d3, d4)>
+#map1 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d3, d2, d4)>
+#map2 = affine_map<(d0, d4, d1, d2, d3) -> (d1, d2)>
+func.func @brgemm_to_bf16dp(
+ %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC
+{
+ %0 = vector.contract {
+ indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["reduction", "reduction", "parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>}
+ %arg0, %arg1, %arg2
+ : !vecA, !vecB into !vecC
+ return %0 : !vecC
+}
+
+// CHECK-LABEL: @brgemm_to_bf16dp
+// CHECK: vector.broadcast
+// CHECK: x86vector.avx512.dot
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+!vecA = vector<1x1x1x4xi8>
+!vecB = vector<1x1x8x4xi8>
+!vecC = vector<1x8xi32>
+#map = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d3, d4)>
+#map1 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d3, d2, d4)>
+#map2 = affine_map<(d0, d4, d1, d2, d3) -> (d1, d2)>
+func.func @brgemm_to_int8dp(
+ %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC
+{
+ %0 = vector.contract {
+ indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["reduction", "reduction", "parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>}
+ %arg0, %arg1, %arg2
+ : !vecA, !vecB into !vecC
+ return %0 : !vecC
+}
+
+// CHECK-LABEL: @brgemm_to_int8dp
+// CHECK: vector.broadcast
+// CHECK: x86vector.avx.dot.i8
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+!vecA = vector<1x1x1x2xbf16>
+!vecB = vector<1x1x16x2xbf16>
+!vecC = vector<1x1x16xf32>
+#map = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d3, d4)>
+#map1 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d3, d2, d4)>
+#map2 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d2)>
+func.func @batch_matmul_bf16dp(
+ %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC
+{
+ %0 = vector.contract {
+ indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["parallel", "reduction", "parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>}
+ %arg0, %arg1, %arg2
+ : !vecA, !vecB into !vecC
+ return %0 : !vecC
+}
+
+// CHECK-LABEL: @batch_matmul_bf16dp
+// CHECK: vector.broadcast
+// CHECK: x86vector.avx512.dot
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+!vecA = vector<1x1x1x4xi8>
+!vecB = vector<1x1x8x4xi8>
+!vecC = vector<1x1x8xi32>
+#map = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d3, d4)>
+#map1 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d3, d2, d4)>
+#map2 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d2)>
+func.func @batch_matmul_int8dp(
+ %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC
+{
+ %0 = vector.contract {
+ indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["parallel", "reduction", "parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>}
+ %arg0, %arg1, %arg2
+ : !vecA, !vecB into !vecC
+ return %0 : !vecC
+}
+
+
+// CHECK-LABEL: @batch_matmul_int8dp
+// CHECK: vector.broadcast
+// CHECK: x86vector.avx.dot.i8
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+!vecA = vector<1x1x2xbf16>
+!vecB = vector<1x16x2xbf16>
+!vecC = vector<1x16xf32>
+#map = affine_map<(d4, d1, d2, d3) -> (d1, d3, d4)>
+#map1 = affine_map<(d4, d1, d2, d3) -> (d3, d2, d4)>
+#map2 = affine_map<(d4, d1, d2, d3) -> (d1, d2)>
+func.func @matmul_outer_product_to_bf16dp(
+ %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC
+{
+ %0 = vector.contract {
+ indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["reduction", "parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>}
+ %arg0, %arg1, %arg2
+ : !vecA, !vecB into !vecC
+ return %0 : !vecC
+}
+
+// CHECK-LABEL: @matmul_outer_product_to_bf16dp
+// CHECK: vector.broadcast
+// CHECK: x86vector.avx512.dot
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+!vecA = vector<1x1x4xi8>
+!vecB = vector<1x8x4xi8>
+!vecC = vector<1x8xi32>
+#map = affine_map<(d4, d1, d2, d3) -> (d1, d3, d4)>
+#map1 = affine_map<(d4, d1, d2, d3) -> (d3, d2, d4)>
+#map2 = affine_map<(d4, d1, d2, d3) -> (d1, d2)>
+func.func @matmul_outer_product_to_int8dp(
+ %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC
+{
+ %0 = vector.contract {
+ indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["reduction", "parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>}
+ %arg0, %arg1, %arg2
+ : !vecA, !vecB into !vecC
+ return %0 : !vecC
+}
+
+// CHECK-LABEL: @matmul_outer_product_to_int8dp
+// CHECK: vector.broadcast
+// CHECK: x86vector.avx.dot.i8
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+!vecA = vector<1x1x2xbf16>
+!vecB = vector<1x16x2xbf16>
+!vecC = vector<1x16xf32>
+#map = affine_map<(d4, d1, d2, d3) -> (d1, d3, d4)>
+#map1 = affine_map<(d4, d1, d2, d3) -> (d3, d2, d4)>
+#map2 = affine_map<(d4, d1, d2, d3) -> (d1, d2)>
+func.func @negative_invalid_vc_kind(
+ %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC
+{
+ %0 = vector.contract {
+ indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["reduction", "parallel", "parallel", "reduction"],
+ kind = #vector.kind<mul>}
+ %arg0, %arg1, %arg2
+ : !vecA, !vecB into !vecC
+ return %0 : !vecC
+}
+
+// CHECK-LABEL: @negative_invalid_vc_kind
+// CHECK-NOT: x86vector.avx512.dot
+// CHECK: vector.contract
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+!vecA = vector<1x1x1x4xbf16>
+!vecB = vector<1x1x16x4xbf16>
+!vecC = vector<1x16xf32>
+#map = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d3, d4)>
+#map1 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d3, d2, d4)>
+#map2 = affine_map<(d0, d4, d1, d2, d3) -> (d1, d2)>
+func.func @negative_false_vnni_bf16(
+ %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC
+{
+ %0 = vector.contract {
+ indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["reduction", "reduction", "parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>}
+ %arg0, %arg1, %arg2
+ : !vecA, !vecB into !vecC
+ return %0 : !vecC
+}
+
+// CHECK-LABEL: @negative_false_vnni_bf16
+// CHECK-NOT: x86vector.avx512.dot
+// CHECK: vector.contract
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+!vecA = vector<1x1x1x2xi8>
+!vecB = vector<1x1x8x2xi8>
+!vecC = vector<1x8xi32>
+#map = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d3, d4)>
+#map1 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d3, d2, d4)>
+#map2 = affine_map<(d0, d4, d1, d2, d3) -> (d1, d2)>
+func.func @negative_false_vnni_int8(
+ %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC
+{
+ %0 = vector.contract {
+ indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["reduction", "reduction", "parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>}
+ %arg0, %arg1, %arg2
+ : !vecA, !vecB into !vecC
+ return %0 : !vecC
+}
+
+// CHECK-LABEL: @negative_false_vnni_int8
+// CHECK-NOT: x86vector.avx.dot.i8
+// CHECK: vector.contract
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+!vecA = vector<3x1x1x2xbf16>
+!vecB = vector<3x1x16x2xbf16>
+!vecC = vector<3x1x16xf32>
+#map = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d3, d4)>
+#map1 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d3, d2, d4)>
+#map2 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d2)>
+func.func @negative_batch_dimension(
+ %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC
+{
+ %0 = vector.contract {
+ indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["parallel", "reduction", "parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>}
+ %arg0, %arg1, %arg2
+ : !vecA, !vecB into !vecC
+ return %0 : !vecC
+}
+
+// CHECK-LABEL: @negative_batch_dimension
+// CHECK-NOT: x86vector.avx512.dot
+// CHECK: vector.contract
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+!vecA = vector<2x1x1x4xi8>
+!vecB = vector<2x1x8x4xi8>
+!vecC = vector<1x8xi32>
+#map = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d3, d4)>
+#map1 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d3, d2, d4)>
+#map2 = affine_map<(d0, d4, d1, d2, d3) -> (d1, d2)>
+func.func @negative_brgemm_dimension(
+ %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC
+{
+ %0 = vector.contract {
+ indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["reduction", "reduction", "parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>}
+ %arg0, %arg1, %arg2
+ : !vecA, !vecB into !vecC
+ return %0 : !vecC
+}
+
+// CHECK-LABEL: @negative_brgemm_dimension
+// CHECK-NOT: x86vector.avx.dot.i8
+// CHECK: vector.contract
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product
+ } : !transform.any_op
+ transform.yield
+ }
+}
>From a7366b22238e084667733f38091eb4d77f6421de Mon Sep 17 00:00:00 2001
From: Arun Thangamani <arun.thangamani at intel.com>
Date: Fri, 14 Nov 2025 07:49:59 -0800
Subject: [PATCH 2/5] fix clang issues
---
mlir/include/mlir/Dialect/X86Vector/Transforms.h | 3 ++-
.../Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp | 4 ++--
2 files changed, 4 insertions(+), 3 deletions(-)
diff --git a/mlir/include/mlir/Dialect/X86Vector/Transforms.h b/mlir/include/mlir/Dialect/X86Vector/Transforms.h
index 0fda47c180971..943d7182d1960 100644
--- a/mlir/include/mlir/Dialect/X86Vector/Transforms.h
+++ b/mlir/include/mlir/Dialect/X86Vector/Transforms.h
@@ -87,7 +87,8 @@ struct MaskHelper {
void populateVectorContractToFMAPatterns(RewritePatternSet &patterns);
-void populateVectorContractToPackedTypeDotProductPatterns(RewritePatternSet &patterns);
+void populateVectorContractToPackedTypeDotProductPatterns(
+ RewritePatternSet &patterns);
//===----------------------------------------------------------------------===//
/// Helpers extracted from:
diff --git a/mlir/lib/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp b/mlir/lib/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp
index 39c5cf5d9bb54..68d577326a308 100644
--- a/mlir/lib/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp
+++ b/mlir/lib/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp
@@ -23,8 +23,8 @@ using namespace mlir;
using namespace mlir::x86vector;
using namespace mlir::transform;
-void mlir::transform::ApplyVectorContractToFMAPatternsOp::
- populatePatterns(RewritePatternSet &patterns) {
+void mlir::transform::ApplyVectorContractToFMAPatternsOp::populatePatterns(
+ RewritePatternSet &patterns) {
x86vector::populateVectorContractToFMAPatterns(patterns);
}
>From 416efac3ced77e2008145ee41c513331f62a821e Mon Sep 17 00:00:00 2001
From: Arun Thangamani <arun.thangamani at intel.com>
Date: Thu, 20 Nov 2025 05:15:39 -0800
Subject: [PATCH 3/5] VNNI check, new test-cass and B broadcast
---
.../TransformOps/X86VectorTransformOps.td | 5 +-
.../TransformOps/X86VectorTransformOps.cpp | 3 +-
.../Transforms/VectorContractToFMA.cpp | 65 +++--
.../VectorContractToPackedTypeDotProduct.cpp | 233 ++++++++++++++----
.../X86Vector/vector-contract-to-fma.mlir | 105 +++++++-
...or-contract-to-packed-type-dotproduct.mlir | 205 +++++++++++++++
6 files changed, 539 insertions(+), 77 deletions(-)
diff --git a/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td b/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td
index 4009a140bb097..6192d5e31ffc5 100644
--- a/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td
+++ b/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td
@@ -21,7 +21,7 @@ def ApplyVectorContractToFMAPatternsOp : Op<Transform_Dialect,
"apply_patterns.x86vector.vector_contract_to_fma",
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
let description = [{
- Indicates that vector contract operation can be lowered to a FMA.
+ Collect patterns to lower a F32 type vector.contract operation to a FMA.
}];
let assemblyFormat = "attr-dict";
@@ -31,7 +31,8 @@ def ApplyVectorContractToPackedTypeDotProductPatternsOp : Op<Transform_Dialect,
"apply_patterns.x86vector.vector_contract_to_packed_type_dot_product",
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
let description = [{
- Indicates that vector contract operation can be lowered to a BF16/Int8 dot-product.
+ Collect patterns to lower a BF16/Int8 type vector.contract operation
+ to a BF16/Int8 dot-product.
}];
let assemblyFormat = "attr-dict";
diff --git a/mlir/lib/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp b/mlir/lib/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp
index 68d577326a308..980c585848080 100644
--- a/mlir/lib/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp
+++ b/mlir/lib/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp
@@ -1,5 +1,4 @@
-//===- X86VectorTransformOps.cpp - Implementation of Vector transform ops
-//--===//
+//===- X86VectorTransformOps.cpp -============================================//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
diff --git a/mlir/lib/Dialect/X86Vector/Transforms/VectorContractToFMA.cpp b/mlir/lib/Dialect/X86Vector/Transforms/VectorContractToFMA.cpp
index 764ec46681094..9349466ba1a34 100644
--- a/mlir/lib/Dialect/X86Vector/Transforms/VectorContractToFMA.cpp
+++ b/mlir/lib/Dialect/X86Vector/Transforms/VectorContractToFMA.cpp
@@ -40,29 +40,38 @@ struct VectorContractToFMA : public OpRewritePattern<vector::ContractionOp> {
LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
PatternRewriter &rewriter) const override {
- if (contractOp.getKind() != vector::CombiningKind::ADD) {
+ if (contractOp.getKind() != vector::CombiningKind::ADD)
return rewriter.notifyMatchFailure(contractOp,
"Expects add combining kind");
- }
VectorType lhsTy = contractOp.getLhsType();
if (!lhsTy.getElementType().isF32())
return rewriter.notifyMatchFailure(contractOp,
"Only F32 lowering is supported.");
- if (llvm::any_of(lhsTy.getShape(), [](int64_t dim) { return dim != 1; }))
- return rewriter.notifyMatchFailure(
- contractOp, "Expects one for all dimensions of LHS");
+
+ ArrayRef<int64_t> lhsShape = lhsTy.getShape();
+ llvm::SmallVector<int64_t> dimsLhs;
+ llvm::copy_if(lhsShape, std::back_inserter(dimsLhs),
+ [](int64_t dim) { return dim != 1; });
VectorType rhsTy = contractOp.getRhsType();
ArrayRef<int64_t> rhsShape = rhsTy.getShape();
llvm::SmallVector<int64_t> dimsRhs;
llvm::copy_if(rhsShape, std::back_inserter(dimsRhs),
[](int64_t dim) { return dim != 1; });
- if (dimsRhs.size() != 1)
- return rewriter.notifyMatchFailure(contractOp, "Irregular RHS shape");
+
+ if (dimsLhs.size() > 0 && dimsRhs.size() > 0)
+ return rewriter.notifyMatchFailure(
+ contractOp, "Excepts unit dimensions for either LHS or RHS shape.");
+
+ if (dimsLhs.size() != 1 && dimsRhs.size() != 1)
+ return rewriter.notifyMatchFailure(contractOp,
+ "Irregular LHS or RHS shape.");
VectorType accTy = dyn_cast<VectorType>(contractOp.getAccType());
- assert(accTy && "Invalid accumulator");
+ if (!accTy)
+ return rewriter.notifyMatchFailure(contractOp, "Wrong accmulator type");
+
ArrayRef<int64_t> accShape = accTy.getShape();
llvm::SmallVector<int64_t> dimsAcc;
llvm::copy_if(accShape, std::back_inserter(dimsAcc),
@@ -72,21 +81,39 @@ struct VectorContractToFMA : public OpRewritePattern<vector::ContractionOp> {
// Lowers vector.contract into a broadcast+FMA sequence.
auto loc = contractOp.getLoc();
- auto castLhs = vector::ShapeCastOp::create(
- rewriter, loc, VectorType::get(1, lhsTy.getElementType()),
- contractOp.getLhs());
- auto castRhs = vector::ShapeCastOp::create(
- rewriter, loc, VectorType::get(dimsRhs.front(), rhsTy.getElementType()),
- contractOp.getRhs());
auto castAcc = vector::ShapeCastOp::create(
rewriter, loc, VectorType::get(dimsAcc.front(), accTy.getElementType()),
contractOp.getAcc());
- auto broadcastLhs = vector::BroadcastOp::create(
- rewriter, loc, castRhs.getResult().getType(), castLhs);
- auto fma =
- vector::FMAOp::create(rewriter, loc, broadcastLhs, castRhs, castAcc);
- auto castFma = vector::ShapeCastOp::create(rewriter, loc, accTy, fma);
+ vector::FMAOp fma;
+
+ if (dimsRhs.size() > 0) {
+ auto castLhs = vector::ShapeCastOp::create(
+ rewriter, loc, VectorType::get(1, lhsTy.getElementType()),
+ contractOp.getLhs());
+ auto castRhs = vector::ShapeCastOp::create(
+ rewriter, loc,
+ VectorType::get(dimsRhs.front(), rhsTy.getElementType()),
+ contractOp.getRhs());
+ auto broadcastLhs = vector::BroadcastOp::create(
+ rewriter, loc, castRhs.getResult().getType(), castLhs);
+ fma =
+ vector::FMAOp::create(rewriter, loc, broadcastLhs, castRhs, castAcc);
+ } else {
+ auto castLhs = vector::ShapeCastOp::create(
+ rewriter, loc,
+ VectorType::get(dimsLhs.front(), lhsTy.getElementType()),
+ contractOp.getLhs());
+ auto castRhs = vector::ShapeCastOp::create(
+ rewriter, loc, VectorType::get(1, rhsTy.getElementType()),
+ contractOp.getRhs());
+ auto broadcastRhs = vector::BroadcastOp::create(
+ rewriter, loc, castLhs.getResult().getType(), castRhs);
+ fma =
+ vector::FMAOp::create(rewriter, loc, castLhs, broadcastRhs, castAcc);
+ }
+
+ auto castFma = vector::ShapeCastOp::create(rewriter, loc, accTy, fma);
rewriter.replaceOp(contractOp, castFma);
return success();
diff --git a/mlir/lib/Dialect/X86Vector/Transforms/VectorContractToPackedTypeDotProduct.cpp b/mlir/lib/Dialect/X86Vector/Transforms/VectorContractToPackedTypeDotProduct.cpp
index 1dabbddbebb7e..59082b9761135 100644
--- a/mlir/lib/Dialect/X86Vector/Transforms/VectorContractToPackedTypeDotProduct.cpp
+++ b/mlir/lib/Dialect/X86Vector/Transforms/VectorContractToPackedTypeDotProduct.cpp
@@ -6,6 +6,8 @@
//
//===----------------------------------------------------------------------===//
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
#include "mlir/Dialect/X86Vector/Transforms.h"
@@ -22,6 +24,90 @@ using namespace mlir;
using namespace mlir::vector;
using namespace mlir::x86vector;
+static FailureOr<SmallVector<mlir::utils::IteratorType>>
+inferIteratorsFromOutMap(AffineMap map) {
+ if (!map.isProjectedPermutation())
+ return failure();
+ SmallVector<mlir::utils::IteratorType> iterators(
+ map.getNumDims(), mlir::utils::IteratorType::reduction);
+ for (auto expr : map.getResults())
+ if (auto dim = dyn_cast<AffineDimExpr>(expr))
+ iterators[dim.getPosition()] = mlir::utils::IteratorType::parallel;
+ return iterators;
+}
+
+static bool isInVnniLayout(Operation *op, ArrayRef<AffineMap> indexingMaps,
+ std::optional<unsigned> blockingFactor) {
+ // Narrow down type operations - VNNI only applies to contractions.
+ FailureOr<linalg::ContractionDimensions> dims =
+ linalg::inferContractionDims(indexingMaps);
+ if (failed(dims))
+ return false;
+
+ auto matA = op->getOperand(0);
+ auto matB = op->getOperand(1);
+ auto typeA = dyn_cast<ShapedType>(matA.getType());
+ auto typeB = dyn_cast<ShapedType>(matB.getType());
+ unsigned rankA = typeA.getRank();
+ unsigned rankB = typeB.getRank();
+ // VNNI format requires at least 1 parallel and 2 reduction dimensions.
+ if (rankA < 3 || rankB < 3)
+ return false;
+
+ // At least two reduction dimensions are expected:
+ // one for the VNNI factor and one for the K dimension
+ if (dims->k.size() < 2)
+ return false;
+
+ // Validate affine maps - VNNI computation should be defined by the two
+ // innermost reduction iterators.
+ // The input matrix dimensions layout must match the following:
+ // - matrix A - [...][K/vnniFactor][vnniFactor]
+ // - matrix B - [...][K/vnniFactor][N][vnniFactor]
+ auto maybeIters = inferIteratorsFromOutMap(indexingMaps[2]);
+ if (failed(maybeIters))
+ return false;
+ SmallVector<mlir::utils::IteratorType> iteratorTypes = *maybeIters;
+ AffineMap mapA = indexingMaps[0];
+ AffineMap mapB = indexingMaps[1];
+
+ auto vnniDimA = dyn_cast<AffineDimExpr>(mapA.getResult(rankA - 1));
+ auto vnniDimB = dyn_cast<AffineDimExpr>(mapB.getResult(rankB - 1));
+ if (!vnniDimA || !vnniDimB || vnniDimA != vnniDimB ||
+ iteratorTypes[vnniDimA.getPosition()] !=
+ mlir::utils::IteratorType::reduction)
+ return false;
+ auto redDimA = dyn_cast<AffineDimExpr>(mapA.getResult(rankA - 2));
+ auto redDimB = dyn_cast<AffineDimExpr>(mapB.getResult(rankB - 3));
+ if (!redDimA || !redDimB || redDimA != redDimB ||
+ iteratorTypes[redDimA.getPosition()] !=
+ mlir::utils::IteratorType::reduction)
+ return false;
+ auto parallelDimB = dyn_cast<AffineDimExpr>(mapB.getResult(rankB - 2));
+ if (!parallelDimB || iteratorTypes[parallelDimB.getPosition()] !=
+ mlir::utils::IteratorType::parallel)
+ return false;
+
+ // VNNI factor must be:
+ // - the innermost inputs' dimension
+ // - statically known
+ // - multiple of 2 or equal to the specified factor
+ auto vnniDimSize = typeB.getShape().back();
+ if (vnniDimSize == ShapedType::kDynamic || vnniDimSize == 0 ||
+ vnniDimSize % 2 != 0)
+ return false;
+ if (typeA.getShape().back() != vnniDimSize)
+ return false;
+ if (blockingFactor && vnniDimSize != *blockingFactor)
+ return false;
+
+ // The split reduction dimension size should also match.
+ if (typeA.getShape().end()[-2] != typeB.getShape().end()[-3])
+ return false;
+
+ return true;
+}
+
// Implements packed type outer product contraction as a sequence
// of broadcast and packed dot-product operations.
//
@@ -41,50 +127,58 @@ struct VectorContractToPackedTypeDotProduct
LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
PatternRewriter &rewriter) const override {
- if (contractOp.getKind() != vector::CombiningKind::ADD) {
+ if (contractOp.getKind() != vector::CombiningKind::ADD)
return rewriter.notifyMatchFailure(contractOp,
"Expects add combining kind");
- }
VectorType lhsTy = contractOp.getLhsType();
if (!lhsTy.getElementType().isBF16() &&
!lhsTy.getElementType().isSignlessInteger(8))
return rewriter.notifyMatchFailure(
contractOp, "Only BF16/Int8 lowering is supported.");
- ArrayRef<int64_t> lhsShape = lhsTy.getShape();
- if (lhsTy.getElementType().isBF16() && lhsShape.back() != 2)
- return rewriter.notifyMatchFailure(
- contractOp, "The LHS vnni dim should be 2 for BF16.");
- if (lhsTy.getElementType().isSignlessInteger(8) && lhsShape.back() != 4)
- return rewriter.notifyMatchFailure(
- contractOp, "The LHS vnni dim should be 4 for Int8.");
+ if (lhsTy.getElementType().isBF16() &&
+ !isInVnniLayout(contractOp.getOperation(),
+ contractOp.getIndexingMapsArray(), 2))
+ return rewriter.notifyMatchFailure(contractOp,
+ "Input matrices not in VNNI format");
+
+ if (lhsTy.getElementType().isSignlessInteger(8) &&
+ !isInVnniLayout(contractOp.getOperation(),
+ contractOp.getIndexingMapsArray(), 4))
+ return rewriter.notifyMatchFailure(contractOp,
+ "Input matrices not in VNNI format");
+
+ ArrayRef<int64_t> lhsShape = lhsTy.getShape();
llvm::SmallVector<int64_t> dimsLhs;
llvm::copy_if(lhsShape, std::back_inserter(dimsLhs),
[](int64_t dim) { return dim != 1; });
- if (dimsLhs.size() != 1)
- return rewriter.notifyMatchFailure(contractOp, "Irregular LHS shape");
VectorType rhsTy = contractOp.getRhsType();
ArrayRef<int64_t> rhsShape = rhsTy.getShape();
- if (lhsTy.getElementType().isBF16() && rhsShape.back() != 2)
- return rewriter.notifyMatchFailure(
- contractOp, "The RHS vnni dim should be 2 for BF16.");
- if (lhsTy.getElementType().isSignlessInteger(8) && rhsShape.back() != 4)
- return rewriter.notifyMatchFailure(
- contractOp, "The RHS vnni dim should be 4 for Int8.");
llvm::SmallVector<int64_t> dimsRhs;
llvm::copy_if(rhsShape, std::back_inserter(dimsRhs),
[](int64_t dim) { return dim != 1; });
- if (dimsRhs.size() != 2)
- return rewriter.notifyMatchFailure(contractOp, "Irregular RHS shape");
- VectorType accTy = dyn_cast<VectorType>(contractOp.getAccType());
- assert(accTy && "Invalid accumulator");
- if (!accTy.getElementType().isF32() &&
- !accTy.getElementType().isSignlessInteger(32))
+ if ((dimsLhs.size() - 1) > 0 && (dimsRhs.size() - 1) > 0)
return rewriter.notifyMatchFailure(
- contractOp, "Only F32/Int32 accumulation is supported.");
+ contractOp, "Excepts unit dimensions for either LHS or RHS shape.");
+
+ if ((dimsLhs.size() - 1) != 1 && (dimsRhs.size() - 1) != 1)
+ return rewriter.notifyMatchFailure(contractOp,
+ "Irregular LHS or RHS shape.");
+
+ VectorType accTy = dyn_cast<VectorType>(contractOp.getAccType());
+ if (!accTy)
+ return rewriter.notifyMatchFailure(contractOp, "Wrong accmulator type");
+
+ if ((lhsTy.getElementType().isBF16() && !accTy.getElementType().isF32()) ||
+ (lhsTy.getElementType().isSignlessInteger(8) &&
+ !accTy.getElementType().isSignlessInteger(32)))
+ return rewriter.notifyMatchFailure(contractOp,
+ "Only F32 for BF16 or Int32 for Int8 "
+ "accumulation type is supported.");
+
ArrayRef<int64_t> accShape = accTy.getShape();
llvm::SmallVector<int64_t> dimsAcc;
llvm::copy_if(accShape, std::back_inserter(dimsAcc),
@@ -93,43 +187,78 @@ struct VectorContractToPackedTypeDotProduct
return rewriter.notifyMatchFailure(contractOp, "Irregular ACC shape");
auto loc = contractOp.getLoc();
- auto castRhs = vector::ShapeCastOp::create(
- rewriter, loc,
- VectorType::get(dimsRhs.front() * dimsRhs.back(),
- rhsTy.getElementType()),
- contractOp.getRhs());
-
auto castAcc = vector::ShapeCastOp::create(
rewriter, loc, VectorType::get(dimsAcc.front(), accTy.getElementType()),
contractOp.getAcc());
- auto castLhs = vector::ShapeCastOp::create(
- rewriter, loc, VectorType::get(dimsLhs.front(), lhsTy.getElementType()),
- contractOp.getLhs());
- auto bitcastLhs = vector::BitCastOp::create(
- rewriter, loc, VectorType::get({1}, rewriter.getIntegerType(32)),
- castLhs);
- auto broadcastLhs = vector::BroadcastOp::create(
- rewriter, loc,
- VectorType::get({dimsRhs.front()}, rewriter.getIntegerType(32)),
- bitcastLhs);
- auto bitcastLhsPkType = vector::BitCastOp::create(
- rewriter, loc, castRhs.getResult().getType(), broadcastLhs);
-
Value dp;
- if (lhsTy.getElementType().isBF16()) {
- dp = x86vector::DotBF16Op::create(
+ if ((dimsRhs.size() - 1) > 0) {
+ auto castRhs = vector::ShapeCastOp::create(
rewriter, loc,
- VectorType::get(dimsRhs.front(), rewriter.getF32Type()), castAcc,
- bitcastLhsPkType, castRhs);
- }
+ VectorType::get(dimsRhs.front() * dimsRhs.back(),
+ rhsTy.getElementType()),
+ contractOp.getRhs());
+ auto castLhs = vector::ShapeCastOp::create(
+ rewriter, loc,
+ VectorType::get(dimsLhs.front(), lhsTy.getElementType()),
+ contractOp.getLhs());
+ auto bitcastLhs = vector::BitCastOp::create(
+ rewriter, loc, VectorType::get({1}, rewriter.getIntegerType(32)),
+ castLhs);
+ auto broadcastLhs = vector::BroadcastOp::create(
+ rewriter, loc,
+ VectorType::get({dimsRhs.front()}, rewriter.getIntegerType(32)),
+ bitcastLhs);
+ auto bitcastLhsPkType = vector::BitCastOp::create(
+ rewriter, loc, castRhs.getResult().getType(), broadcastLhs);
- if (lhsTy.getElementType().isSignlessInteger(8)) {
- dp = x86vector::DotInt8Op::create(
+ if (lhsTy.getElementType().isBF16()) {
+ dp = x86vector::DotBF16Op::create(
+ rewriter, loc,
+ VectorType::get(dimsRhs.front(), rewriter.getF32Type()), castAcc,
+ bitcastLhsPkType, castRhs);
+ }
+
+ if (lhsTy.getElementType().isSignlessInteger(8)) {
+ dp = x86vector::DotInt8Op::create(
+ rewriter, loc,
+ VectorType::get(dimsRhs.front(), rewriter.getIntegerType(32)),
+ castAcc, bitcastLhsPkType, castRhs);
+ }
+ } else {
+ auto castLhs = vector::ShapeCastOp::create(
+ rewriter, loc,
+ VectorType::get(dimsLhs.front() * dimsLhs.back(),
+ lhsTy.getElementType()),
+ contractOp.getLhs());
+ auto castRhs = vector::ShapeCastOp::create(
+ rewriter, loc,
+ VectorType::get(dimsRhs.front(), rhsTy.getElementType()),
+ contractOp.getRhs());
+ auto bitcastRhs = vector::BitCastOp::create(
+ rewriter, loc, VectorType::get({1}, rewriter.getIntegerType(32)),
+ castRhs);
+ auto broadcastRhs = vector::BroadcastOp::create(
rewriter, loc,
- VectorType::get(dimsRhs.front(), rewriter.getIntegerType(32)),
- castAcc, bitcastLhsPkType, castRhs);
+ VectorType::get({dimsLhs.front()}, rewriter.getIntegerType(32)),
+ bitcastRhs);
+ auto bitcastRhsPkType = vector::BitCastOp::create(
+ rewriter, loc, castLhs.getResult().getType(), broadcastRhs);
+
+ if (lhsTy.getElementType().isBF16()) {
+ dp = x86vector::DotBF16Op::create(
+ rewriter, loc,
+ VectorType::get(dimsLhs.front(), rewriter.getF32Type()), castAcc,
+ castLhs, bitcastRhsPkType);
+ }
+
+ if (lhsTy.getElementType().isSignlessInteger(8)) {
+ dp = x86vector::DotInt8Op::create(
+ rewriter, loc,
+ VectorType::get(dimsLhs.front(), rewriter.getIntegerType(32)),
+ castAcc, castLhs, bitcastRhsPkType);
+ }
}
if (dp) {
diff --git a/mlir/test/Dialect/X86Vector/vector-contract-to-fma.mlir b/mlir/test/Dialect/X86Vector/vector-contract-to-fma.mlir
index 3a79037ca37c2..cf53710d839a1 100644
--- a/mlir/test/Dialect/X86Vector/vector-contract-to-fma.mlir
+++ b/mlir/test/Dialect/X86Vector/vector-contract-to-fma.mlir
@@ -19,8 +19,6 @@ func.func @matmul_outer_product_to_fma(
}
// CHECK-LABEL: @matmul_outer_product_to_fma
-// CHECK-COUNT-1: vector.shape_cast{{.*}}to vector<1xf32>
-// CHECK-COUNT-2: vector.shape_cast{{.*}}to vector<64xf32>
// CHECK: vector.broadcast{{.*}}vector<1xf32> to vector<64xf32>
// CHECK: vector.fma{{.*}}vector<64xf32>
// CHECK: vector.shape_cast{{.*}}vector<64xf32> to vector<1x64xf32>
@@ -37,6 +35,40 @@ module attributes {transform.with_named_sequence} {
// -----
+!vecA = vector<64x1xf32>
+!vecB = vector<1x1xf32>
+!vecC = vector<64x1xf32>
+#map = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
+func.func @matmul_outer_product_to_fma_bcst_B(
+ %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC
+{
+ %0 = vector.contract {
+ indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>}
+ %arg0, %arg1, %arg2
+ : !vecA, !vecB into !vecC
+ return %0 : !vecC
+}
+
+// CHECK-LABEL: @matmul_outer_product_to_fma_bcst_B
+// CHECK: vector.broadcast
+// CHECK: vector.fma{{.*}}vector<64xf32>
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.x86vector.vector_contract_to_fma
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
!vecA = vector<1x1x1xf32>
!vecB = vector<1x1x64xf32>
!vecC = vector<1x1x64xf32>
@@ -71,6 +103,40 @@ module attributes {transform.with_named_sequence} {
// -----
+!vecA = vector<1x64x1xf32>
+!vecB = vector<1x1x1xf32>
+!vecC = vector<1x64x1xf32>
+#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
+#map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+func.func @batch_matmul_to_fma_bcst_B(
+ %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC
+{
+ %0 = vector.contract {
+ indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["parallel", "parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>}
+ %arg0, %arg1, %arg2
+ : !vecA, !vecB into !vecC
+ return %0 : !vecC
+}
+
+// CHECK-LABEL: @batch_matmul_to_fma_bcst_B
+// CHECK: vector.broadcast
+// CHECK: vector.fma{{.*}}vector<64xf32>
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.x86vector.vector_contract_to_fma
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
!vecA = vector<1x1x1xf32>
!vecB = vector<1x1x64xf32>
!vecC = vector<1x64xf32>
@@ -105,6 +171,40 @@ module attributes {transform.with_named_sequence} {
// -----
+!vecA = vector<1x64x1xf32>
+!vecB = vector<1x1x1xf32>
+!vecC = vector<64x1xf32>
+#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
+#map2 = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+func.func @brgemm_to_fma_bcst_B(
+ %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC
+{
+ %0 = vector.contract {
+ indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["reduction", "parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>}
+ %arg0, %arg1, %arg2
+ : !vecA, !vecB into !vecC
+ return %0 : !vecC
+}
+
+// CHECK-LABEL: @brgemm_to_fma_bcst_B
+// CHECK: vector.broadcast
+// CHECK: vector.fma{{.*}}vector<64xf32>
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.x86vector.vector_contract_to_fma
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
!vecA = vector<3x1x1xf32>
!vecB = vector<3x1x64xf32>
!vecC = vector<3x1x64xf32>
@@ -208,3 +308,4 @@ module attributes {transform.with_named_sequence} {
transform.yield
}
}
+
diff --git a/mlir/test/Dialect/X86Vector/vector-contract-to-packed-type-dotproduct.mlir b/mlir/test/Dialect/X86Vector/vector-contract-to-packed-type-dotproduct.mlir
index 551f1f95ed9c0..268592c09a3a7 100644
--- a/mlir/test/Dialect/X86Vector/vector-contract-to-packed-type-dotproduct.mlir
+++ b/mlir/test/Dialect/X86Vector/vector-contract-to-packed-type-dotproduct.mlir
@@ -34,6 +34,40 @@ module attributes {transform.with_named_sequence} {
// -----
+!vecA = vector<1x16x1x2xbf16>
+!vecB = vector<1x1x1x2xbf16>
+!vecC = vector<16x1xf32>
+#map = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d3, d4)>
+#map1 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d3, d2, d4)>
+#map2 = affine_map<(d0, d4, d1, d2, d3) -> (d1, d2)>
+func.func @brgemm_to_bf16dp_bcst_B(
+ %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC
+{
+ %0 = vector.contract {
+ indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["reduction", "reduction", "parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>}
+ %arg0, %arg1, %arg2
+ : !vecA, !vecB into !vecC
+ return %0 : !vecC
+}
+
+// CHECK-LABEL: @brgemm_to_bf16dp_bcst_B
+// CHECK: vector.broadcast
+// CHECK: x86vector.avx512.dot
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
!vecA = vector<1x1x1x4xi8>
!vecB = vector<1x1x8x4xi8>
!vecC = vector<1x8xi32>
@@ -137,6 +171,41 @@ module attributes {transform.with_named_sequence} {
// -----
+!vecA = vector<1x8x1x4xi8>
+!vecB = vector<1x1x1x4xi8>
+!vecC = vector<1x8x1xi32>
+#map = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d3, d4)>
+#map1 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d3, d2, d4)>
+#map2 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d2)>
+func.func @batch_matmul_int8dp_bcst_B(
+ %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC
+{
+ %0 = vector.contract {
+ indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["parallel", "reduction", "parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>}
+ %arg0, %arg1, %arg2
+ : !vecA, !vecB into !vecC
+ return %0 : !vecC
+}
+
+
+// CHECK-LABEL: @batch_matmul_int8dp_bcst_B
+// CHECK: vector.broadcast
+// CHECK: x86vector.avx.dot.i8
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
!vecA = vector<1x1x2xbf16>
!vecB = vector<1x16x2xbf16>
!vecC = vector<1x16xf32>
@@ -171,6 +240,40 @@ module attributes {transform.with_named_sequence} {
// -----
+!vecA = vector<16x1x2xbf16>
+!vecB = vector<1x1x2xbf16>
+!vecC = vector<16x1xf32>
+#map = affine_map<(d4, d1, d2, d3) -> (d1, d3, d4)>
+#map1 = affine_map<(d4, d1, d2, d3) -> (d3, d2, d4)>
+#map2 = affine_map<(d4, d1, d2, d3) -> (d1, d2)>
+func.func @matmul_outer_product_to_bf16dp_bcst_B(
+ %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC
+{
+ %0 = vector.contract {
+ indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["reduction", "parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>}
+ %arg0, %arg1, %arg2
+ : !vecA, !vecB into !vecC
+ return %0 : !vecC
+}
+
+// CHECK-LABEL: @matmul_outer_product_to_bf16dp_bcst_B
+// CHECK: vector.broadcast
+// CHECK: x86vector.avx512.dot
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
!vecA = vector<1x1x4xi8>
!vecB = vector<1x8x4xi8>
!vecC = vector<1x8xi32>
@@ -372,3 +475,105 @@ module attributes {transform.with_named_sequence} {
transform.yield
}
}
+
+// -----
+
+!vecA = vector<1x1x1x2xbf16>
+!vecB = vector<1x1x16x2xbf16>
+!vecC = vector<1x1x16xbf16>
+#map = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d3, d4)>
+#map1 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d3, d2, d4)>
+#map2 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d2)>
+func.func @negative_acc_type(
+ %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC
+{
+ %0 = vector.contract {
+ indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["parallel", "reduction", "parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>}
+ %arg0, %arg1, %arg2
+ : !vecA, !vecB into !vecC
+ return %0 : !vecC
+}
+
+// CHECK-LABEL: @negative_acc_type
+// CHECK-NOT: x86vector.avx512.dot
+// CHECK: vector.contract
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+!vecA = vector<1x1x1x4xbf16>
+!vecB = vector<1x1x16x4xbf16>
+!vecC = vector<1x1x16xbf16>
+#map = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d3, d4)>
+#map1 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d3, d2, d4)>
+#map2 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d2)>
+func.func @negative_acc_type(
+ %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC
+{
+ %0 = vector.contract {
+ indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["parallel", "reduction", "parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>}
+ %arg0, %arg1, %arg2
+ : !vecA, !vecB into !vecC
+ return %0 : !vecC
+}
+
+// CHECK-LABEL: @negative_acc_type
+// CHECK-NOT: x86vector.avx512.dot
+// CHECK: vector.contract
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+!vecA = vector<1x1x1xbf16>
+!vecB = vector<1x1x32xbf16>
+!vecC = vector<1x32xf32>
+#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
+#map2 = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+func.func @negative_brgemm_not_vnni(
+ %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC
+{
+ %0 = vector.contract {
+ indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["reduction", "parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>}
+ %arg0, %arg1, %arg2
+ : !vecA, !vecB into !vecC
+ return %0 : !vecC
+}
+
+// CHECK-LABEL: @negative_brgemm_not_vnni
+// CHECK-NOT: x86vector.avx512.dot
+// CHECK: vector.contract
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product
+ } : !transform.any_op
+ transform.yield
+ }
+}
>From 942699fe8c8599798b102a648a174ae522576d9f Mon Sep 17 00:00:00 2001
From: Arun Thangamani <arun.thangamani at intel.com>
Date: Fri, 21 Nov 2025 00:29:33 -0800
Subject: [PATCH 4/5] formatting, few more test-cases, and dp vector shape
checks.
---
.../TransformOps/X86VectorTransformOps.td | 2 +-
.../TransformOps/X86VectorTransformOps.cpp | 2 +-
.../Transforms/VectorContractToFMA.cpp | 46 +++++---
.../VectorContractToPackedTypeDotProduct.cpp | 98 +++++++++-------
.../X86Vector/vector-contract-to-fma.mlir | 33 ++++++
...or-contract-to-packed-type-dotproduct.mlir | 110 +++++++++++++++++-
6 files changed, 226 insertions(+), 65 deletions(-)
diff --git a/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td b/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td
index 6192d5e31ffc5..3c5294ff14fc7 100644
--- a/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td
+++ b/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td
@@ -1,4 +1,4 @@
-//===- X86VectorTransformOps.td - X86Vector transform ops ---*- tablegen -*-===//
+//===- X86VectorTransformOps.td - X86Vector transform ops --*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
diff --git a/mlir/lib/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp b/mlir/lib/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp
index 980c585848080..95db208207672 100644
--- a/mlir/lib/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp
+++ b/mlir/lib/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp
@@ -1,4 +1,4 @@
-//===- X86VectorTransformOps.cpp -============================================//
+//===- X86VectorTransformOps.cpp ------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
diff --git a/mlir/lib/Dialect/X86Vector/Transforms/VectorContractToFMA.cpp b/mlir/lib/Dialect/X86Vector/Transforms/VectorContractToFMA.cpp
index 9349466ba1a34..9a6b4d07ecf58 100644
--- a/mlir/lib/Dialect/X86Vector/Transforms/VectorContractToFMA.cpp
+++ b/mlir/lib/Dialect/X86Vector/Transforms/VectorContractToFMA.cpp
@@ -42,7 +42,7 @@ struct VectorContractToFMA : public OpRewritePattern<vector::ContractionOp> {
if (contractOp.getKind() != vector::CombiningKind::ADD)
return rewriter.notifyMatchFailure(contractOp,
- "Expects add combining kind");
+ "Expects add combining kind.");
VectorType lhsTy = contractOp.getLhsType();
if (!lhsTy.getElementType().isF32())
@@ -50,50 +50,60 @@ struct VectorContractToFMA : public OpRewritePattern<vector::ContractionOp> {
"Only F32 lowering is supported.");
ArrayRef<int64_t> lhsShape = lhsTy.getShape();
- llvm::SmallVector<int64_t> dimsLhs;
- llvm::copy_if(lhsShape, std::back_inserter(dimsLhs),
+ llvm::SmallVector<int64_t> nonUnitDimLhs;
+ llvm::copy_if(lhsShape, std::back_inserter(nonUnitDimLhs),
[](int64_t dim) { return dim != 1; });
VectorType rhsTy = contractOp.getRhsType();
ArrayRef<int64_t> rhsShape = rhsTy.getShape();
- llvm::SmallVector<int64_t> dimsRhs;
- llvm::copy_if(rhsShape, std::back_inserter(dimsRhs),
+ llvm::SmallVector<int64_t> nonUnitDimRhs;
+ llvm::copy_if(rhsShape, std::back_inserter(nonUnitDimRhs),
[](int64_t dim) { return dim != 1; });
- if (dimsLhs.size() > 0 && dimsRhs.size() > 0)
+ if (nonUnitDimLhs.size() > 0 && nonUnitDimRhs.size() > 0)
return rewriter.notifyMatchFailure(
contractOp, "Excepts unit dimensions for either LHS or RHS shape.");
- if (dimsLhs.size() != 1 && dimsRhs.size() != 1)
- return rewriter.notifyMatchFailure(contractOp,
- "Irregular LHS or RHS shape.");
+ if (nonUnitDimLhs.size() != 1 && nonUnitDimRhs.size() != 1)
+ return rewriter.notifyMatchFailure(
+ contractOp,
+ "Excepts a one non-unit A/B dimension for either LHS or RHS shape.");
VectorType accTy = dyn_cast<VectorType>(contractOp.getAccType());
if (!accTy)
- return rewriter.notifyMatchFailure(contractOp, "Wrong accmulator type");
+ return rewriter.notifyMatchFailure(contractOp,
+ "Accmulator is not a vector type");
+
+ if (!accTy.getElementType().isF32())
+ return rewriter.notifyMatchFailure(contractOp,
+ "Accmulator should be F32 type.");
ArrayRef<int64_t> accShape = accTy.getShape();
- llvm::SmallVector<int64_t> dimsAcc;
- llvm::copy_if(accShape, std::back_inserter(dimsAcc),
+ llvm::SmallVector<int64_t> nonUnitDimAcc;
+ llvm::copy_if(accShape, std::back_inserter(nonUnitDimAcc),
[](int64_t dim) { return dim != 1; });
- if (dimsAcc.size() != 1)
- return rewriter.notifyMatchFailure(contractOp, "Irregular ACC shape");
+ if (nonUnitDimAcc.size() != 1)
+ return rewriter.notifyMatchFailure(contractOp,
+ "A or B dimension should be non-unit.");
// Lowers vector.contract into a broadcast+FMA sequence.
auto loc = contractOp.getLoc();
auto castAcc = vector::ShapeCastOp::create(
- rewriter, loc, VectorType::get(dimsAcc.front(), accTy.getElementType()),
+ rewriter, loc,
+ VectorType::get(nonUnitDimAcc.front(), accTy.getElementType()),
contractOp.getAcc());
vector::FMAOp fma;
- if (dimsRhs.size() > 0) {
+ // LHS shape is unit dimension. Broadcast into vector-size of non-unit
+ // dimension in RHS shape.
+ if (nonUnitDimRhs.size() > 0) {
auto castLhs = vector::ShapeCastOp::create(
rewriter, loc, VectorType::get(1, lhsTy.getElementType()),
contractOp.getLhs());
auto castRhs = vector::ShapeCastOp::create(
rewriter, loc,
- VectorType::get(dimsRhs.front(), rhsTy.getElementType()),
+ VectorType::get(nonUnitDimRhs.front(), rhsTy.getElementType()),
contractOp.getRhs());
auto broadcastLhs = vector::BroadcastOp::create(
rewriter, loc, castRhs.getResult().getType(), castLhs);
@@ -102,7 +112,7 @@ struct VectorContractToFMA : public OpRewritePattern<vector::ContractionOp> {
} else {
auto castLhs = vector::ShapeCastOp::create(
rewriter, loc,
- VectorType::get(dimsLhs.front(), lhsTy.getElementType()),
+ VectorType::get(nonUnitDimLhs.front(), lhsTy.getElementType()),
contractOp.getLhs());
auto castRhs = vector::ShapeCastOp::create(
rewriter, loc, VectorType::get(1, rhsTy.getElementType()),
diff --git a/mlir/lib/Dialect/X86Vector/Transforms/VectorContractToPackedTypeDotProduct.cpp b/mlir/lib/Dialect/X86Vector/Transforms/VectorContractToPackedTypeDotProduct.cpp
index 59082b9761135..68cd8700f5be6 100644
--- a/mlir/lib/Dialect/X86Vector/Transforms/VectorContractToPackedTypeDotProduct.cpp
+++ b/mlir/lib/Dialect/X86Vector/Transforms/VectorContractToPackedTypeDotProduct.cpp
@@ -129,7 +129,7 @@ struct VectorContractToPackedTypeDotProduct
if (contractOp.getKind() != vector::CombiningKind::ADD)
return rewriter.notifyMatchFailure(contractOp,
- "Expects add combining kind");
+ "Expects add combining kind.");
VectorType lhsTy = contractOp.getLhsType();
if (!lhsTy.getElementType().isBF16() &&
@@ -137,40 +137,36 @@ struct VectorContractToPackedTypeDotProduct
return rewriter.notifyMatchFailure(
contractOp, "Only BF16/Int8 lowering is supported.");
- if (lhsTy.getElementType().isBF16() &&
- !isInVnniLayout(contractOp.getOperation(),
- contractOp.getIndexingMapsArray(), 2))
+ unsigned int blockingFactor = lhsTy.getElementType().isBF16() ? 2 : 4;
+ if (!isInVnniLayout(contractOp.getOperation(),
+ contractOp.getIndexingMapsArray(), blockingFactor))
return rewriter.notifyMatchFailure(contractOp,
- "Input matrices not in VNNI format");
-
- if (lhsTy.getElementType().isSignlessInteger(8) &&
- !isInVnniLayout(contractOp.getOperation(),
- contractOp.getIndexingMapsArray(), 4))
- return rewriter.notifyMatchFailure(contractOp,
- "Input matrices not in VNNI format");
+ "Input matrices not in VNNI format.");
ArrayRef<int64_t> lhsShape = lhsTy.getShape();
- llvm::SmallVector<int64_t> dimsLhs;
- llvm::copy_if(lhsShape, std::back_inserter(dimsLhs),
+ llvm::SmallVector<int64_t> nonUnitDimLhs;
+ llvm::copy_if(lhsShape, std::back_inserter(nonUnitDimLhs),
[](int64_t dim) { return dim != 1; });
VectorType rhsTy = contractOp.getRhsType();
ArrayRef<int64_t> rhsShape = rhsTy.getShape();
- llvm::SmallVector<int64_t> dimsRhs;
- llvm::copy_if(rhsShape, std::back_inserter(dimsRhs),
+ llvm::SmallVector<int64_t> nonUnitDimRhs;
+ llvm::copy_if(rhsShape, std::back_inserter(nonUnitDimRhs),
[](int64_t dim) { return dim != 1; });
- if ((dimsLhs.size() - 1) > 0 && (dimsRhs.size() - 1) > 0)
- return rewriter.notifyMatchFailure(
- contractOp, "Excepts unit dimensions for either LHS or RHS shape.");
-
- if ((dimsLhs.size() - 1) != 1 && (dimsRhs.size() - 1) != 1)
+ if ((nonUnitDimLhs.size() - 1) > 0 && (nonUnitDimRhs.size() - 1) > 0)
return rewriter.notifyMatchFailure(contractOp,
- "Irregular LHS or RHS shape.");
+ "Excepts unit dimensions for either "
+ "LHS or RHS shape other than VNNI.");
+
+ if ((nonUnitDimLhs.size() - 1) != 1 && (nonUnitDimRhs.size() - 1) != 1)
+ return rewriter.notifyMatchFailure(
+ contractOp,
+ "Excepts a one non-unit A/B dimension for either LHS or RHS shape.");
VectorType accTy = dyn_cast<VectorType>(contractOp.getAccType());
if (!accTy)
- return rewriter.notifyMatchFailure(contractOp, "Wrong accmulator type");
+ return rewriter.notifyMatchFailure(contractOp, "Wrong accmulator type.");
if ((lhsTy.getElementType().isBF16() && !accTy.getElementType().isF32()) ||
(lhsTy.getElementType().isSignlessInteger(8) &&
@@ -180,35 +176,55 @@ struct VectorContractToPackedTypeDotProduct
"accumulation type is supported.");
ArrayRef<int64_t> accShape = accTy.getShape();
- llvm::SmallVector<int64_t> dimsAcc;
- llvm::copy_if(accShape, std::back_inserter(dimsAcc),
+ llvm::SmallVector<int64_t> nonUnitDimAcc;
+ llvm::copy_if(accShape, std::back_inserter(nonUnitDimAcc),
[](int64_t dim) { return dim != 1; });
- if (dimsAcc.size() != 1)
- return rewriter.notifyMatchFailure(contractOp, "Irregular ACC shape");
+ if (nonUnitDimAcc.size() != 1)
+ return rewriter.notifyMatchFailure(
+ contractOp, "A or B should be a non-unit dim in acc.");
+
+ // Non-unit dimensions should match the vector length of BF16 or Int8
+ // dot-product.
+ unsigned int nonUnitDim = nonUnitDimLhs.size() == 2 ? nonUnitDimLhs.front()
+ : nonUnitDimRhs.front();
+ if (lhsTy.getElementType().isBF16() && nonUnitDim != 4 && nonUnitDim != 8 &&
+ nonUnitDim != 16 && nonUnitDimAcc.front() == nonUnitDim)
+ return rewriter.notifyMatchFailure(
+ contractOp, "BF16 dot-product operation expects non-unit (LHR or "
+ "RHS) dim and acc dim of size 4/8/16.");
+
+ if (lhsTy.getElementType().isSignlessInteger(8) && nonUnitDim != 4 &&
+ nonUnitDim != 8 && nonUnitDimAcc.front() == nonUnitDim)
+ return rewriter.notifyMatchFailure(
+ contractOp, "Int8 dot-product operation expects non-unit (LHR or "
+ "RHS) dim and acc dim of size 4/8.");
auto loc = contractOp.getLoc();
auto castAcc = vector::ShapeCastOp::create(
- rewriter, loc, VectorType::get(dimsAcc.front(), accTy.getElementType()),
+ rewriter, loc,
+ VectorType::get(nonUnitDimAcc.front(), accTy.getElementType()),
contractOp.getAcc());
Value dp;
- if ((dimsRhs.size() - 1) > 0) {
+ // LHS shape is unit dimension. Broadcast into vector-size of non-unit
+ // dimension in RHS shape. Subtract one to remove VNNI dim.
+ if ((nonUnitDimRhs.size() - 1) > 0) {
auto castRhs = vector::ShapeCastOp::create(
rewriter, loc,
- VectorType::get(dimsRhs.front() * dimsRhs.back(),
+ VectorType::get(nonUnitDimRhs.front() * nonUnitDimRhs.back(),
rhsTy.getElementType()),
contractOp.getRhs());
auto castLhs = vector::ShapeCastOp::create(
rewriter, loc,
- VectorType::get(dimsLhs.front(), lhsTy.getElementType()),
+ VectorType::get(nonUnitDimLhs.front(), lhsTy.getElementType()),
contractOp.getLhs());
auto bitcastLhs = vector::BitCastOp::create(
rewriter, loc, VectorType::get({1}, rewriter.getIntegerType(32)),
castLhs);
auto broadcastLhs = vector::BroadcastOp::create(
rewriter, loc,
- VectorType::get({dimsRhs.front()}, rewriter.getIntegerType(32)),
+ VectorType::get({nonUnitDimRhs.front()}, rewriter.getIntegerType(32)),
bitcastLhs);
auto bitcastLhsPkType = vector::BitCastOp::create(
rewriter, loc, castRhs.getResult().getType(), broadcastLhs);
@@ -216,32 +232,32 @@ struct VectorContractToPackedTypeDotProduct
if (lhsTy.getElementType().isBF16()) {
dp = x86vector::DotBF16Op::create(
rewriter, loc,
- VectorType::get(dimsRhs.front(), rewriter.getF32Type()), castAcc,
- bitcastLhsPkType, castRhs);
+ VectorType::get(nonUnitDimRhs.front(), rewriter.getF32Type()),
+ castAcc, bitcastLhsPkType, castRhs);
}
if (lhsTy.getElementType().isSignlessInteger(8)) {
dp = x86vector::DotInt8Op::create(
rewriter, loc,
- VectorType::get(dimsRhs.front(), rewriter.getIntegerType(32)),
+ VectorType::get(nonUnitDimRhs.front(), rewriter.getIntegerType(32)),
castAcc, bitcastLhsPkType, castRhs);
}
- } else {
+ } else { // RHS shape is unit dimension.
auto castLhs = vector::ShapeCastOp::create(
rewriter, loc,
- VectorType::get(dimsLhs.front() * dimsLhs.back(),
+ VectorType::get(nonUnitDimLhs.front() * nonUnitDimLhs.back(),
lhsTy.getElementType()),
contractOp.getLhs());
auto castRhs = vector::ShapeCastOp::create(
rewriter, loc,
- VectorType::get(dimsRhs.front(), rhsTy.getElementType()),
+ VectorType::get(nonUnitDimRhs.front(), rhsTy.getElementType()),
contractOp.getRhs());
auto bitcastRhs = vector::BitCastOp::create(
rewriter, loc, VectorType::get({1}, rewriter.getIntegerType(32)),
castRhs);
auto broadcastRhs = vector::BroadcastOp::create(
rewriter, loc,
- VectorType::get({dimsLhs.front()}, rewriter.getIntegerType(32)),
+ VectorType::get({nonUnitDimLhs.front()}, rewriter.getIntegerType(32)),
bitcastRhs);
auto bitcastRhsPkType = vector::BitCastOp::create(
rewriter, loc, castLhs.getResult().getType(), broadcastRhs);
@@ -249,14 +265,14 @@ struct VectorContractToPackedTypeDotProduct
if (lhsTy.getElementType().isBF16()) {
dp = x86vector::DotBF16Op::create(
rewriter, loc,
- VectorType::get(dimsLhs.front(), rewriter.getF32Type()), castAcc,
- castLhs, bitcastRhsPkType);
+ VectorType::get(nonUnitDimLhs.front(), rewriter.getF32Type()),
+ castAcc, castLhs, bitcastRhsPkType);
}
if (lhsTy.getElementType().isSignlessInteger(8)) {
dp = x86vector::DotInt8Op::create(
rewriter, loc,
- VectorType::get(dimsLhs.front(), rewriter.getIntegerType(32)),
+ VectorType::get(nonUnitDimLhs.front(), rewriter.getIntegerType(32)),
castAcc, castLhs, bitcastRhsPkType);
}
}
diff --git a/mlir/test/Dialect/X86Vector/vector-contract-to-fma.mlir b/mlir/test/Dialect/X86Vector/vector-contract-to-fma.mlir
index cf53710d839a1..e506b166d43ff 100644
--- a/mlir/test/Dialect/X86Vector/vector-contract-to-fma.mlir
+++ b/mlir/test/Dialect/X86Vector/vector-contract-to-fma.mlir
@@ -309,3 +309,36 @@ module attributes {transform.with_named_sequence} {
}
}
+// -----
+
+!vecA = vector<1x1x1xf32>
+!vecB = vector<1x1x64xf32>
+!vecC = vector<1x1x64xi32>
+#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
+#map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+func.func @negative_accumulator_type(
+ %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC
+{
+ %0 = vector.contract {
+ indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["parallel", "parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>}
+ %arg0, %arg1, %arg2
+ : !vecA, !vecB into !vecC
+ return %0 : !vecC
+}
+
+// CHECK-LABEL: @negative_accumulator_type
+// CHECK-NOT: vector.fma
+// CHECK: vector.contract
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.x86vector.vector_contract_to_fma
+ } : !transform.any_op
+ transform.yield
+ }
+}
diff --git a/mlir/test/Dialect/X86Vector/vector-contract-to-packed-type-dotproduct.mlir b/mlir/test/Dialect/X86Vector/vector-contract-to-packed-type-dotproduct.mlir
index 268592c09a3a7..65676cbae772c 100644
--- a/mlir/test/Dialect/X86Vector/vector-contract-to-packed-type-dotproduct.mlir
+++ b/mlir/test/Dialect/X86Vector/vector-contract-to-packed-type-dotproduct.mlir
@@ -484,7 +484,7 @@ module attributes {transform.with_named_sequence} {
#map = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d3, d4)>
#map1 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d3, d2, d4)>
#map2 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d2)>
-func.func @negative_acc_type(
+func.func @negative_float_acc_type(
%arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC
{
%0 = vector.contract {
@@ -496,7 +496,7 @@ func.func @negative_acc_type(
return %0 : !vecC
}
-// CHECK-LABEL: @negative_acc_type
+// CHECK-LABEL: @negative_float_acc_type
// CHECK-NOT: x86vector.avx512.dot
// CHECK: vector.contract
@@ -512,13 +512,47 @@ module attributes {transform.with_named_sequence} {
// -----
+!vecA = vector<1x1x1x4xi8>
+!vecB = vector<1x1x8x4xi8>
+!vecC = vector<1x1x8xi8>
+#map = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d3, d4)>
+#map1 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d3, d2, d4)>
+#map2 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d2)>
+func.func @negative_int_acc_type(
+ %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC
+{
+ %0 = vector.contract {
+ indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["parallel", "reduction", "parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>}
+ %arg0, %arg1, %arg2
+ : !vecA, !vecB into !vecC
+ return %0 : !vecC
+}
+
+// CHECK-LABEL: @negative_int_acc_type
+// CHECK-NOT: x86vector.avx.dot.i8
+// CHECK: vector.contract
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
!vecA = vector<1x1x1x4xbf16>
!vecB = vector<1x1x16x4xbf16>
!vecC = vector<1x1x16xbf16>
#map = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d3, d4)>
#map1 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d3, d2, d4)>
#map2 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d2)>
-func.func @negative_acc_type(
+func.func @negative_wrong_vnni_blocking_factor_bf16(
%arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC
{
%0 = vector.contract {
@@ -530,7 +564,7 @@ func.func @negative_acc_type(
return %0 : !vecC
}
-// CHECK-LABEL: @negative_acc_type
+// CHECK-LABEL: @negative_wrong_vnni_blocking_factor_bf16
// CHECK-NOT: x86vector.avx512.dot
// CHECK: vector.contract
@@ -577,3 +611,71 @@ module attributes {transform.with_named_sequence} {
transform.yield
}
}
+
+// -----
+
+!vecA = vector<1x1x1x4xi8>
+!vecB = vector<1x1x16x4xi8>
+!vecC = vector<1x1x16xi32>
+#map = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d3, d4)>
+#map1 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d3, d2, d4)>
+#map2 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d2)>
+func.func @negative_wrong_vector_shape_int8(
+ %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC
+{
+ %0 = vector.contract {
+ indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["parallel", "reduction", "parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>}
+ %arg0, %arg1, %arg2
+ : !vecA, !vecB into !vecC
+ return %0 : !vecC
+}
+
+// CHECK-LABEL: @negative_wrong_vector_shape_int8
+// CHECK-NOT: x86vector.avx.dot.i8
+// CHECK: vector.contract
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+!vecA = vector<1x1x1x2xbf16>
+!vecB = vector<1x1x32x2xbf16>
+!vecC = vector<1x1x32xf32>
+#map = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d3, d4)>
+#map1 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d3, d2, d4)>
+#map2 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d2)>
+func.func @negative_wrong_vector_shape_bf16(
+ %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC
+{
+ %0 = vector.contract {
+ indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["parallel", "reduction", "parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>}
+ %arg0, %arg1, %arg2
+ : !vecA, !vecB into !vecC
+ return %0 : !vecC
+}
+
+// CHECK-LABEL: @negative_wrong_vector_shape_bf16
+// CHECK-NOT: x86vector.avx512.dot
+// CHECK: vector.contract
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product
+ } : !transform.any_op
+ transform.yield
+ }
+}
>From 960c24c66e1e96b7b31b0177cd189985ad1ba80d Mon Sep 17 00:00:00 2001
From: Arun Thangamani <arun.thangamani at intel.com>
Date: Fri, 21 Nov 2025 00:56:52 -0800
Subject: [PATCH 5/5] fix clang-format issue
---
mlir/lib/Dialect/X86Vector/Transforms/VectorContractToFMA.cpp | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Dialect/X86Vector/Transforms/VectorContractToFMA.cpp b/mlir/lib/Dialect/X86Vector/Transforms/VectorContractToFMA.cpp
index 9a6b4d07ecf58..f6169cb6b5400 100644
--- a/mlir/lib/Dialect/X86Vector/Transforms/VectorContractToFMA.cpp
+++ b/mlir/lib/Dialect/X86Vector/Transforms/VectorContractToFMA.cpp
@@ -83,8 +83,8 @@ struct VectorContractToFMA : public OpRewritePattern<vector::ContractionOp> {
llvm::copy_if(accShape, std::back_inserter(nonUnitDimAcc),
[](int64_t dim) { return dim != 1; });
if (nonUnitDimAcc.size() != 1)
- return rewriter.notifyMatchFailure(contractOp,
- "A or B dimension should be non-unit.");
+ return rewriter.notifyMatchFailure(
+ contractOp, "A or B dimension should be non-unit.");
// Lowers vector.contract into a broadcast+FMA sequence.
auto loc = contractOp.getLoc();
More information about the Mlir-commits
mailing list