[Mlir-commits] [mlir] [mlir][x86vector] Lower BF16 vector.contract to FMA using AVX2 BF16 packed ops. (PR #170267)
Arun Thangamani
llvmlistbot at llvm.org
Fri Dec 12 01:20:00 PST 2025
https://github.com/arun-thmn updated https://github.com/llvm/llvm-project/pull/170267
>From fe09e05837b2ee469da250562e45c4bfa15865ff Mon Sep 17 00:00:00 2001
From: Arun Thangamani <arun.thangamani at intel.com>
Date: Tue, 2 Dec 2025 01:14:35 -0800
Subject: [PATCH 1/7] bf16 vector.contract lowering to vector.fma using AVX2
BF16 packed ops.
---
.../TransformOps/X86VectorTransformOps.td | 10 +
.../mlir/Dialect/X86Vector/Transforms.h | 2 +
.../Dialect/X86Vector/Utils/X86VectorUtils.h | 43 ++++
mlir/lib/Dialect/X86Vector/CMakeLists.txt | 1 +
.../TransformOps/X86VectorTransformOps.cpp | 5 +
.../X86Vector/Transforms/CMakeLists.txt | 4 +-
.../Transforms/VectorContractBF16ToFMA.cpp | 225 +++++++++++++++++
.../VectorContractToPackedTypeDotProduct.cpp | 87 +------
.../Dialect/X86Vector/Utils/CMakeLists.txt | 13 +
.../X86Vector/Utils/X86VectorUtils.cpp | 110 +++++++++
.../vector-contract-bf16-to-fma.mlir | 229 ++++++++++++++++++
11 files changed, 642 insertions(+), 87 deletions(-)
create mode 100644 mlir/include/mlir/Dialect/X86Vector/Utils/X86VectorUtils.h
create mode 100644 mlir/lib/Dialect/X86Vector/Transforms/VectorContractBF16ToFMA.cpp
create mode 100644 mlir/lib/Dialect/X86Vector/Utils/CMakeLists.txt
create mode 100644 mlir/lib/Dialect/X86Vector/Utils/X86VectorUtils.cpp
create mode 100644 mlir/test/Dialect/X86Vector/vector-contract-bf16-to-fma.mlir
diff --git a/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td b/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td
index 3c5294ff14fc7..b4f1c0e12ef9b 100644
--- a/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td
+++ b/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td
@@ -38,6 +38,16 @@ def ApplyVectorContractToPackedTypeDotProductPatternsOp : Op<Transform_Dialect,
let assemblyFormat = "attr-dict";
}
+def ApplyVectorContractBF16ToFMAPatternsOp : Op<Transform_Dialect,
+ "apply_patterns.x86vector.vector_contract_bf16_to_fma",
+ [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+ let description = [{
+ Indicates that BF16 vector contract operation can be lowered to a FMA.
+ }];
+
+ let assemblyFormat = "attr-dict";
+}
+
#endif // X86VECTOR_TRANSFORM_OPS
diff --git a/mlir/include/mlir/Dialect/X86Vector/Transforms.h b/mlir/include/mlir/Dialect/X86Vector/Transforms.h
index fc46dff63c2b7..aba903845e429 100644
--- a/mlir/include/mlir/Dialect/X86Vector/Transforms.h
+++ b/mlir/include/mlir/Dialect/X86Vector/Transforms.h
@@ -91,6 +91,8 @@ void populateVectorContractToFMAPatterns(RewritePatternSet &patterns);
void populateVectorContractToPackedTypeDotProductPatterns(
RewritePatternSet &patterns);
+void populateVectorContractBF16ToFMAPatterns(RewritePatternSet &patterns);
+
//===----------------------------------------------------------------------===//
/// Helpers extracted from:
/// - clang/lib/Headers/avxintrin.h
diff --git a/mlir/include/mlir/Dialect/X86Vector/Utils/X86VectorUtils.h b/mlir/include/mlir/Dialect/X86Vector/Utils/X86VectorUtils.h
new file mode 100644
index 0000000000000..8a76009ddb907
--- /dev/null
+++ b/mlir/include/mlir/Dialect/X86Vector/Utils/X86VectorUtils.h
@@ -0,0 +1,43 @@
+//===- X86VectorUtils.h - X86Vector Utilities -------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_X86VECTOR_UTILS_X86VECTORUTILS_H_
+#define MLIR_DIALECT_X86VECTOR_UTILS_X86VECTORUTILS_H_
+
+#include "mlir/Support/LogicalResult.h"
+#include "llvm/ADT/ArrayRef.h"
+#include <cstdint>
+#include <optional>
+#include <string>
+
+namespace mlir {
+class Type;
+class ShapedType;
+class OpOperand;
+class AffineDimExpr;
+class AffineMap;
+class VectorType;
+class Operation;
+
+namespace x86vector {
+enum class VnniOperandRank {
+ TRANSPOSE = 3,
+ GEMM = 3,
+ BRGEMM_INS = 4,
+ BRGEMM_OUTS = 3
+};
+
+// Return true if the operation is in VNNI layout.
+// Optionally, the check can be constrained to a specific VNNI blocking factor.
+bool isInVnniLayout(Operation *op, llvm::ArrayRef<AffineMap> indexingMaps,
+ std::optional<unsigned> blockingFactor = std::nullopt);
+
+} // namespace x86vector
+} // namespace mlir
+
+#endif // MLIR_DIALECT_X86VECTOR_UTILS_X86VECTORUTILS_H_
diff --git a/mlir/lib/Dialect/X86Vector/CMakeLists.txt b/mlir/lib/Dialect/X86Vector/CMakeLists.txt
index cb1e9d01821a2..329a6c3e80254 100644
--- a/mlir/lib/Dialect/X86Vector/CMakeLists.txt
+++ b/mlir/lib/Dialect/X86Vector/CMakeLists.txt
@@ -1,3 +1,4 @@
add_subdirectory(IR)
add_subdirectory(Transforms)
add_subdirectory(TransformOps)
+add_subdirectory(Utils)
diff --git a/mlir/lib/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp b/mlir/lib/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp
index 95db208207672..172f159b43f80 100644
--- a/mlir/lib/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp
+++ b/mlir/lib/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp
@@ -32,6 +32,11 @@ void mlir::transform::ApplyVectorContractToPackedTypeDotProductPatternsOp::
x86vector::populateVectorContractToPackedTypeDotProductPatterns(patterns);
}
+void mlir::transform::ApplyVectorContractBF16ToFMAPatternsOp::populatePatterns(
+ RewritePatternSet &patterns) {
+ x86vector::populateVectorContractBF16ToFMAPatterns(patterns);
+}
+
//===----------------------------------------------------------------------===//
// Transform op registration
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/X86Vector/Transforms/CMakeLists.txt b/mlir/lib/Dialect/X86Vector/Transforms/CMakeLists.txt
index 2cab50fb591c4..9eb94691753cf 100644
--- a/mlir/lib/Dialect/X86Vector/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/X86Vector/Transforms/CMakeLists.txt
@@ -3,14 +3,16 @@ add_mlir_dialect_library(MLIRX86VectorTransforms
LegalizeForLLVMExport.cpp
VectorContractToFMA.cpp
VectorContractToPackedTypeDotProduct.cpp
+ VectorContractBF16ToFMA.cpp
LINK_LIBS PUBLIC
MLIRArithDialect
- MLIRX86VectorDialect
MLIRIR
MLIRLinalgDialect
MLIRLLVMCommonConversion
MLIRLLVMDialect
MLIRVectorDialect
MLIRVectorUtils
+ MLIRX86VectorDialect
+ MLIRX86VectorUtils
)
diff --git a/mlir/lib/Dialect/X86Vector/Transforms/VectorContractBF16ToFMA.cpp b/mlir/lib/Dialect/X86Vector/Transforms/VectorContractBF16ToFMA.cpp
new file mode 100644
index 0000000000000..58cc7f7497044
--- /dev/null
+++ b/mlir/lib/Dialect/X86Vector/Transforms/VectorContractBF16ToFMA.cpp
@@ -0,0 +1,225 @@
+//===- VectorContractBF16ToFMA.cpp
+//--------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
+#include "mlir/Dialect/X86Vector/Transforms.h"
+#include "mlir/Dialect/X86Vector/Utils/X86VectorUtils.h"
+#include "mlir/Dialect/X86Vector/X86VectorDialect.h"
+
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/Dominance.h"
+#include "mlir/IR/PatternMatch.h"
+
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+using namespace mlir;
+using namespace mlir::vector;
+using namespace mlir::x86vector;
+
+static FailureOr<llvm::SmallVector<mlir::memref::SubViewOp>>
+getSubviewFromVectorInput(Location loc, PatternRewriter &rewriter,
+ mlir::Value prodOp, int64_t mnDim, int64_t vnniDim,
+ int64_t mnDimIndx) {
+
+ llvm::SmallVector<mlir::memref::SubViewOp> subviews;
+
+ Value srcOperation;
+ SmallVector<OpFoldResult> indexVals;
+
+ if (auto transferRead =
+ prodOp.getDefiningOp<mlir::vector::TransferReadOp>()) {
+ srcOperation = transferRead.getOperand(0);
+ SmallVector<OpFoldResult> indexValues(transferRead.getIndices().begin(),
+ transferRead.getIndices().end());
+ indexVals = indexValues;
+ }
+
+ if (auto load = prodOp.getDefiningOp<mlir::vector::LoadOp>()) {
+ srcOperation = load.getOperand(0);
+ SmallVector<OpFoldResult> indexValues(load.getIndices().begin(),
+ load.getIndices().end());
+ indexVals = indexValues;
+ }
+
+ if (!srcOperation)
+ return failure();
+
+ llvm::SmallVector<OpFoldResult> strides;
+ llvm::SmallVector<OpFoldResult> sizes;
+
+ for (unsigned int i = 0; i < indexVals.size(); i++) {
+ strides.push_back(rewriter.getIndexAttr(1));
+ sizes.push_back(rewriter.getIndexAttr(1));
+ }
+
+ sizes[indexVals.size() - 1] = rewriter.getIndexAttr(vnniDim);
+ sizes[indexVals.size() - mnDimIndx] = rewriter.getIndexAttr(mnDim);
+
+ if (mnDim == 1) {
+ indexVals[indexVals.size() - 1] = rewriter.getIndexAttr(1);
+ }
+
+ auto subview = memref::SubViewOp::create(rewriter, loc, srcOperation,
+ indexVals, sizes, strides);
+ subviews.push_back(subview);
+
+ if (mnDim == 1) {
+ indexVals[indexVals.size() - 1] = rewriter.getIndexAttr(0);
+ sizes[indexVals.size() - 1] = rewriter.getIndexAttr(1);
+
+ auto unitDimEvenIndxSubview = memref::SubViewOp::create(
+ rewriter, loc, srcOperation, indexVals, sizes, strides);
+ subviews.push_back(unitDimEvenIndxSubview);
+ }
+
+ return subviews;
+}
+
+struct VectorContractBF16ToFMA
+ : public OpRewritePattern<vector::ContractionOp> {
+ using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
+ PatternRewriter &rewriter) const override {
+
+ if (contractOp.getKind() != vector::CombiningKind::ADD)
+ return rewriter.notifyMatchFailure(contractOp,
+ "Expects add combining kind.");
+
+ VectorType lhsTy = contractOp.getLhsType();
+ if (!lhsTy.getElementType().isBF16())
+ return rewriter.notifyMatchFailure(contractOp,
+ "Only BF16 lowering is supported.");
+
+ if (!isInVnniLayout(contractOp.getOperation(),
+ contractOp.getIndexingMapsArray(), 2))
+ return rewriter.notifyMatchFailure(contractOp,
+ "Input matrices not in VNNI format.");
+
+ ArrayRef<int64_t> lhsShape = lhsTy.getShape();
+ llvm::SmallVector<int64_t> nonUnitDimLhs;
+ llvm::copy_if(lhsShape, std::back_inserter(nonUnitDimLhs),
+ [](int64_t dim) { return dim != 1; });
+
+ VectorType rhsTy = contractOp.getRhsType();
+ ArrayRef<int64_t> rhsShape = rhsTy.getShape();
+ llvm::SmallVector<int64_t> nonUnitDimRhs;
+ llvm::copy_if(rhsShape, std::back_inserter(nonUnitDimRhs),
+ [](int64_t dim) { return dim != 1; });
+
+ if ((nonUnitDimLhs.size() - 1) > 0 && (nonUnitDimRhs.size() - 1) > 0)
+ return rewriter.notifyMatchFailure(contractOp,
+ "Excepts unit dimensions for either "
+ "LHS or RHS shape other than VNNI.");
+
+ if ((nonUnitDimLhs.size() - 1) != 1 && (nonUnitDimRhs.size() - 1) != 1)
+ return rewriter.notifyMatchFailure(
+ contractOp,
+ "Excepts a one non-unit A/B dimension for either LHS or RHS shape.");
+
+ VectorType accTy = dyn_cast<VectorType>(contractOp.getAccType());
+ if (!accTy)
+ return rewriter.notifyMatchFailure(contractOp, "Wrong accmulator type.");
+
+ if ((lhsTy.getElementType().isBF16() && !accTy.getElementType().isF32()))
+ return rewriter.notifyMatchFailure(
+ contractOp, "Only F32 acumulation supported for BF16 type.");
+
+ ArrayRef<int64_t> accShape = accTy.getShape();
+ llvm::SmallVector<int64_t> nonUnitDimAcc;
+ llvm::copy_if(accShape, std::back_inserter(nonUnitDimAcc),
+ [](int64_t dim) { return dim != 1; });
+ if (nonUnitDimAcc.size() != 1)
+ return rewriter.notifyMatchFailure(
+ contractOp, "A or B should be a non-unit dim in acc.");
+
+ // Non-unit dimensions should match the vector length of BF16 or Int8
+ // dot-product.
+ unsigned int nonUnitDim = nonUnitDimLhs.size() == 2 ? nonUnitDimLhs.front()
+ : nonUnitDimRhs.front();
+ if (nonUnitDim != 4 && nonUnitDim != 8 &&
+ !(nonUnitDimAcc.front() == nonUnitDim))
+ return rewriter.notifyMatchFailure(
+ contractOp, "BF16 packed load operation expects non-unit (LHR or "
+ "RHS) dim and acc dim of size 4/8.");
+
+ auto loc = contractOp.getLoc();
+ auto castAcc = vector::ShapeCastOp::create(
+ rewriter, loc,
+ VectorType::get(nonUnitDimAcc.front(), accTy.getElementType()),
+ contractOp.getAcc());
+ mlir::VectorType dstType =
+ mlir::VectorType::get(nonUnitDimAcc.front(), rewriter.getF32Type());
+
+ llvm::SmallVector<mlir::memref::SubViewOp> unitDimSubview;
+ llvm::SmallVector<mlir::memref::SubViewOp> nonUnitDimSubview;
+
+ if ((nonUnitDimRhs.size() - 1) > 0) {
+
+ auto unitSubview = getSubviewFromVectorInput(
+ loc, rewriter, contractOp.getLhs(), 1, 1, 2);
+ auto nonUnitSubview = getSubviewFromVectorInput(
+ loc, rewriter, contractOp.getRhs(), nonUnitDimRhs.front(), 2, 2);
+ if (failed(unitSubview) || failed(nonUnitSubview))
+ return failure();
+
+ unitDimSubview = *unitSubview;
+ nonUnitDimSubview = *nonUnitSubview;
+
+ } else {
+ auto nonUnitSubview = getSubviewFromVectorInput(
+ loc, rewriter, contractOp.getLhs(), nonUnitDimRhs.front(), 2, 3);
+ auto unitSubview = getSubviewFromVectorInput(
+ loc, rewriter, contractOp.getRhs(), 1, 1, 2);
+ if (failed(unitSubview) || failed(nonUnitSubview))
+ return failure();
+
+ unitDimSubview = *unitSubview;
+ nonUnitDimSubview = *nonUnitSubview;
+ }
+
+ auto loadBcstOddIndxElementToF32 = x86vector::BcstToPackedF32Op::create(
+ rewriter, loc, dstType, unitDimSubview[0]);
+ auto loadOddIndxElementF32 = x86vector::CvtPackedOddIndexedToF32Op::create(
+ rewriter, loc, dstType, nonUnitDimSubview[0]);
+ auto oddIndxFMA =
+ vector::FMAOp::create(rewriter, loc, loadBcstOddIndxElementToF32,
+ loadOddIndxElementF32, castAcc);
+
+ llvm::SmallVector<Operation *> users;
+ for (OpResult result : contractOp->getResults())
+ for (Operation *user : result.getUsers())
+ users.push_back(user);
+
+ if (users.size() == 1) {
+ rewriter.setInsertionPoint(users[0]);
+ }
+
+ auto loadBcstEvenIndxElementToF32 = x86vector::BcstToPackedF32Op::create(
+ rewriter, loc, dstType, unitDimSubview[1]);
+ auto loadEvenIndxElementF32 =
+ x86vector::CvtPackedEvenIndexedToF32Op::create(rewriter, loc, dstType,
+ nonUnitDimSubview[0]);
+
+ vector::FMAOp fma =
+ vector::FMAOp::create(rewriter, loc, loadBcstEvenIndxElementToF32,
+ loadEvenIndxElementF32, oddIndxFMA);
+
+ auto castFma = vector::ShapeCastOp::create(rewriter, loc, accTy, fma);
+ rewriter.replaceOp(contractOp, castFma);
+ return success();
+ }
+};
+
+void x86vector::populateVectorContractBF16ToFMAPatterns(
+ RewritePatternSet &patterns) {
+ patterns.add<VectorContractBF16ToFMA>(patterns.getContext());
+}
diff --git a/mlir/lib/Dialect/X86Vector/Transforms/VectorContractToPackedTypeDotProduct.cpp b/mlir/lib/Dialect/X86Vector/Transforms/VectorContractToPackedTypeDotProduct.cpp
index 1e64811db910b..a00a3e5bdd766 100644
--- a/mlir/lib/Dialect/X86Vector/Transforms/VectorContractToPackedTypeDotProduct.cpp
+++ b/mlir/lib/Dialect/X86Vector/Transforms/VectorContractToPackedTypeDotProduct.cpp
@@ -11,6 +11,7 @@
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
#include "mlir/Dialect/X86Vector/Transforms.h"
+#include "mlir/Dialect/X86Vector/Utils/X86VectorUtils.h"
#include "mlir/Dialect/X86Vector/X86VectorDialect.h"
#include "mlir/IR/BuiltinAttributes.h"
@@ -26,92 +27,6 @@ using namespace mlir::x86vector;
namespace {
-static FailureOr<SmallVector<mlir::utils::IteratorType>>
-inferIteratorsFromOutMap(AffineMap map) {
- if (!map.isProjectedPermutation())
- return failure();
- SmallVector<mlir::utils::IteratorType> iterators(
- map.getNumDims(), mlir::utils::IteratorType::reduction);
- for (auto expr : map.getResults())
- if (auto dim = dyn_cast<AffineDimExpr>(expr))
- iterators[dim.getPosition()] = mlir::utils::IteratorType::parallel;
- return iterators;
-}
-
-// Returns true if the operation is in VNNI layout.
-// Optionally, the check can be constrained to a specific VNNI blocking factor.
-static bool isInVnniLayout(Operation *op, ArrayRef<AffineMap> indexingMaps,
- std::optional<unsigned> blockingFactor) {
- // Narrow down type operations - VNNI only applies to contractions.
- FailureOr<linalg::ContractionDimensions> dims =
- linalg::inferContractionDims(indexingMaps);
- if (failed(dims))
- return false;
-
- auto matA = op->getOperand(0);
- auto matB = op->getOperand(1);
- auto typeA = dyn_cast<ShapedType>(matA.getType());
- auto typeB = dyn_cast<ShapedType>(matB.getType());
- unsigned rankA = typeA.getRank();
- unsigned rankB = typeB.getRank();
- // VNNI format requires at least 1 parallel and 2 reduction dimensions.
- if (rankA < 3 || rankB < 3)
- return false;
-
- // At least two reduction dimensions are expected:
- // one for the VNNI factor and one for the K dimension
- if (dims->k.size() < 2)
- return false;
-
- // Validate affine maps - VNNI computation should be defined by the two
- // innermost reduction iterators.
- // The input matrix dimensions layout must match the following:
- // - matrix A - [...][K/vnniFactor][vnniFactor]
- // - matrix B - [...][K/vnniFactor][N][vnniFactor]
- auto maybeIters = inferIteratorsFromOutMap(indexingMaps[2]);
- if (failed(maybeIters))
- return false;
- SmallVector<mlir::utils::IteratorType> iteratorTypes = *maybeIters;
- AffineMap mapA = indexingMaps[0];
- AffineMap mapB = indexingMaps[1];
-
- auto vnniDimA = dyn_cast<AffineDimExpr>(mapA.getResult(rankA - 1));
- auto vnniDimB = dyn_cast<AffineDimExpr>(mapB.getResult(rankB - 1));
- if (!vnniDimA || !vnniDimB || vnniDimA != vnniDimB ||
- iteratorTypes[vnniDimA.getPosition()] !=
- mlir::utils::IteratorType::reduction)
- return false;
- auto redDimA = dyn_cast<AffineDimExpr>(mapA.getResult(rankA - 2));
- auto redDimB = dyn_cast<AffineDimExpr>(mapB.getResult(rankB - 3));
- if (!redDimA || !redDimB || redDimA != redDimB ||
- iteratorTypes[redDimA.getPosition()] !=
- mlir::utils::IteratorType::reduction)
- return false;
- auto parallelDimB = dyn_cast<AffineDimExpr>(mapB.getResult(rankB - 2));
- if (!parallelDimB || iteratorTypes[parallelDimB.getPosition()] !=
- mlir::utils::IteratorType::parallel)
- return false;
-
- // VNNI factor must be:
- // - the innermost inputs' dimension
- // - statically known
- // - multiple of 2 or equal to the specified factor
- auto vnniDimSize = typeB.getShape().back();
- if (vnniDimSize == ShapedType::kDynamic || vnniDimSize == 0 ||
- vnniDimSize % 2 != 0)
- return false;
- if (typeA.getShape().back() != vnniDimSize)
- return false;
- if (blockingFactor && vnniDimSize != *blockingFactor)
- return false;
-
- // The split reduction dimension size should also match.
- if (typeA.getShape().end()[-2] != typeB.getShape().end()[-3])
- return false;
-
- return true;
-}
-
// Implements packed type outer product contraction as a sequence
// of broadcast and packed dot-product operations.
//
diff --git a/mlir/lib/Dialect/X86Vector/Utils/CMakeLists.txt b/mlir/lib/Dialect/X86Vector/Utils/CMakeLists.txt
new file mode 100644
index 0000000000000..6a2da861737ed
--- /dev/null
+++ b/mlir/lib/Dialect/X86Vector/Utils/CMakeLists.txt
@@ -0,0 +1,13 @@
+add_mlir_dialect_library(MLIRX86VectorUtils
+ X86VectorUtils.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/X86Vector/Utils
+
+ LINK_LIBS PUBLIC
+ MLIRAffineDialect
+ MLIRDialectUtils
+ MLIRFuncDialect
+ MLIRIR
+ MLIRVectorDialect
+ )
diff --git a/mlir/lib/Dialect/X86Vector/Utils/X86VectorUtils.cpp b/mlir/lib/Dialect/X86Vector/Utils/X86VectorUtils.cpp
new file mode 100644
index 0000000000000..e06307eeedcdb
--- /dev/null
+++ b/mlir/lib/Dialect/X86Vector/Utils/X86VectorUtils.cpp
@@ -0,0 +1,110 @@
+//===- X86VectorUtils.cpp - MLIR Utilities for X86VectorOps -------------===//
+//
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements utility methods for working with the Vector dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/X86Vector/Utils/X86VectorUtils.h"
+
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/TypeUtilities.h"
+#include "mlir/IR/Types.h"
+
+#define DEBUG_TYPE "x86vector-utils"
+
+using namespace mlir;
+
+static FailureOr<SmallVector<mlir::utils::IteratorType>>
+inferIteratorsFromOutMap(AffineMap map) {
+ if (!map.isProjectedPermutation())
+ return failure();
+ SmallVector<mlir::utils::IteratorType> iterators(
+ map.getNumDims(), mlir::utils::IteratorType::reduction);
+ for (auto expr : map.getResults())
+ if (auto dim = dyn_cast<AffineDimExpr>(expr))
+ iterators[dim.getPosition()] = mlir::utils::IteratorType::parallel;
+ return iterators;
+}
+
+// Returns true if the operation is in VNNI layout.
+// Optionally, the check can be constrained to a specific VNNI blocking factor.
+bool mlir::x86vector::isInVnniLayout(Operation *op, ArrayRef<AffineMap> indexingMaps,
+ std::optional<unsigned> blockingFactor) {
+ // Narrow down type operations - VNNI only applies to contractions.
+ FailureOr<linalg::ContractionDimensions> dims =
+ linalg::inferContractionDims(indexingMaps);
+ if (failed(dims))
+ return false;
+
+ auto matA = op->getOperand(0);
+ auto matB = op->getOperand(1);
+ auto typeA = dyn_cast<ShapedType>(matA.getType());
+ auto typeB = dyn_cast<ShapedType>(matB.getType());
+ unsigned rankA = typeA.getRank();
+ unsigned rankB = typeB.getRank();
+ // VNNI format requires at least 1 parallel and 2 reduction dimensions.
+ if (rankA < 3 || rankB < 3)
+ return false;
+
+ // At least two reduction dimensions are expected:
+ // one for the VNNI factor and one for the K dimension
+ if (dims->k.size() < 2)
+ return false;
+
+ // Validate affine maps - VNNI computation should be defined by the two
+ // innermost reduction iterators.
+ // The input matrix dimensions layout must match the following:
+ // - matrix A - [...][K/vnniFactor][vnniFactor]
+ // - matrix B - [...][K/vnniFactor][N][vnniFactor]
+ auto maybeIters = inferIteratorsFromOutMap(indexingMaps[2]);
+ if (failed(maybeIters))
+ return false;
+ SmallVector<mlir::utils::IteratorType> iteratorTypes = *maybeIters;
+ AffineMap mapA = indexingMaps[0];
+ AffineMap mapB = indexingMaps[1];
+
+ auto vnniDimA = dyn_cast<AffineDimExpr>(mapA.getResult(rankA - 1));
+ auto vnniDimB = dyn_cast<AffineDimExpr>(mapB.getResult(rankB - 1));
+ if (!vnniDimA || !vnniDimB || vnniDimA != vnniDimB ||
+ iteratorTypes[vnniDimA.getPosition()] !=
+ mlir::utils::IteratorType::reduction)
+ return false;
+ auto redDimA = dyn_cast<AffineDimExpr>(mapA.getResult(rankA - 2));
+ auto redDimB = dyn_cast<AffineDimExpr>(mapB.getResult(rankB - 3));
+ if (!redDimA || !redDimB || redDimA != redDimB ||
+ iteratorTypes[redDimA.getPosition()] !=
+ mlir::utils::IteratorType::reduction)
+ return false;
+ auto parallelDimB = dyn_cast<AffineDimExpr>(mapB.getResult(rankB - 2));
+ if (!parallelDimB || iteratorTypes[parallelDimB.getPosition()] !=
+ mlir::utils::IteratorType::parallel)
+ return false;
+
+ // VNNI factor must be:
+ // - the innermost inputs' dimension
+ // - statically known
+ // - multiple of 2 or equal to the specified factor
+ auto vnniDimSize = typeB.getShape().back();
+ if (vnniDimSize == ShapedType::kDynamic || vnniDimSize == 0 ||
+ vnniDimSize % 2 != 0)
+ return false;
+ if (typeA.getShape().back() != vnniDimSize)
+ return false;
+ if (blockingFactor && vnniDimSize != *blockingFactor)
+ return false;
+
+ // The split reduction dimension size should also match.
+ if (typeA.getShape().end()[-2] != typeB.getShape().end()[-3])
+ return false;
+
+ return true;
+}
diff --git a/mlir/test/Dialect/X86Vector/vector-contract-bf16-to-fma.mlir b/mlir/test/Dialect/X86Vector/vector-contract-bf16-to-fma.mlir
new file mode 100644
index 0000000000000..0bee34f23a8a4
--- /dev/null
+++ b/mlir/test/Dialect/X86Vector/vector-contract-bf16-to-fma.mlir
@@ -0,0 +1,229 @@
+// RUN: mlir-opt %s -transform-interpreter -cse -split-input-file | FileCheck %s
+
+!vecA = vector<1x1x1x2xbf16>
+!vecB = vector<1x1x8x2xbf16>
+!vecC = vector<1x8xf32>
+!memrefA = memref<1x4x1x2xbf16>
+!memrefB = memref<1x1x32x2xbf16>
+#map = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d3, d4)>
+#map1 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d3, d2, d4)>
+#map2 = affine_map<(d0, d4, d1, d2, d3) -> (d1, d2)>
+func.func @brgemm_to_fma(
+ %arg0: !memrefA, %arg1: !memrefB, %arg2: !vecC) -> !vecC
+{
+ %c0 = arith.constant 0 : index
+ %0 = ub.poison : bf16
+ %1 = vector.transfer_read %arg0[%c0, %c0, %c0, %c0], %0 {in_bounds = [true, true, true, true]} :
+ !memrefA, !vecA
+ %2 = vector.transfer_read %arg1[%c0, %c0, %c0, %c0], %0 {in_bounds = [true, true, true, true]} :
+ !memrefB, !vecB
+ %3 = vector.contract {
+ indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["reduction", "reduction", "parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>}
+ %1, %2, %arg2
+ : !vecA, !vecB into !vecC
+ return %3 : !vecC
+}
+
+// CHECK-LABEL: @brgemm_to_fma
+// CHECK: x86vector.avx.bcst_to_f32.packed
+// CHECK: x86vector.avx.cvt.packed.odd.indexed_to_f32
+// CHECK: vector.fma
+// CHECK: x86vector.avx.bcst_to_f32.packed
+// CHECK: x86vector.avx.cvt.packed.even.indexed_to_f32
+// CHECK: vector.fma
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.x86vector.vector_contract_bf16_to_fma
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+!vecA = vector<1x1x1x2xbf16>
+!vecB = vector<1x1x8x2xbf16>
+!vecC = vector<1x8xf32>
+!memrefA = memref<1x4x1x2xbf16>
+!memrefB = memref<1x1x32x2xbf16>
+#map = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d3, d4)>
+#map1 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d3, d2, d4)>
+#map2 = affine_map<(d0, d4, d1, d2, d3) -> (d1, d2)>
+func.func @brgemm_to_fma_load(
+ %arg0: !memrefA, %arg1: !memrefB, %arg2: !vecC) -> !vecC
+{
+ %c0 = arith.constant 0 : index
+ %0 = ub.poison : bf16
+ %1 = vector.load %arg0[%c0, %c0, %c0, %c0] :
+ !memrefA, !vecA
+ %2 = vector.load %arg1[%c0, %c0, %c0, %c0] :
+ !memrefB, !vecB
+ %3 = vector.contract {
+ indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["reduction", "reduction", "parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>}
+ %1, %2, %arg2
+ : !vecA, !vecB into !vecC
+ return %3 : !vecC
+}
+
+// CHECK-LABEL: @brgemm_to_fma_load
+// CHECK: x86vector.avx.bcst_to_f32.packed
+// CHECK: x86vector.avx.cvt.packed.odd.indexed_to_f32
+// CHECK: vector.fma
+// CHECK: x86vector.avx.bcst_to_f32.packed
+// CHECK: x86vector.avx.cvt.packed.even.indexed_to_f32
+// CHECK: vector.fma
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.x86vector.vector_contract_bf16_to_fma
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+!vecA = vector<1x8x1x2xbf16>
+!vecB = vector<1x1x1x2xbf16>
+!vecC = vector<8x1xf32>
+!memrefA = memref<1x32x1x2xbf16>
+!memrefB = memref<1x1x4x2xbf16>
+#map = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d3, d4)>
+#map1 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d3, d2, d4)>
+#map2 = affine_map<(d0, d4, d1, d2, d3) -> (d1, d2)>
+func.func @brgemm_to_fma_load_bcst_B(
+ %arg0: !memrefA, %arg1: !memrefB, %arg2: !vecC) -> !vecC
+{
+ %c0 = arith.constant 0 : index
+ %0 = ub.poison : bf16
+ %1 = vector.load %arg0[%c0, %c0, %c0, %c0] :
+ !memrefA, !vecA
+ %2 = vector.load %arg1[%c0, %c0, %c0, %c0] :
+ !memrefB, !vecB
+ %3 = vector.contract {
+ indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["reduction", "reduction", "parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>}
+ %1, %2, %arg2
+ : !vecA, !vecB into !vecC
+ return %3 : !vecC
+}
+
+// CHECK-LABEL: @brgemm_to_fma_load_bcst_B
+// CHECK: x86vector.avx.bcst_to_f32.packed
+// CHECK: x86vector.avx.cvt.packed.odd.indexed_to_f32
+// CHECK: vector.fma
+// CHECK: x86vector.avx.bcst_to_f32.packed
+// CHECK: x86vector.avx.cvt.packed.even.indexed_to_f32
+// CHECK: vector.fma
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.x86vector.vector_contract_bf16_to_fma
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+!vecA = vector<1x1x1x2xbf16>
+!vecB = vector<1x1x8x2xbf16>
+!vecC = vector<1x1x8xf32>
+!memrefA = memref<1x4x1x2xbf16>
+!memrefB = memref<1x1x32x2xbf16>
+#map = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d3, d4)>
+#map1 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d3, d2, d4)>
+#map2 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d2)>
+func.func @batch_matmul_fma_load(
+ %arg0: !memrefA, %arg1: !memrefB, %arg2: !vecC) -> !vecC
+{
+ %c0 = arith.constant 0 : index
+ %0 = ub.poison : bf16
+ %1 = vector.load %arg0[%c0, %c0, %c0, %c0] :
+ !memrefA, !vecA
+ %2 = vector.load %arg1[%c0, %c0, %c0, %c0] :
+ !memrefB, !vecB
+ %3 = vector.contract {
+ indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["parallel", "reduction", "parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>}
+ %1, %2, %arg2
+ : !vecA, !vecB into !vecC
+ return %3 : !vecC
+}
+
+// CHECK-LABEL: @batch_matmul_fma_load
+// CHECK: x86vector.avx.bcst_to_f32.packed
+// CHECK: x86vector.avx.cvt.packed.odd.indexed_to_f32
+// CHECK: vector.fma
+// CHECK: x86vector.avx.bcst_to_f32.packed
+// CHECK: x86vector.avx.cvt.packed.even.indexed_to_f32
+// CHECK: vector.fma
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.x86vector.vector_contract_bf16_to_fma
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+!vecA = vector<1x1x2xbf16>
+!vecB = vector<1x8x2xbf16>
+!vecC = vector<1x8xf32>
+!memrefA = memref<4x1x2xbf16>
+!memrefB = memref<1x32x2xbf16>
+#map = affine_map<(d4, d1, d2, d3) -> (d1, d3, d4)>
+#map1 = affine_map<(d4, d1, d2, d3) -> (d3, d2, d4)>
+#map2 = affine_map<(d4, d1, d2, d3) -> (d1, d2)>
+func.func @matmul_outer_product_to_fma_load(
+ %arg0: !memrefA, %arg1: !memrefB, %arg2: !vecC) -> !vecC
+{
+ %c0 = arith.constant 0 : index
+ %0 = ub.poison : bf16
+ %1 = vector.load %arg0[%c0, %c0, %c0] :
+ !memrefA, !vecA
+ %2 = vector.load %arg1[%c0, %c0, %c0] :
+ !memrefB, !vecB
+ %3 = vector.contract {
+ indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["reduction", "parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>}
+ %1, %2, %arg2
+ : !vecA, !vecB into !vecC
+ return %3 : !vecC
+}
+
+// CHECK-LABEL: @matmul_outer_product_to_fma_load
+// CHECK: x86vector.avx.bcst_to_f32.packed
+// CHECK: x86vector.avx.cvt.packed.odd.indexed_to_f32
+// CHECK: vector.fma
+// CHECK: x86vector.avx.bcst_to_f32.packed
+// CHECK: x86vector.avx.cvt.packed.even.indexed_to_f32
+// CHECK: vector.fma
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.x86vector.vector_contract_bf16_to_fma
+ } : !transform.any_op
+ transform.yield
+ }
+}
>From 36499af835134b1d69ca4b6aab73b7ef4d6fc3e4 Mon Sep 17 00:00:00 2001
From: Arun Thangamani <arun.thangamani at intel.com>
Date: Tue, 2 Dec 2025 01:38:13 -0800
Subject: [PATCH 2/7] fix - clang format errors
---
mlir/lib/Dialect/X86Vector/Utils/X86VectorUtils.cpp | 5 +++--
1 file changed, 3 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Dialect/X86Vector/Utils/X86VectorUtils.cpp b/mlir/lib/Dialect/X86Vector/Utils/X86VectorUtils.cpp
index e06307eeedcdb..a934188bb058d 100644
--- a/mlir/lib/Dialect/X86Vector/Utils/X86VectorUtils.cpp
+++ b/mlir/lib/Dialect/X86Vector/Utils/X86VectorUtils.cpp
@@ -37,8 +37,9 @@ inferIteratorsFromOutMap(AffineMap map) {
// Returns true if the operation is in VNNI layout.
// Optionally, the check can be constrained to a specific VNNI blocking factor.
-bool mlir::x86vector::isInVnniLayout(Operation *op, ArrayRef<AffineMap> indexingMaps,
- std::optional<unsigned> blockingFactor) {
+bool mlir::x86vector::isInVnniLayout(Operation *op,
+ ArrayRef<AffineMap> indexingMaps,
+ std::optional<unsigned> blockingFactor) {
// Narrow down type operations - VNNI only applies to contractions.
FailureOr<linalg::ContractionDimensions> dims =
linalg::inferContractionDims(indexingMaps);
>From af67c0c4b6c0673b460ac5b056f369523f221638 Mon Sep 17 00:00:00 2001
From: Arun Thangamani <arun.thangamani at intel.com>
Date: Mon, 8 Dec 2025 05:15:09 -0800
Subject: [PATCH 3/7] added new comments + negative test-cases
---
.../TransformOps/X86VectorTransformOps.td | 3 +-
.../mlir/Dialect/X86Vector/Transforms.h | 3 +
.../Transforms/VectorContractBF16ToFMA.cpp | 75 +++++++++++++++----
.../vector-contract-bf16-to-fma.mlir | 72 ++++++++++++++++++
4 files changed, 139 insertions(+), 14 deletions(-)
diff --git a/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td b/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td
index b4f1c0e12ef9b..9c3ed1c8092a1 100644
--- a/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td
+++ b/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td
@@ -42,7 +42,8 @@ def ApplyVectorContractBF16ToFMAPatternsOp : Op<Transform_Dialect,
"apply_patterns.x86vector.vector_contract_bf16_to_fma",
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
let description = [{
- Indicates that BF16 vector contract operation can be lowered to a FMA.
+ Collect patterns to lower a BF16 type vector.contract operation
+ to a FMA via emulation lowering using BF16 packed operations.
}];
let assemblyFormat = "attr-dict";
diff --git a/mlir/include/mlir/Dialect/X86Vector/Transforms.h b/mlir/include/mlir/Dialect/X86Vector/Transforms.h
index aba903845e429..c4960ae28cb4f 100644
--- a/mlir/include/mlir/Dialect/X86Vector/Transforms.h
+++ b/mlir/include/mlir/Dialect/X86Vector/Transforms.h
@@ -91,6 +91,9 @@ void populateVectorContractToFMAPatterns(RewritePatternSet &patterns);
void populateVectorContractToPackedTypeDotProductPatterns(
RewritePatternSet &patterns);
+// A set of patterns for lowering 32-bit packed BF16 vector contraction
+// operations to vector fused multiply-add (FMA) operations, following
+// the emulation-based approach using BF16 packed operations.
void populateVectorContractBF16ToFMAPatterns(RewritePatternSet &patterns);
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/X86Vector/Transforms/VectorContractBF16ToFMA.cpp b/mlir/lib/Dialect/X86Vector/Transforms/VectorContractBF16ToFMA.cpp
index 58cc7f7497044..43c3a38b277b7 100644
--- a/mlir/lib/Dialect/X86Vector/Transforms/VectorContractBF16ToFMA.cpp
+++ b/mlir/lib/Dialect/X86Vector/Transforms/VectorContractBF16ToFMA.cpp
@@ -19,11 +19,27 @@
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/Support/Casting.h"
using namespace mlir;
using namespace mlir::vector;
using namespace mlir::x86vector;
+// This function retrives the source operation of the load or transfer
+// reads and creates subviews for the BF16 packed-operations to
+// broadcast or load BF16 elements as F32 packed elements.
+//
+// For example:
+// ```
+// vector.load %arg0[%c0, %c0, %c0]:memref<4x1x2xbf16>,vector<1x1x2xbf16>
+// vector.load %arg0[%c0, %c0, %c0]:memref<1x32x2xbf16>,vector<1x8x2xbf16>
+// ```
+// to
+// ```
+// memref.subview %arg0[%c0,%c0,%c1]:memref<4x1x2xbf16> to memref<1x1x1xbf16>
+// memref.subview %arg1[%c0,%c0,%c0]:memref<1x32x2xbf16> to memref<1x8x2xbf16>
+// memref.subview %arg0[%c0,%c0,%c0]:memref<4x1x2xbf16> to memref<1x1x1xbf16>
+// ```
static FailureOr<llvm::SmallVector<mlir::memref::SubViewOp>>
getSubviewFromVectorInput(Location loc, PatternRewriter &rewriter,
mlir::Value prodOp, int64_t mnDim, int64_t vnniDim,
@@ -52,6 +68,10 @@ getSubviewFromVectorInput(Location loc, PatternRewriter &rewriter,
if (!srcOperation)
return failure();
+ Type srcType = srcOperation.getType();
+ if (!llvm::isa<mlir::MemRefType>(srcType))
+ return failure();
+
llvm::SmallVector<OpFoldResult> strides;
llvm::SmallVector<OpFoldResult> sizes;
@@ -83,6 +103,24 @@ getSubviewFromVectorInput(Location loc, PatternRewriter &rewriter,
return subviews;
}
+// Implements outer product contraction as a sequence of BF16-packed
+// operation even/odd loads and FMA operations.
+//
+// For example:
+// ```
+// %1 = vector.load from memref (%m1) -> vector<1x1x2xbf16>
+// %2 = vector.load from memref (%m2) -> vector<1x8x2xbf16>
+// return vector.contract %1, %2, %arg1
+// ```
+// to
+// ```
+// %1 = x86vector.avx.bcst_to_f32.packed %m1[c1] -> vector<8xf32>
+// %2 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %m2 -> vector<8xf32>
+// %3 = vector.fma %1, %2, %arg1
+// %4 = x86vector.avx.bcst_to_f32.packed %m1[c0] -> vector<8xf32>
+// %5 = x86vector.avx.cvt.packed.even.indexed_to_f32 %m2 -> vector<8xf32>
+// return vector.fma %4, %5, %3
+// ```
struct VectorContractBF16ToFMA
: public OpRewritePattern<vector::ContractionOp> {
using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
@@ -94,6 +132,9 @@ struct VectorContractBF16ToFMA
return rewriter.notifyMatchFailure(contractOp,
"Expects add combining kind.");
+ // TODO: Move this validation to a comon utility folder. Planned to
+ // do once (code refactoring), all architecture specific nanokernel
+ // passes are merged into the repo.
VectorType lhsTy = contractOp.getLhsType();
if (!lhsTy.getElementType().isBF16())
return rewriter.notifyMatchFailure(contractOp,
@@ -141,8 +182,7 @@ struct VectorContractBF16ToFMA
return rewriter.notifyMatchFailure(
contractOp, "A or B should be a non-unit dim in acc.");
- // Non-unit dimensions should match the vector length of BF16 or Int8
- // dot-product.
+ // Non-unit dimensions should match the vector length of BF16.
unsigned int nonUnitDim = nonUnitDimLhs.size() == 2 ? nonUnitDimLhs.front()
: nonUnitDimRhs.front();
if (nonUnitDim != 4 && nonUnitDim != 8 &&
@@ -151,25 +191,25 @@ struct VectorContractBF16ToFMA
contractOp, "BF16 packed load operation expects non-unit (LHR or "
"RHS) dim and acc dim of size 4/8.");
+ // Lower vector.contract to FMAs with help of BF16 packed ops.
auto loc = contractOp.getLoc();
- auto castAcc = vector::ShapeCastOp::create(
- rewriter, loc,
- VectorType::get(nonUnitDimAcc.front(), accTy.getElementType()),
- contractOp.getAcc());
- mlir::VectorType dstType =
- mlir::VectorType::get(nonUnitDimAcc.front(), rewriter.getF32Type());
-
llvm::SmallVector<mlir::memref::SubViewOp> unitDimSubview;
llvm::SmallVector<mlir::memref::SubViewOp> nonUnitDimSubview;
+ // create the unit-dimension LHS or RHS subview and the
+ // corresponding non-unit dimension LHS or RHS subview on the other-side.
+ // For example, if LHS has type vector<1x1x2xbf16> and RHS has type
+ // vector<1x8x2xbf16>, we create two subview for the LHS and one subview
+ // for the RHS. In the opposite case (non-unit dimension on the LHS), we
+ // do vice-versa.
if ((nonUnitDimRhs.size() - 1) > 0) {
-
auto unitSubview = getSubviewFromVectorInput(
loc, rewriter, contractOp.getLhs(), 1, 1, 2);
auto nonUnitSubview = getSubviewFromVectorInput(
loc, rewriter, contractOp.getRhs(), nonUnitDimRhs.front(), 2, 2);
if (failed(unitSubview) || failed(nonUnitSubview))
- return failure();
+ return rewriter.notifyMatchFailure(
+ contractOp, " The input source is not MemRef Type.");
unitDimSubview = *unitSubview;
nonUnitDimSubview = *nonUnitSubview;
@@ -180,12 +220,21 @@ struct VectorContractBF16ToFMA
auto unitSubview = getSubviewFromVectorInput(
loc, rewriter, contractOp.getRhs(), 1, 1, 2);
if (failed(unitSubview) || failed(nonUnitSubview))
- return failure();
+ return rewriter.notifyMatchFailure(
+ contractOp, " The input source is not MemRef Type.");
unitDimSubview = *unitSubview;
nonUnitDimSubview = *nonUnitSubview;
}
+ auto castAcc = vector::ShapeCastOp::create(
+ rewriter, loc,
+ VectorType::get(nonUnitDimAcc.front(), accTy.getElementType()),
+ contractOp.getAcc());
+ mlir::VectorType dstType =
+ mlir::VectorType::get(nonUnitDimAcc.front(), rewriter.getF32Type());
+
+ // Load, broadcast, and do FMA for odd indexed BF16 elements.
auto loadBcstOddIndxElementToF32 = x86vector::BcstToPackedF32Op::create(
rewriter, loc, dstType, unitDimSubview[0]);
auto loadOddIndxElementF32 = x86vector::CvtPackedOddIndexedToF32Op::create(
@@ -203,12 +252,12 @@ struct VectorContractBF16ToFMA
rewriter.setInsertionPoint(users[0]);
}
+ // Load, broadcast, and do FMA for even indexed BF16 elements.
auto loadBcstEvenIndxElementToF32 = x86vector::BcstToPackedF32Op::create(
rewriter, loc, dstType, unitDimSubview[1]);
auto loadEvenIndxElementF32 =
x86vector::CvtPackedEvenIndexedToF32Op::create(rewriter, loc, dstType,
nonUnitDimSubview[0]);
-
vector::FMAOp fma =
vector::FMAOp::create(rewriter, loc, loadBcstEvenIndxElementToF32,
loadEvenIndxElementF32, oddIndxFMA);
diff --git a/mlir/test/Dialect/X86Vector/vector-contract-bf16-to-fma.mlir b/mlir/test/Dialect/X86Vector/vector-contract-bf16-to-fma.mlir
index 0bee34f23a8a4..0b92e9367365f 100644
--- a/mlir/test/Dialect/X86Vector/vector-contract-bf16-to-fma.mlir
+++ b/mlir/test/Dialect/X86Vector/vector-contract-bf16-to-fma.mlir
@@ -227,3 +227,75 @@ module attributes {transform.with_named_sequence} {
transform.yield
}
}
+
+// -----
+
+#map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d2, d4, d1)>
+#map1 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d4, d3, d1)>
+#map2 = affine_map<(d0, d1, d2, d3, d4) -> (d2, d3)>
+
+func.func @negative_tensor_type(%arg0: tensor<4x64x32x2xbf16>, %arg1: tensor<4x32x64x2xbf16>, %arg2: vector<1x16xf32>) -> vector<1x16xf32> {
+ %0 = ub.poison : bf16
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c16 = arith.constant 16 : index
+ %extracted_slice = tensor.extract_slice %arg0[%c0, %c0, %c0, 0] [1, 4, 1, 2] [1, 1, 1, 1] : tensor<4x64x32x2xbf16> to tensor<1x4x1x2xbf16>
+ %extracted_slice_0 = tensor.extract_slice %arg1[%c0, %c0, %c0, 0] [1, 1, 32, 2] [1, 1, 1, 1] : tensor<4x32x64x2xbf16> to tensor<1x1x32x2xbf16>
+ %1 = vector.transfer_read %extracted_slice[%c0, %c0, %c0, %c0], %0 {in_bounds = [true, true, true, true]} : tensor<1x4x1x2xbf16>, vector<1x1x1x2xbf16>
+ %2 = vector.transfer_read %extracted_slice[%c0, %c1, %c0, %c0], %0 {in_bounds = [true, true, true, true]} : tensor<1x4x1x2xbf16>, vector<1x1x1x2xbf16>
+ %3 = vector.transfer_read %extracted_slice_0[%c0, %c0, %c0, %c0], %0 {in_bounds = [true, true, true, true]} : tensor<1x1x32x2xbf16>, vector<1x1x16x2xbf16>
+ %4 = vector.transfer_read %extracted_slice_0[%c0, %c0, %c16, %c0], %0 {in_bounds = [true, true, true, true]} : tensor<1x1x32x2xbf16>, vector<1x1x16x2xbf16>
+ %5 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["reduction", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %1, %3, %arg2 {unroll_shape = array<i64: 1, 2, 1, 16, 1>} : vector<1x1x1x2xbf16>, vector<1x1x16x2xbf16> into vector<1x16xf32>
+ %6 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["reduction", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %1, %4, %5 {unroll_shape = array<i64: 1, 2, 1, 16, 1>} : vector<1x1x1x2xbf16>, vector<1x1x16x2xbf16> into vector<1x16xf32>
+ %7 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["reduction", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %2, %3, %6 {unroll_shape = array<i64: 1, 2, 1, 16, 1>} : vector<1x1x1x2xbf16>, vector<1x1x16x2xbf16> into vector<1x16xf32>
+ %8 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["reduction", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %2, %4, %7 {unroll_shape = array<i64: 1, 2, 1, 16, 1>} : vector<1x1x1x2xbf16>, vector<1x1x16x2xbf16> into vector<1x16xf32>
+ return %8 : vector<1x16xf32>
+}
+
+// CHECK-LABEL: @negative_tensor_type
+// CHECK-NOT: vector.fma
+// CHECK: vector.contract
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %0 {
+ transform.apply_patterns.x86vector.vector_contract_bf16_to_fma
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+!vecA = vector<1x1x1x2xbf16>
+!vecB = vector<1x1x16x2xbf16>
+!vecC = vector<1x1x16xf32>
+#map = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d3, d4)>
+#map1 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d3, d2, d4)>
+#map2 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d2)>
+func.func @negative_no_memref_src(
+ %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC
+{
+ %0 = vector.contract {
+ indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["parallel", "reduction", "parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>}
+ %arg0, %arg1, %arg2
+ : !vecA, !vecB into !vecC
+ return %0 : !vecC
+}
+
+// CHECK-LABEL: @negative_no_memref_src
+// CHECK: vector.contract
+// CHECK-NOT: vector.fma
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.x86vector.vector_contract_bf16_to_fma
+ } : !transform.any_op
+ transform.yield
+ }
+}
>From a9ba397ac125342fadf45e58c0c80dd873b8b3bd Mon Sep 17 00:00:00 2001
From: Arun Thangamani <arun.thangamani at intel.com>
Date: Tue, 9 Dec 2025 06:46:14 -0800
Subject: [PATCH 4/7] code refactor: new comments, code simplification etc..
---
.../Transforms/VectorContractBF16ToFMA.cpp | 98 +++++++++++--------
.../X86Vector/Utils/X86VectorUtils.cpp | 2 +-
2 files changed, 58 insertions(+), 42 deletions(-)
diff --git a/mlir/lib/Dialect/X86Vector/Transforms/VectorContractBF16ToFMA.cpp b/mlir/lib/Dialect/X86Vector/Transforms/VectorContractBF16ToFMA.cpp
index 43c3a38b277b7..375364a4f1c7b 100644
--- a/mlir/lib/Dialect/X86Vector/Transforms/VectorContractBF16ToFMA.cpp
+++ b/mlir/lib/Dialect/X86Vector/Transforms/VectorContractBF16ToFMA.cpp
@@ -50,20 +50,17 @@ getSubviewFromVectorInput(Location loc, PatternRewriter &rewriter,
Value srcOperation;
SmallVector<OpFoldResult> indexVals;
- if (auto transferRead =
- prodOp.getDefiningOp<mlir::vector::TransferReadOp>()) {
- srcOperation = transferRead.getOperand(0);
- SmallVector<OpFoldResult> indexValues(transferRead.getIndices().begin(),
- transferRead.getIndices().end());
- indexVals = indexValues;
- }
+ Operation *defOp = prodOp.getDefiningOp();
+ if (!defOp)
+ return failure();
- if (auto load = prodOp.getDefiningOp<mlir::vector::LoadOp>()) {
- srcOperation = load.getOperand(0);
- SmallVector<OpFoldResult> indexValues(load.getIndices().begin(),
- load.getIndices().end());
- indexVals = indexValues;
- }
+ llvm::TypeSwitch<Operation *>(defOp)
+ .Case<mlir::vector::TransferReadOp, mlir::vector::LoadOp>(
+ [&](auto readOp) {
+ srcOperation = readOp.getOperand(0);
+ indexVals = SmallVector<OpFoldResult>(readOp.getIndices().begin(),
+ readOp.getIndices().end());
+ });
if (!srcOperation)
return failure();
@@ -75,14 +72,20 @@ getSubviewFromVectorInput(Location loc, PatternRewriter &rewriter,
llvm::SmallVector<OpFoldResult> strides;
llvm::SmallVector<OpFoldResult> sizes;
- for (unsigned int i = 0; i < indexVals.size(); i++) {
+ auto nonVNNIDimSize = indexVals.size() - 1;
+ // Create the size and stride offsets.
+ for (unsigned int i = 0; i < nonVNNIDimSize; i++) {
strides.push_back(rewriter.getIndexAttr(1));
sizes.push_back(rewriter.getIndexAttr(1));
}
- sizes[indexVals.size() - 1] = rewriter.getIndexAttr(vnniDim);
+ strides.push_back(rewriter.getIndexAttr(1));
+ sizes.push_back(rewriter.getIndexAttr(vnniDim));
+
+ // update the unit/nonUnit Dim size eiither it is A(LHS) or B(RHS).
sizes[indexVals.size() - mnDimIndx] = rewriter.getIndexAttr(mnDim);
+ // for unitDim, first broadcast odd element, so index is set to C1.
if (mnDim == 1) {
indexVals[indexVals.size() - 1] = rewriter.getIndexAttr(1);
}
@@ -91,6 +94,11 @@ getSubviewFromVectorInput(Location loc, PatternRewriter &rewriter,
indexVals, sizes, strides);
subviews.push_back(subview);
+ // For unit-dims, two subviews should be created for the odd and even
+ // indexed BF16 element because x86vector.avx.bcst_to_f32.packed op
+ // loads and broadcast the first BF16 element into packed F32. It
+ // cannot distinguish between even and odd BF16 elements within a
+ // packed pair.
if (mnDim == 1) {
indexVals[indexVals.size() - 1] = rewriter.getIndexAttr(0);
sizes[indexVals.size() - 1] = rewriter.getIndexAttr(1);
@@ -193,8 +201,6 @@ struct VectorContractBF16ToFMA
// Lower vector.contract to FMAs with help of BF16 packed ops.
auto loc = contractOp.getLoc();
- llvm::SmallVector<mlir::memref::SubViewOp> unitDimSubview;
- llvm::SmallVector<mlir::memref::SubViewOp> nonUnitDimSubview;
// create the unit-dimension LHS or RHS subview and the
// corresponding non-unit dimension LHS or RHS subview on the other-side.
@@ -202,30 +208,40 @@ struct VectorContractBF16ToFMA
// vector<1x8x2xbf16>, we create two subview for the LHS and one subview
// for the RHS. In the opposite case (non-unit dimension on the LHS), we
// do vice-versa.
- if ((nonUnitDimRhs.size() - 1) > 0) {
- auto unitSubview = getSubviewFromVectorInput(
- loc, rewriter, contractOp.getLhs(), 1, 1, 2);
- auto nonUnitSubview = getSubviewFromVectorInput(
- loc, rewriter, contractOp.getRhs(), nonUnitDimRhs.front(), 2, 2);
- if (failed(unitSubview) || failed(nonUnitSubview))
- return rewriter.notifyMatchFailure(
- contractOp, " The input source is not MemRef Type.");
-
- unitDimSubview = *unitSubview;
- nonUnitDimSubview = *nonUnitSubview;
-
- } else {
- auto nonUnitSubview = getSubviewFromVectorInput(
- loc, rewriter, contractOp.getLhs(), nonUnitDimRhs.front(), 2, 3);
- auto unitSubview = getSubviewFromVectorInput(
- loc, rewriter, contractOp.getRhs(), 1, 1, 2);
- if (failed(unitSubview) || failed(nonUnitSubview))
- return rewriter.notifyMatchFailure(
- contractOp, " The input source is not MemRef Type.");
-
- unitDimSubview = *unitSubview;
- nonUnitDimSubview = *nonUnitSubview;
- }
+ bool rhsHasMultipleNonUnitDims = (nonUnitDimRhs.size() - 1) > 0;
+ // Select which operand is "unit" and which is "non-unit".
+ Value unitSrc =
+ rhsHasMultipleNonUnitDims ? contractOp.getLhs() : contractOp.getRhs();
+ Value nonUnitSrc =
+ rhsHasMultipleNonUnitDims ? contractOp.getRhs() : contractOp.getLhs();
+
+ // mnDim index differs depending on the orientation.
+ int unitMnDim = rhsHasMultipleNonUnitDims ? 2 : 2; // same for both
+ int nonUnitMnDim = rhsHasMultipleNonUnitDims ? 2 : 3; // A or B
+
+ // VNNI factor: always 1 for unit dims, 2 for non-unit dims.
+ int unitVnni = 1;
+ int nonUnitVnni = 2;
+
+ // Non-unit dim size.
+ int nonUnitSize = nonUnitDimRhs.front();
+
+ // Build subviews.
+ auto unitSubview = getSubviewFromVectorInput(
+ loc, rewriter, unitSrc, /*size=*/1, unitVnni, unitMnDim);
+
+ auto nonUnitSubview = getSubviewFromVectorInput(loc, rewriter, nonUnitSrc,
+ /*size=*/nonUnitSize,
+ nonUnitVnni, nonUnitMnDim);
+
+ // Check failures once.
+ if (failed(unitSubview) || failed(nonUnitSubview))
+ return rewriter.notifyMatchFailure(
+ contractOp, "The input source is not MemRef Type.");
+
+ llvm::SmallVector<mlir::memref::SubViewOp> unitDimSubview = *unitSubview;
+ llvm::SmallVector<mlir::memref::SubViewOp> nonUnitDimSubview =
+ *nonUnitSubview;
auto castAcc = vector::ShapeCastOp::create(
rewriter, loc,
diff --git a/mlir/lib/Dialect/X86Vector/Utils/X86VectorUtils.cpp b/mlir/lib/Dialect/X86Vector/Utils/X86VectorUtils.cpp
index a934188bb058d..f7e71a32766ab 100644
--- a/mlir/lib/Dialect/X86Vector/Utils/X86VectorUtils.cpp
+++ b/mlir/lib/Dialect/X86Vector/Utils/X86VectorUtils.cpp
@@ -66,7 +66,7 @@ bool mlir::x86vector::isInVnniLayout(Operation *op,
// The input matrix dimensions layout must match the following:
// - matrix A - [...][K/vnniFactor][vnniFactor]
// - matrix B - [...][K/vnniFactor][N][vnniFactor]
- auto maybeIters = inferIteratorsFromOutMap(indexingMaps[2]);
+ auto maybeIters = inferIteratorsFromOutMap(indexingMaps[2] /* outs */);
if (failed(maybeIters))
return false;
SmallVector<mlir::utils::IteratorType> iteratorTypes = *maybeIters;
>From e05648e0a4ccb0a3291312b58eb4a0f33dc27a1e Mon Sep 17 00:00:00 2001
From: Arun Thangamani <arun.thangamani at intel.com>
Date: Wed, 10 Dec 2025 06:24:38 -0800
Subject: [PATCH 5/7] code refactor: new comments, code simplification etc..
---
.../Dialect/X86Vector/Utils/X86VectorUtils.h | 11 --
.../Transforms/VectorContractBF16ToFMA.cpp | 139 +++++++++---------
.../Dialect/X86Vector/Utils/CMakeLists.txt | 1 +
.../X86Vector/Utils/X86VectorUtils.cpp | 7 +-
4 files changed, 73 insertions(+), 85 deletions(-)
diff --git a/mlir/include/mlir/Dialect/X86Vector/Utils/X86VectorUtils.h b/mlir/include/mlir/Dialect/X86Vector/Utils/X86VectorUtils.h
index 8a76009ddb907..2de9a3122cbd9 100644
--- a/mlir/include/mlir/Dialect/X86Vector/Utils/X86VectorUtils.h
+++ b/mlir/include/mlir/Dialect/X86Vector/Utils/X86VectorUtils.h
@@ -16,21 +16,10 @@
#include <string>
namespace mlir {
-class Type;
-class ShapedType;
-class OpOperand;
-class AffineDimExpr;
class AffineMap;
-class VectorType;
class Operation;
namespace x86vector {
-enum class VnniOperandRank {
- TRANSPOSE = 3,
- GEMM = 3,
- BRGEMM_INS = 4,
- BRGEMM_OUTS = 3
-};
// Return true if the operation is in VNNI layout.
// Optionally, the check can be constrained to a specific VNNI blocking factor.
diff --git a/mlir/lib/Dialect/X86Vector/Transforms/VectorContractBF16ToFMA.cpp b/mlir/lib/Dialect/X86Vector/Transforms/VectorContractBF16ToFMA.cpp
index 375364a4f1c7b..0d3576c4d5e57 100644
--- a/mlir/lib/Dialect/X86Vector/Transforms/VectorContractBF16ToFMA.cpp
+++ b/mlir/lib/Dialect/X86Vector/Transforms/VectorContractBF16ToFMA.cpp
@@ -1,5 +1,4 @@
-//===- VectorContractBF16ToFMA.cpp
-//--------------------------------------------===//
+//===- VectorContractBF16ToFMA.cpp-----------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -29,83 +28,82 @@ using namespace mlir::x86vector;
// reads and creates subviews for the BF16 packed-operations to
// broadcast or load BF16 elements as F32 packed elements.
//
-// For example:
+// Example(1) Unit Dim:
// ```
// vector.load %arg0[%c0, %c0, %c0]:memref<4x1x2xbf16>,vector<1x1x2xbf16>
-// vector.load %arg0[%c0, %c0, %c0]:memref<1x32x2xbf16>,vector<1x8x2xbf16>
// ```
// to
// ```
// memref.subview %arg0[%c0,%c0,%c1]:memref<4x1x2xbf16> to memref<1x1x1xbf16>
-// memref.subview %arg1[%c0,%c0,%c0]:memref<1x32x2xbf16> to memref<1x8x2xbf16>
// memref.subview %arg0[%c0,%c0,%c0]:memref<4x1x2xbf16> to memref<1x1x1xbf16>
// ```
-static FailureOr<llvm::SmallVector<mlir::memref::SubViewOp>>
-getSubviewFromVectorInput(Location loc, PatternRewriter &rewriter,
- mlir::Value prodOp, int64_t mnDim, int64_t vnniDim,
- int64_t mnDimIndx) {
-
- llvm::SmallVector<mlir::memref::SubViewOp> subviews;
-
- Value srcOperation;
- SmallVector<OpFoldResult> indexVals;
+//
+// Example(2) Non-unit Dim:
+// ```
+// vector.load %arg1[%c0, %c0, %c0]:memref<1x32x2xbf16>,vector<1x8x2xbf16>
+// ```
+// to
+// ```
+// memref.subview %arg1[%c0,%c0,%c0]:memref<1x32x2xbf16> to memref<1x8x2xbf16>
+// ```
+static FailureOr<SmallVector<memref::SubViewOp>>
+getSubviewFromVectorInput(Location loc, PatternRewriter &rewriter, Value prodOp,
+ int64_t mnDimSize, int64_t vnniDimSize,
+ int64_t mnDimIdx) {
Operation *defOp = prodOp.getDefiningOp();
if (!defOp)
return failure();
- llvm::TypeSwitch<Operation *>(defOp)
- .Case<mlir::vector::TransferReadOp, mlir::vector::LoadOp>(
- [&](auto readOp) {
- srcOperation = readOp.getOperand(0);
- indexVals = SmallVector<OpFoldResult>(readOp.getIndices().begin(),
- readOp.getIndices().end());
- });
+ Value srcOperation;
+ SmallVector<OpFoldResult> indexVals;
+ llvm::TypeSwitch<Operation *>(defOp).Case<TransferReadOp, LoadOp>(
+ [&](auto readOp) {
+ srcOperation = readOp.getOperand(0);
+ indexVals = SmallVector<OpFoldResult>(readOp.getIndices().begin(),
+ readOp.getIndices().end());
+ });
if (!srcOperation)
return failure();
Type srcType = srcOperation.getType();
- if (!llvm::isa<mlir::MemRefType>(srcType))
+ if (!llvm::isa<MemRefType>(srcType))
return failure();
- llvm::SmallVector<OpFoldResult> strides;
- llvm::SmallVector<OpFoldResult> sizes;
-
auto nonVNNIDimSize = indexVals.size() - 1;
// Create the size and stride offsets.
- for (unsigned int i = 0; i < nonVNNIDimSize; i++) {
- strides.push_back(rewriter.getIndexAttr(1));
- sizes.push_back(rewriter.getIndexAttr(1));
- }
+ auto one = rewriter.getIndexAttr(1);
+ SmallVector<OpFoldResult> strides(nonVNNIDimSize, one);
+ SmallVector<OpFoldResult> sizes(nonVNNIDimSize, one);
strides.push_back(rewriter.getIndexAttr(1));
- sizes.push_back(rewriter.getIndexAttr(vnniDim));
+ sizes.push_back(rewriter.getIndexAttr(vnniDimSize));
- // update the unit/nonUnit Dim size eiither it is A(LHS) or B(RHS).
- sizes[indexVals.size() - mnDimIndx] = rewriter.getIndexAttr(mnDim);
+ // update the unit/nonUnit Dim size either it is A(LHS) or B(RHS).
+ sizes[indexVals.size() - mnDimIdx] = rewriter.getIndexAttr(mnDimSize);
- // for unitDim, first broadcast odd element, so index is set to C1.
- if (mnDim == 1) {
+ // for unitDim, first broadcast odd element, so index is set to 1.
+ if (mnDimSize == 1)
indexVals[indexVals.size() - 1] = rewriter.getIndexAttr(1);
- }
+ llvm::SmallVector<memref::SubViewOp> subviews;
auto subview = memref::SubViewOp::create(rewriter, loc, srcOperation,
indexVals, sizes, strides);
subviews.push_back(subview);
// For unit-dims, two subviews should be created for the odd and even
- // indexed BF16 element because x86vector.avx.bcst_to_f32.packed op
- // loads and broadcast the first BF16 element into packed F32. It
+ // element in the VNNI tuple (2xbf16) because x86vector.avx.bcst_to_f32.packed
+ // op loads and broadcast the first BF16 element into packed F32. It
// cannot distinguish between even and odd BF16 elements within a
// packed pair.
- if (mnDim == 1) {
+ if (mnDimSize == 1) {
indexVals[indexVals.size() - 1] = rewriter.getIndexAttr(0);
sizes[indexVals.size() - 1] = rewriter.getIndexAttr(1);
- auto unitDimEvenIndxSubview = memref::SubViewOp::create(
+ auto unitDimEvenIdxSubview = memref::SubViewOp::create(
rewriter, loc, srcOperation, indexVals, sizes, strides);
- subviews.push_back(unitDimEvenIndxSubview);
+ subviews.push_back(unitDimEvenIdxSubview);
}
return subviews;
@@ -140,7 +138,7 @@ struct VectorContractBF16ToFMA
return rewriter.notifyMatchFailure(contractOp,
"Expects add combining kind.");
- // TODO: Move this validation to a comon utility folder. Planned to
+ // TODO: Move this validation to a common utility folder. Planned to
// do once (code refactoring), all architecture specific nanokernel
// passes are merged into the repo.
VectorType lhsTy = contractOp.getLhsType();
@@ -149,10 +147,19 @@ struct VectorContractBF16ToFMA
"Only BF16 lowering is supported.");
if (!isInVnniLayout(contractOp.getOperation(),
- contractOp.getIndexingMapsArray(), 2))
+ contractOp.getIndexingMapsArray(),
+ /*blockingFactor=*/2))
return rewriter.notifyMatchFailure(contractOp,
"Input matrices not in VNNI format.");
+ VectorType accTy = dyn_cast<VectorType>(contractOp.getAccType());
+ if (!accTy)
+ return rewriter.notifyMatchFailure(contractOp, "Wrong accmulator type.");
+
+ if ((lhsTy.getElementType().isBF16() && !accTy.getElementType().isF32()))
+ return rewriter.notifyMatchFailure(
+ contractOp, "Only F32 acumulation supported for BF16 type.");
+
ArrayRef<int64_t> lhsShape = lhsTy.getShape();
llvm::SmallVector<int64_t> nonUnitDimLhs;
llvm::copy_if(lhsShape, std::back_inserter(nonUnitDimLhs),
@@ -174,13 +181,13 @@ struct VectorContractBF16ToFMA
contractOp,
"Excepts a one non-unit A/B dimension for either LHS or RHS shape.");
- VectorType accTy = dyn_cast<VectorType>(contractOp.getAccType());
+ /* VectorType accTy = dyn_cast<VectorType>(contractOp.getAccType());
if (!accTy)
return rewriter.notifyMatchFailure(contractOp, "Wrong accmulator type.");
if ((lhsTy.getElementType().isBF16() && !accTy.getElementType().isF32()))
return rewriter.notifyMatchFailure(
- contractOp, "Only F32 acumulation supported for BF16 type.");
+ contractOp, "Only F32 acumulation supported for BF16 type."); */
ArrayRef<int64_t> accShape = accTy.getShape();
llvm::SmallVector<int64_t> nonUnitDimAcc;
@@ -216,8 +223,7 @@ struct VectorContractBF16ToFMA
rhsHasMultipleNonUnitDims ? contractOp.getRhs() : contractOp.getLhs();
// mnDim index differs depending on the orientation.
- int unitMnDim = rhsHasMultipleNonUnitDims ? 2 : 2; // same for both
- int nonUnitMnDim = rhsHasMultipleNonUnitDims ? 2 : 3; // A or B
+ int mnDimIdx = rhsHasMultipleNonUnitDims ? 2 : 3; // A or B
// VNNI factor: always 1 for unit dims, 2 for non-unit dims.
int unitVnni = 1;
@@ -228,55 +234,52 @@ struct VectorContractBF16ToFMA
// Build subviews.
auto unitSubview = getSubviewFromVectorInput(
- loc, rewriter, unitSrc, /*size=*/1, unitVnni, unitMnDim);
+ loc, rewriter, unitSrc, /*size=*/1, unitVnni, mnDimIdx);
- auto nonUnitSubview = getSubviewFromVectorInput(loc, rewriter, nonUnitSrc,
- /*size=*/nonUnitSize,
- nonUnitVnni, nonUnitMnDim);
+ auto nonUnitSubview =
+ getSubviewFromVectorInput(loc, rewriter, nonUnitSrc,
+ /*size=*/nonUnitSize, nonUnitVnni, mnDimIdx);
// Check failures once.
if (failed(unitSubview) || failed(nonUnitSubview))
return rewriter.notifyMatchFailure(
contractOp, "The input source is not MemRef Type.");
- llvm::SmallVector<mlir::memref::SubViewOp> unitDimSubview = *unitSubview;
- llvm::SmallVector<mlir::memref::SubViewOp> nonUnitDimSubview =
- *nonUnitSubview;
+ SmallVector<memref::SubViewOp> unitDimSubview = *unitSubview;
+ SmallVector<memref::SubViewOp> nonUnitDimSubview = *nonUnitSubview;
auto castAcc = vector::ShapeCastOp::create(
rewriter, loc,
VectorType::get(nonUnitDimAcc.front(), accTy.getElementType()),
contractOp.getAcc());
- mlir::VectorType dstType =
- mlir::VectorType::get(nonUnitDimAcc.front(), rewriter.getF32Type());
+ VectorType dstType =
+ VectorType::get(nonUnitDimAcc.front(), rewriter.getF32Type());
// Load, broadcast, and do FMA for odd indexed BF16 elements.
- auto loadBcstOddIndxElementToF32 = x86vector::BcstToPackedF32Op::create(
+ auto loadBcstOddIdxElementToF32 = x86vector::BcstToPackedF32Op::create(
rewriter, loc, dstType, unitDimSubview[0]);
- auto loadOddIndxElementF32 = x86vector::CvtPackedOddIndexedToF32Op::create(
+ auto loadOddIdxElementF32 = x86vector::CvtPackedOddIndexedToF32Op::create(
rewriter, loc, dstType, nonUnitDimSubview[0]);
- auto oddIndxFMA =
- vector::FMAOp::create(rewriter, loc, loadBcstOddIndxElementToF32,
- loadOddIndxElementF32, castAcc);
+ auto oddIdxFMA =
+ vector::FMAOp::create(rewriter, loc, loadBcstOddIdxElementToF32,
+ loadOddIdxElementF32, castAcc);
llvm::SmallVector<Operation *> users;
for (OpResult result : contractOp->getResults())
for (Operation *user : result.getUsers())
users.push_back(user);
- if (users.size() == 1) {
+ if (users.size() == 1)
rewriter.setInsertionPoint(users[0]);
- }
// Load, broadcast, and do FMA for even indexed BF16 elements.
- auto loadBcstEvenIndxElementToF32 = x86vector::BcstToPackedF32Op::create(
+ auto loadBcstEvenIdxElementToF32 = x86vector::BcstToPackedF32Op::create(
rewriter, loc, dstType, unitDimSubview[1]);
- auto loadEvenIndxElementF32 =
- x86vector::CvtPackedEvenIndexedToF32Op::create(rewriter, loc, dstType,
- nonUnitDimSubview[0]);
+ auto loadEvenIdxElementF32 = x86vector::CvtPackedEvenIndexedToF32Op::create(
+ rewriter, loc, dstType, nonUnitDimSubview[0]);
vector::FMAOp fma =
- vector::FMAOp::create(rewriter, loc, loadBcstEvenIndxElementToF32,
- loadEvenIndxElementF32, oddIndxFMA);
+ vector::FMAOp::create(rewriter, loc, loadBcstEvenIdxElementToF32,
+ loadEvenIdxElementF32, oddIdxFMA);
auto castFma = vector::ShapeCastOp::create(rewriter, loc, accTy, fma);
rewriter.replaceOp(contractOp, castFma);
diff --git a/mlir/lib/Dialect/X86Vector/Utils/CMakeLists.txt b/mlir/lib/Dialect/X86Vector/Utils/CMakeLists.txt
index 6a2da861737ed..595846489f6c9 100644
--- a/mlir/lib/Dialect/X86Vector/Utils/CMakeLists.txt
+++ b/mlir/lib/Dialect/X86Vector/Utils/CMakeLists.txt
@@ -9,5 +9,6 @@ add_mlir_dialect_library(MLIRX86VectorUtils
MLIRDialectUtils
MLIRFuncDialect
MLIRIR
+ MLIRLinalgDialect
MLIRVectorDialect
)
diff --git a/mlir/lib/Dialect/X86Vector/Utils/X86VectorUtils.cpp b/mlir/lib/Dialect/X86Vector/Utils/X86VectorUtils.cpp
index f7e71a32766ab..bb31686a14616 100644
--- a/mlir/lib/Dialect/X86Vector/Utils/X86VectorUtils.cpp
+++ b/mlir/lib/Dialect/X86Vector/Utils/X86VectorUtils.cpp
@@ -5,10 +5,6 @@
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
-//
-// This file implements utility methods for working with the Vector dialect.
-//
-//===----------------------------------------------------------------------===//
#include "mlir/Dialect/X86Vector/Utils/X86VectorUtils.h"
@@ -19,9 +15,8 @@
#include "mlir/IR/TypeUtilities.h"
#include "mlir/IR/Types.h"
-#define DEBUG_TYPE "x86vector-utils"
-
using namespace mlir;
+using namespace mlir::x86vector;
static FailureOr<SmallVector<mlir::utils::IteratorType>>
inferIteratorsFromOutMap(AffineMap map) {
>From 18e77daf35703c0ca9606186bdab714d4571eeb0 Mon Sep 17 00:00:00 2001
From: Arun Thangamani <arun.thangamani at intel.com>
Date: Thu, 11 Dec 2025 07:53:07 -0800
Subject: [PATCH 6/7] creating new namespace and removing hard coded values +
others
---
.../Transforms/VectorContractBF16ToFMA.cpp | 51 ++++++++++---------
.../X86Vector/Utils/X86VectorUtils.cpp | 12 +++--
.../vector-contract-bf16-to-fma.mlir | 46 +++++++++++++++++
3 files changed, 81 insertions(+), 28 deletions(-)
diff --git a/mlir/lib/Dialect/X86Vector/Transforms/VectorContractBF16ToFMA.cpp b/mlir/lib/Dialect/X86Vector/Transforms/VectorContractBF16ToFMA.cpp
index 0d3576c4d5e57..b9ccee9bd7186 100644
--- a/mlir/lib/Dialect/X86Vector/Transforms/VectorContractBF16ToFMA.cpp
+++ b/mlir/lib/Dialect/X86Vector/Transforms/VectorContractBF16ToFMA.cpp
@@ -24,7 +24,7 @@ using namespace mlir;
using namespace mlir::vector;
using namespace mlir::x86vector;
-// This function retrives the source operation of the load or transfer
+// This function retrieves the source operation of the load or transfer
// reads and creates subviews for the BF16 packed-operations to
// broadcast or load BF16 elements as F32 packed elements.
//
@@ -81,7 +81,7 @@ getSubviewFromVectorInput(Location loc, PatternRewriter &rewriter, Value prodOp,
sizes.push_back(rewriter.getIndexAttr(vnniDimSize));
// update the unit/nonUnit Dim size either it is A(LHS) or B(RHS).
- sizes[indexVals.size() - mnDimIdx] = rewriter.getIndexAttr(mnDimSize);
+ sizes[mnDimIdx] = rewriter.getIndexAttr(mnDimSize);
// for unitDim, first broadcast odd element, so index is set to 1.
if (mnDimSize == 1)
@@ -97,6 +97,9 @@ getSubviewFromVectorInput(Location loc, PatternRewriter &rewriter, Value prodOp,
// op loads and broadcast the first BF16 element into packed F32. It
// cannot distinguish between even and odd BF16 elements within a
// packed pair.
+ // Example:
+ // memref.subview %arg0[%c0,%c1]:memref<1x2xbf16> to memref<1x1xbf16> // Odd
+ // memref.subview %arg0[%c0,%c0]:memref<1x2xbf16> to memref<1x1xbf16> // Even
if (mnDimSize == 1) {
indexVals[indexVals.size() - 1] = rewriter.getIndexAttr(0);
sizes[indexVals.size() - 1] = rewriter.getIndexAttr(1);
@@ -156,20 +159,32 @@ struct VectorContractBF16ToFMA
if (!accTy)
return rewriter.notifyMatchFailure(contractOp, "Wrong accmulator type.");
- if ((lhsTy.getElementType().isBF16() && !accTy.getElementType().isF32()))
+ if (!accTy.getElementType().isF32())
return rewriter.notifyMatchFailure(
contractOp, "Only F32 acumulation supported for BF16 type.");
ArrayRef<int64_t> lhsShape = lhsTy.getShape();
llvm::SmallVector<int64_t> nonUnitDimLhs;
- llvm::copy_if(lhsShape, std::back_inserter(nonUnitDimLhs),
- [](int64_t dim) { return dim != 1; });
+ llvm::SmallVector<unsigned> nonUnitIndicesLhs;
+
+ for (auto it : llvm::enumerate(lhsShape)) {
+ if (it.value() != 1) {
+ nonUnitDimLhs.push_back(it.value());
+ nonUnitIndicesLhs.push_back(it.index());
+ }
+ }
VectorType rhsTy = contractOp.getRhsType();
ArrayRef<int64_t> rhsShape = rhsTy.getShape();
llvm::SmallVector<int64_t> nonUnitDimRhs;
- llvm::copy_if(rhsShape, std::back_inserter(nonUnitDimRhs),
- [](int64_t dim) { return dim != 1; });
+ llvm::SmallVector<unsigned> nonUnitIndicesRhs;
+
+ for (auto it : llvm::enumerate(rhsShape)) {
+ if (it.value() != 1) {
+ nonUnitDimRhs.push_back(it.value());
+ nonUnitIndicesRhs.push_back(it.index());
+ }
+ }
if ((nonUnitDimLhs.size() - 1) > 0 && (nonUnitDimRhs.size() - 1) > 0)
return rewriter.notifyMatchFailure(contractOp,
@@ -181,14 +196,6 @@ struct VectorContractBF16ToFMA
contractOp,
"Excepts a one non-unit A/B dimension for either LHS or RHS shape.");
- /* VectorType accTy = dyn_cast<VectorType>(contractOp.getAccType());
- if (!accTy)
- return rewriter.notifyMatchFailure(contractOp, "Wrong accmulator type.");
-
- if ((lhsTy.getElementType().isBF16() && !accTy.getElementType().isF32()))
- return rewriter.notifyMatchFailure(
- contractOp, "Only F32 acumulation supported for BF16 type."); */
-
ArrayRef<int64_t> accShape = accTy.getShape();
llvm::SmallVector<int64_t> nonUnitDimAcc;
llvm::copy_if(accShape, std::back_inserter(nonUnitDimAcc),
@@ -223,7 +230,9 @@ struct VectorContractBF16ToFMA
rhsHasMultipleNonUnitDims ? contractOp.getRhs() : contractOp.getLhs();
// mnDim index differs depending on the orientation.
- int mnDimIdx = rhsHasMultipleNonUnitDims ? 2 : 3; // A or B
+ int mnDimIdx = rhsHasMultipleNonUnitDims
+ ? nonUnitIndicesRhs.front()
+ : nonUnitIndicesLhs.front(); // A or B
// VNNI factor: always 1 for unit dims, 2 for non-unit dims.
int unitVnni = 1;
@@ -264,13 +273,9 @@ struct VectorContractBF16ToFMA
vector::FMAOp::create(rewriter, loc, loadBcstOddIdxElementToF32,
loadOddIdxElementF32, castAcc);
- llvm::SmallVector<Operation *> users;
- for (OpResult result : contractOp->getResults())
- for (Operation *user : result.getUsers())
- users.push_back(user);
-
- if (users.size() == 1)
- rewriter.setInsertionPoint(users[0]);
+ OpResult vcResult = contractOp->getResult(0);
+ if (vcResult.hasOneUse())
+ rewriter.setInsertionPoint(*vcResult.getUsers().begin());
// Load, broadcast, and do FMA for even indexed BF16 elements.
auto loadBcstEvenIdxElementToF32 = x86vector::BcstToPackedF32Op::create(
diff --git a/mlir/lib/Dialect/X86Vector/Utils/X86VectorUtils.cpp b/mlir/lib/Dialect/X86Vector/Utils/X86VectorUtils.cpp
index bb31686a14616..ccb2e92fdd9e2 100644
--- a/mlir/lib/Dialect/X86Vector/Utils/X86VectorUtils.cpp
+++ b/mlir/lib/Dialect/X86Vector/Utils/X86VectorUtils.cpp
@@ -15,8 +15,8 @@
#include "mlir/IR/TypeUtilities.h"
#include "mlir/IR/Types.h"
-using namespace mlir;
-using namespace mlir::x86vector;
+namespace mlir {
+namespace x86vector {
static FailureOr<SmallVector<mlir::utils::IteratorType>>
inferIteratorsFromOutMap(AffineMap map) {
@@ -32,9 +32,8 @@ inferIteratorsFromOutMap(AffineMap map) {
// Returns true if the operation is in VNNI layout.
// Optionally, the check can be constrained to a specific VNNI blocking factor.
-bool mlir::x86vector::isInVnniLayout(Operation *op,
- ArrayRef<AffineMap> indexingMaps,
- std::optional<unsigned> blockingFactor) {
+bool isInVnniLayout(Operation *op, ArrayRef<AffineMap> indexingMaps,
+ std::optional<unsigned> blockingFactor) {
// Narrow down type operations - VNNI only applies to contractions.
FailureOr<linalg::ContractionDimensions> dims =
linalg::inferContractionDims(indexingMaps);
@@ -104,3 +103,6 @@ bool mlir::x86vector::isInVnniLayout(Operation *op,
return true;
}
+
+} // namespace x86vector
+} // namespace mlir
diff --git a/mlir/test/Dialect/X86Vector/vector-contract-bf16-to-fma.mlir b/mlir/test/Dialect/X86Vector/vector-contract-bf16-to-fma.mlir
index 0b92e9367365f..aa7cfc84b0f7b 100644
--- a/mlir/test/Dialect/X86Vector/vector-contract-bf16-to-fma.mlir
+++ b/mlir/test/Dialect/X86Vector/vector-contract-bf16-to-fma.mlir
@@ -230,6 +230,52 @@ module attributes {transform.with_named_sequence} {
// -----
+!vecA = vector<1x1x1x1x2xbf16>
+!vecB = vector<1x1x1x8x2xbf16>
+!vecC = vector<1x8xf32>
+!memrefA = memref<1x1x4x1x2xbf16>
+!memrefB = memref<1x1x1x32x2xbf16>
+#map = affine_map<(d5, d0, d4, d1, d2, d3) -> (d5, d0, d1, d3, d4)>
+#map1 = affine_map<(d5, d0, d4, d1, d2, d3) -> (d5, d0, d3, d2, d4)>
+#map2 = affine_map<(d5, d0, d4, d1, d2, d3) -> (d1, d2)>
+func.func @many_dimensions(
+ %arg0: !memrefA, %arg1: !memrefB, %arg2: !vecC) -> !vecC
+{
+ %c0 = arith.constant 0 : index
+ %0 = ub.poison : bf16
+ %1 = vector.load %arg0[%c0, %c0, %c0, %c0, %c0] :
+ !memrefA, !vecA
+ %2 = vector.load %arg1[%c0, %c0, %c0, %c0, %c0] :
+ !memrefB, !vecB
+ %3 = vector.contract {
+ indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["reduction", "reduction", "reduction", "parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>}
+ %1, %2, %arg2
+ : !vecA, !vecB into !vecC
+ return %3 : !vecC
+}
+
+// CHECK-LABEL: @many_dimensions
+// CHECK: x86vector.avx.bcst_to_f32.packed
+// CHECK: x86vector.avx.cvt.packed.odd.indexed_to_f32
+// CHECK: vector.fma
+// CHECK: x86vector.avx.bcst_to_f32.packed
+// CHECK: x86vector.avx.cvt.packed.even.indexed_to_f32
+// CHECK: vector.fma
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.x86vector.vector_contract_bf16_to_fma
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
#map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d2, d4, d1)>
#map1 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d4, d3, d1)>
#map2 = affine_map<(d0, d1, d2, d3, d4) -> (d2, d3)>
>From 406b66c077ae724483ef81f68cb13b3da692c981 Mon Sep 17 00:00:00 2001
From: Arun Thangamani <arun.thangamani at intel.com>
Date: Fri, 12 Dec 2025 01:19:09 -0800
Subject: [PATCH 7/7] added new test-cases
---
.../Transforms/VectorContractBF16ToFMA.cpp | 80 ++++++++---------
.../vector-contract-bf16-to-fma.mlir | 85 +++++++++++++++----
2 files changed, 103 insertions(+), 62 deletions(-)
diff --git a/mlir/lib/Dialect/X86Vector/Transforms/VectorContractBF16ToFMA.cpp b/mlir/lib/Dialect/X86Vector/Transforms/VectorContractBF16ToFMA.cpp
index b9ccee9bd7186..cae8cfbf7c00f 100644
--- a/mlir/lib/Dialect/X86Vector/Transforms/VectorContractBF16ToFMA.cpp
+++ b/mlir/lib/Dialect/X86Vector/Transforms/VectorContractBF16ToFMA.cpp
@@ -48,48 +48,61 @@ using namespace mlir::x86vector;
// ```
static FailureOr<SmallVector<memref::SubViewOp>>
getSubviewFromVectorInput(Location loc, PatternRewriter &rewriter, Value prodOp,
- int64_t mnDimSize, int64_t vnniDimSize,
- int64_t mnDimIdx) {
+ ArrayRef<int64_t> nonUnitDimShape, bool isUnitDim) {
Operation *defOp = prodOp.getDefiningOp();
if (!defOp)
return failure();
- Value srcOperation;
+ Value srcBuff;
SmallVector<OpFoldResult> indexVals;
llvm::TypeSwitch<Operation *>(defOp).Case<TransferReadOp, LoadOp>(
[&](auto readOp) {
- srcOperation = readOp.getOperand(0);
+ srcBuff = readOp.getOperand(0);
indexVals = SmallVector<OpFoldResult>(readOp.getIndices().begin(),
readOp.getIndices().end());
});
- if (!srcOperation)
+ if (!srcBuff)
return failure();
- Type srcType = srcOperation.getType();
+ Type srcType = srcBuff.getType();
if (!llvm::isa<MemRefType>(srcType))
return failure();
+ int64_t mnDimSize = 1;
+ unsigned mnDimIdx = 0;
+
+ if (!isUnitDim) {
+ for (auto it : llvm::enumerate(nonUnitDimShape)) {
+ if (it.value() != 1) {
+ mnDimSize = it.value();
+ mnDimIdx = it.index();
+ break;
+ }
+ }
+ }
+
+ int vnniDimSize = isUnitDim ? 1 : 2;
+
auto nonVNNIDimSize = indexVals.size() - 1;
// Create the size and stride offsets.
auto one = rewriter.getIndexAttr(1);
- SmallVector<OpFoldResult> strides(nonVNNIDimSize, one);
+ SmallVector<OpFoldResult> strides(indexVals.size(), one);
SmallVector<OpFoldResult> sizes(nonVNNIDimSize, one);
- strides.push_back(rewriter.getIndexAttr(1));
sizes.push_back(rewriter.getIndexAttr(vnniDimSize));
// update the unit/nonUnit Dim size either it is A(LHS) or B(RHS).
sizes[mnDimIdx] = rewriter.getIndexAttr(mnDimSize);
// for unitDim, first broadcast odd element, so index is set to 1.
- if (mnDimSize == 1)
+ if (isUnitDim)
indexVals[indexVals.size() - 1] = rewriter.getIndexAttr(1);
llvm::SmallVector<memref::SubViewOp> subviews;
- auto subview = memref::SubViewOp::create(rewriter, loc, srcOperation,
- indexVals, sizes, strides);
+ auto subview = memref::SubViewOp::create(rewriter, loc, srcBuff, indexVals,
+ sizes, strides);
subviews.push_back(subview);
// For unit-dims, two subviews should be created for the odd and even
@@ -97,6 +110,7 @@ getSubviewFromVectorInput(Location loc, PatternRewriter &rewriter, Value prodOp,
// op loads and broadcast the first BF16 element into packed F32. It
// cannot distinguish between even and odd BF16 elements within a
// packed pair.
+ //
// Example:
// memref.subview %arg0[%c0,%c1]:memref<1x2xbf16> to memref<1x1xbf16> // Odd
// memref.subview %arg0[%c0,%c0]:memref<1x2xbf16> to memref<1x1xbf16> // Even
@@ -105,7 +119,7 @@ getSubviewFromVectorInput(Location loc, PatternRewriter &rewriter, Value prodOp,
sizes[indexVals.size() - 1] = rewriter.getIndexAttr(1);
auto unitDimEvenIdxSubview = memref::SubViewOp::create(
- rewriter, loc, srcOperation, indexVals, sizes, strides);
+ rewriter, loc, srcBuff, indexVals, sizes, strides);
subviews.push_back(unitDimEvenIdxSubview);
}
@@ -165,26 +179,14 @@ struct VectorContractBF16ToFMA
ArrayRef<int64_t> lhsShape = lhsTy.getShape();
llvm::SmallVector<int64_t> nonUnitDimLhs;
- llvm::SmallVector<unsigned> nonUnitIndicesLhs;
-
- for (auto it : llvm::enumerate(lhsShape)) {
- if (it.value() != 1) {
- nonUnitDimLhs.push_back(it.value());
- nonUnitIndicesLhs.push_back(it.index());
- }
- }
+ llvm::copy_if(lhsShape, std::back_inserter(nonUnitDimLhs),
+ [](int64_t dim) { return dim != 1; });
VectorType rhsTy = contractOp.getRhsType();
ArrayRef<int64_t> rhsShape = rhsTy.getShape();
llvm::SmallVector<int64_t> nonUnitDimRhs;
- llvm::SmallVector<unsigned> nonUnitIndicesRhs;
-
- for (auto it : llvm::enumerate(rhsShape)) {
- if (it.value() != 1) {
- nonUnitDimRhs.push_back(it.value());
- nonUnitIndicesRhs.push_back(it.index());
- }
- }
+ llvm::copy_if(rhsShape, std::back_inserter(nonUnitDimRhs),
+ [](int64_t dim) { return dim != 1; });
if ((nonUnitDimLhs.size() - 1) > 0 && (nonUnitDimRhs.size() - 1) > 0)
return rewriter.notifyMatchFailure(contractOp,
@@ -229,25 +231,15 @@ struct VectorContractBF16ToFMA
Value nonUnitSrc =
rhsHasMultipleNonUnitDims ? contractOp.getRhs() : contractOp.getLhs();
- // mnDim index differs depending on the orientation.
- int mnDimIdx = rhsHasMultipleNonUnitDims
- ? nonUnitIndicesRhs.front()
- : nonUnitIndicesLhs.front(); // A or B
-
- // VNNI factor: always 1 for unit dims, 2 for non-unit dims.
- int unitVnni = 1;
- int nonUnitVnni = 2;
-
- // Non-unit dim size.
- int nonUnitSize = nonUnitDimRhs.front();
+ ArrayRef<int64_t> nonUnitDimShape =
+ rhsHasMultipleNonUnitDims ? rhsShape : lhsShape;
// Build subviews.
- auto unitSubview = getSubviewFromVectorInput(
- loc, rewriter, unitSrc, /*size=*/1, unitVnni, mnDimIdx);
+ auto unitSubview = getSubviewFromVectorInput(loc, rewriter, unitSrc,
+ nonUnitDimShape, true);
- auto nonUnitSubview =
- getSubviewFromVectorInput(loc, rewriter, nonUnitSrc,
- /*size=*/nonUnitSize, nonUnitVnni, mnDimIdx);
+ auto nonUnitSubview = getSubviewFromVectorInput(loc, rewriter, nonUnitSrc,
+ nonUnitDimShape, false);
// Check failures once.
if (failed(unitSubview) || failed(nonUnitSubview))
diff --git a/mlir/test/Dialect/X86Vector/vector-contract-bf16-to-fma.mlir b/mlir/test/Dialect/X86Vector/vector-contract-bf16-to-fma.mlir
index aa7cfc84b0f7b..c55a859340600 100644
--- a/mlir/test/Dialect/X86Vector/vector-contract-bf16-to-fma.mlir
+++ b/mlir/test/Dialect/X86Vector/vector-contract-bf16-to-fma.mlir
@@ -230,6 +230,52 @@ module attributes {transform.with_named_sequence} {
// -----
+!vecA = vector<8x1x2xbf16>
+!vecB = vector<1x1x2xbf16>
+!vecC = vector<8x1xf32>
+!memrefA = memref<32x1x2xbf16>
+!memrefB = memref<1x4x2xbf16>
+#map = affine_map<(d4, d1, d2, d3) -> (d1, d3, d4)>
+#map1 = affine_map<(d4, d1, d2, d3) -> (d3, d2, d4)>
+#map2 = affine_map<(d4, d1, d2, d3) -> (d1, d2)>
+func.func @matmul_to_fma_load_bcst_B(
+ %arg0: !memrefA, %arg1: !memrefB, %arg2: !vecC) -> !vecC
+{
+ %c0 = arith.constant 0 : index
+ %0 = ub.poison : bf16
+ %1 = vector.load %arg0[%c0, %c0, %c0] :
+ !memrefA, !vecA
+ %2 = vector.load %arg1[%c0, %c0, %c0] :
+ !memrefB, !vecB
+ %3 = vector.contract {
+ indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["reduction", "parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>}
+ %1, %2, %arg2
+ : !vecA, !vecB into !vecC
+ return %3 : !vecC
+}
+
+// CHECK-LABEL: @matmul_to_fma_load_bcst_B
+// CHECK: x86vector.avx.bcst_to_f32.packed
+// CHECK: x86vector.avx.cvt.packed.odd.indexed_to_f32
+// CHECK: vector.fma
+// CHECK: x86vector.avx.bcst_to_f32.packed
+// CHECK: x86vector.avx.cvt.packed.even.indexed_to_f32
+// CHECK: vector.fma
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.x86vector.vector_contract_bf16_to_fma
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
!vecA = vector<1x1x1x1x2xbf16>
!vecB = vector<1x1x1x8x2xbf16>
!vecC = vector<1x8xf32>
@@ -276,26 +322,29 @@ module attributes {transform.with_named_sequence} {
// -----
-#map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d2, d4, d1)>
-#map1 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d4, d3, d1)>
-#map2 = affine_map<(d0, d1, d2, d3, d4) -> (d2, d3)>
-
-func.func @negative_tensor_type(%arg0: tensor<4x64x32x2xbf16>, %arg1: tensor<4x32x64x2xbf16>, %arg2: vector<1x16xf32>) -> vector<1x16xf32> {
+!vecA = vector<1x1x2xbf16>
+!vecB = vector<1x8x2xbf16>
+!vecC = vector<1x8xf32>
+!tensorA = tensor<4x1x2xbf16>
+!tensorB = tensor<1x32x2xbf16>
+#map = affine_map<(d1, d2, d3, d4) -> (d2, d4, d1)>
+#map1 = affine_map<(d1, d2, d3, d4) -> (d4, d3, d1)>
+#map2 = affine_map<(d1, d2, d3, d4) -> (d2, d3)>
+func.func @negative_tensor_type(%arg0: !tensorA, %arg1: !tensorB, %arg2: !vecC) -> !vecC {
%0 = ub.poison : bf16
%c0 = arith.constant 0 : index
- %c1 = arith.constant 1 : index
- %c16 = arith.constant 16 : index
- %extracted_slice = tensor.extract_slice %arg0[%c0, %c0, %c0, 0] [1, 4, 1, 2] [1, 1, 1, 1] : tensor<4x64x32x2xbf16> to tensor<1x4x1x2xbf16>
- %extracted_slice_0 = tensor.extract_slice %arg1[%c0, %c0, %c0, 0] [1, 1, 32, 2] [1, 1, 1, 1] : tensor<4x32x64x2xbf16> to tensor<1x1x32x2xbf16>
- %1 = vector.transfer_read %extracted_slice[%c0, %c0, %c0, %c0], %0 {in_bounds = [true, true, true, true]} : tensor<1x4x1x2xbf16>, vector<1x1x1x2xbf16>
- %2 = vector.transfer_read %extracted_slice[%c0, %c1, %c0, %c0], %0 {in_bounds = [true, true, true, true]} : tensor<1x4x1x2xbf16>, vector<1x1x1x2xbf16>
- %3 = vector.transfer_read %extracted_slice_0[%c0, %c0, %c0, %c0], %0 {in_bounds = [true, true, true, true]} : tensor<1x1x32x2xbf16>, vector<1x1x16x2xbf16>
- %4 = vector.transfer_read %extracted_slice_0[%c0, %c0, %c16, %c0], %0 {in_bounds = [true, true, true, true]} : tensor<1x1x32x2xbf16>, vector<1x1x16x2xbf16>
- %5 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["reduction", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %1, %3, %arg2 {unroll_shape = array<i64: 1, 2, 1, 16, 1>} : vector<1x1x1x2xbf16>, vector<1x1x16x2xbf16> into vector<1x16xf32>
- %6 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["reduction", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %1, %4, %5 {unroll_shape = array<i64: 1, 2, 1, 16, 1>} : vector<1x1x1x2xbf16>, vector<1x1x16x2xbf16> into vector<1x16xf32>
- %7 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["reduction", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %2, %3, %6 {unroll_shape = array<i64: 1, 2, 1, 16, 1>} : vector<1x1x1x2xbf16>, vector<1x1x16x2xbf16> into vector<1x16xf32>
- %8 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["reduction", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %2, %4, %7 {unroll_shape = array<i64: 1, 2, 1, 16, 1>} : vector<1x1x1x2xbf16>, vector<1x1x16x2xbf16> into vector<1x16xf32>
- return %8 : vector<1x16xf32>
+ %c8 = arith.constant 8 : index
+ %1 = vector.transfer_read %arg0[%c0, %c0, %c0], %0 {in_bounds = [true, true, true]} :
+ !tensorA, !vecA
+ %2 = vector.transfer_read %arg1[%c0, %c8, %c0], %0 {in_bounds = [true, true, true]} :
+ !tensorB, !vecB
+ %3 = vector.contract {
+ indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["reduction", "parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>}
+ %1, %2, %arg2
+ : !vecA, !vecB into !vecC
+ return %3 : !vecC
}
// CHECK-LABEL: @negative_tensor_type
More information about the Mlir-commits
mailing list