[Mlir-commits] [mlir] [mlir][x86vector] Lower BF16 vector.contract to FMA using AVX2 BF16 packed ops. (PR #170267)
Adam Siemieniuk
llvmlistbot at llvm.org
Thu Dec 11 02:18:51 PST 2025
================
@@ -0,0 +1,293 @@
+//===- VectorContractBF16ToFMA.cpp-----------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
+#include "mlir/Dialect/X86Vector/Transforms.h"
+#include "mlir/Dialect/X86Vector/Utils/X86VectorUtils.h"
+#include "mlir/Dialect/X86Vector/X86VectorDialect.h"
+
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/Dominance.h"
+#include "mlir/IR/PatternMatch.h"
+
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/Support/Casting.h"
+
+using namespace mlir;
+using namespace mlir::vector;
+using namespace mlir::x86vector;
+
+// This function retrives the source operation of the load or transfer
+// reads and creates subviews for the BF16 packed-operations to
+// broadcast or load BF16 elements as F32 packed elements.
+//
+// Example(1) Unit Dim:
+// ```
+// vector.load %arg0[%c0, %c0, %c0]:memref<4x1x2xbf16>,vector<1x1x2xbf16>
+// ```
+// to
+// ```
+// memref.subview %arg0[%c0,%c0,%c1]:memref<4x1x2xbf16> to memref<1x1x1xbf16>
+// memref.subview %arg0[%c0,%c0,%c0]:memref<4x1x2xbf16> to memref<1x1x1xbf16>
+// ```
+//
+// Example(2) Non-unit Dim:
+// ```
+// vector.load %arg1[%c0, %c0, %c0]:memref<1x32x2xbf16>,vector<1x8x2xbf16>
+// ```
+// to
+// ```
+// memref.subview %arg1[%c0,%c0,%c0]:memref<1x32x2xbf16> to memref<1x8x2xbf16>
+// ```
+static FailureOr<SmallVector<memref::SubViewOp>>
+getSubviewFromVectorInput(Location loc, PatternRewriter &rewriter, Value prodOp,
+ int64_t mnDimSize, int64_t vnniDimSize,
+ int64_t mnDimIdx) {
+
+ Operation *defOp = prodOp.getDefiningOp();
+ if (!defOp)
+ return failure();
+
+ Value srcOperation;
+ SmallVector<OpFoldResult> indexVals;
+ llvm::TypeSwitch<Operation *>(defOp).Case<TransferReadOp, LoadOp>(
+ [&](auto readOp) {
+ srcOperation = readOp.getOperand(0);
+ indexVals = SmallVector<OpFoldResult>(readOp.getIndices().begin(),
+ readOp.getIndices().end());
+ });
+
+ if (!srcOperation)
+ return failure();
+
+ Type srcType = srcOperation.getType();
+ if (!llvm::isa<MemRefType>(srcType))
+ return failure();
+
+ auto nonVNNIDimSize = indexVals.size() - 1;
+ // Create the size and stride offsets.
+ auto one = rewriter.getIndexAttr(1);
+ SmallVector<OpFoldResult> strides(nonVNNIDimSize, one);
+ SmallVector<OpFoldResult> sizes(nonVNNIDimSize, one);
+
+ strides.push_back(rewriter.getIndexAttr(1));
+ sizes.push_back(rewriter.getIndexAttr(vnniDimSize));
+
+ // update the unit/nonUnit Dim size either it is A(LHS) or B(RHS).
+ sizes[indexVals.size() - mnDimIdx] = rewriter.getIndexAttr(mnDimSize);
+
+ // for unitDim, first broadcast odd element, so index is set to 1.
+ if (mnDimSize == 1)
+ indexVals[indexVals.size() - 1] = rewriter.getIndexAttr(1);
+
+ llvm::SmallVector<memref::SubViewOp> subviews;
+ auto subview = memref::SubViewOp::create(rewriter, loc, srcOperation,
+ indexVals, sizes, strides);
+ subviews.push_back(subview);
+
+ // For unit-dims, two subviews should be created for the odd and even
+ // element in the VNNI tuple (2xbf16) because x86vector.avx.bcst_to_f32.packed
+ // op loads and broadcast the first BF16 element into packed F32. It
+ // cannot distinguish between even and odd BF16 elements within a
+ // packed pair.
+ if (mnDimSize == 1) {
+ indexVals[indexVals.size() - 1] = rewriter.getIndexAttr(0);
+ sizes[indexVals.size() - 1] = rewriter.getIndexAttr(1);
+
+ auto unitDimEvenIdxSubview = memref::SubViewOp::create(
+ rewriter, loc, srcOperation, indexVals, sizes, strides);
+ subviews.push_back(unitDimEvenIdxSubview);
+ }
+
+ return subviews;
+}
+
+// Implements outer product contraction as a sequence of BF16-packed
+// operation even/odd loads and FMA operations.
+//
+// For example:
+// ```
+// %1 = vector.load from memref (%m1) -> vector<1x1x2xbf16>
+// %2 = vector.load from memref (%m2) -> vector<1x8x2xbf16>
+// return vector.contract %1, %2, %arg1
+// ```
+// to
+// ```
+// %1 = x86vector.avx.bcst_to_f32.packed %m1[c1] -> vector<8xf32>
+// %2 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %m2 -> vector<8xf32>
+// %3 = vector.fma %1, %2, %arg1
+// %4 = x86vector.avx.bcst_to_f32.packed %m1[c0] -> vector<8xf32>
+// %5 = x86vector.avx.cvt.packed.even.indexed_to_f32 %m2 -> vector<8xf32>
+// return vector.fma %4, %5, %3
+// ```
+struct VectorContractBF16ToFMA
+ : public OpRewritePattern<vector::ContractionOp> {
+ using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
+ PatternRewriter &rewriter) const override {
+
+ if (contractOp.getKind() != vector::CombiningKind::ADD)
+ return rewriter.notifyMatchFailure(contractOp,
+ "Expects add combining kind.");
+
+ // TODO: Move this validation to a common utility folder. Planned to
+ // do once (code refactoring), all architecture specific nanokernel
+ // passes are merged into the repo.
+ VectorType lhsTy = contractOp.getLhsType();
+ if (!lhsTy.getElementType().isBF16())
+ return rewriter.notifyMatchFailure(contractOp,
+ "Only BF16 lowering is supported.");
+
+ if (!isInVnniLayout(contractOp.getOperation(),
+ contractOp.getIndexingMapsArray(),
+ /*blockingFactor=*/2))
+ return rewriter.notifyMatchFailure(contractOp,
+ "Input matrices not in VNNI format.");
+
+ VectorType accTy = dyn_cast<VectorType>(contractOp.getAccType());
+ if (!accTy)
+ return rewriter.notifyMatchFailure(contractOp, "Wrong accmulator type.");
+
+ if ((lhsTy.getElementType().isBF16() && !accTy.getElementType().isF32()))
----------------
adam-smnk wrote:
`lhs` type check is still duplicated.
https://github.com/llvm/llvm-project/pull/170267
More information about the Mlir-commits
mailing list