[Mlir-commits] [mlir] [mlir][x86vector] Lower BF16 vector.contract to FMA using AVX2 BF16 packed ops. (PR #170267)
Arun Thangamani
llvmlistbot at llvm.org
Tue Dec 9 06:51:29 PST 2025
================
@@ -0,0 +1,274 @@
+//===- VectorContractBF16ToFMA.cpp
+//--------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
+#include "mlir/Dialect/X86Vector/Transforms.h"
+#include "mlir/Dialect/X86Vector/Utils/X86VectorUtils.h"
+#include "mlir/Dialect/X86Vector/X86VectorDialect.h"
+
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/Dominance.h"
+#include "mlir/IR/PatternMatch.h"
+
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/Support/Casting.h"
+
+using namespace mlir;
+using namespace mlir::vector;
+using namespace mlir::x86vector;
+
+// This function retrives the source operation of the load or transfer
+// reads and creates subviews for the BF16 packed-operations to
+// broadcast or load BF16 elements as F32 packed elements.
+//
+// For example:
+// ```
+// vector.load %arg0[%c0, %c0, %c0]:memref<4x1x2xbf16>,vector<1x1x2xbf16>
+// vector.load %arg0[%c0, %c0, %c0]:memref<1x32x2xbf16>,vector<1x8x2xbf16>
+// ```
+// to
+// ```
+// memref.subview %arg0[%c0,%c0,%c1]:memref<4x1x2xbf16> to memref<1x1x1xbf16>
+// memref.subview %arg1[%c0,%c0,%c0]:memref<1x32x2xbf16> to memref<1x8x2xbf16>
+// memref.subview %arg0[%c0,%c0,%c0]:memref<4x1x2xbf16> to memref<1x1x1xbf16>
+// ```
+static FailureOr<llvm::SmallVector<mlir::memref::SubViewOp>>
+getSubviewFromVectorInput(Location loc, PatternRewriter &rewriter,
+ mlir::Value prodOp, int64_t mnDim, int64_t vnniDim,
+ int64_t mnDimIndx) {
+
+ llvm::SmallVector<mlir::memref::SubViewOp> subviews;
+
+ Value srcOperation;
+ SmallVector<OpFoldResult> indexVals;
+
+ if (auto transferRead =
+ prodOp.getDefiningOp<mlir::vector::TransferReadOp>()) {
+ srcOperation = transferRead.getOperand(0);
+ SmallVector<OpFoldResult> indexValues(transferRead.getIndices().begin(),
+ transferRead.getIndices().end());
+ indexVals = indexValues;
+ }
+
+ if (auto load = prodOp.getDefiningOp<mlir::vector::LoadOp>()) {
+ srcOperation = load.getOperand(0);
+ SmallVector<OpFoldResult> indexValues(load.getIndices().begin(),
+ load.getIndices().end());
+ indexVals = indexValues;
+ }
+
+ if (!srcOperation)
+ return failure();
+
+ Type srcType = srcOperation.getType();
+ if (!llvm::isa<mlir::MemRefType>(srcType))
+ return failure();
+
+ llvm::SmallVector<OpFoldResult> strides;
+ llvm::SmallVector<OpFoldResult> sizes;
+
+ for (unsigned int i = 0; i < indexVals.size(); i++) {
+ strides.push_back(rewriter.getIndexAttr(1));
+ sizes.push_back(rewriter.getIndexAttr(1));
+ }
+
+ sizes[indexVals.size() - 1] = rewriter.getIndexAttr(vnniDim);
+ sizes[indexVals.size() - mnDimIndx] = rewriter.getIndexAttr(mnDim);
+
+ if (mnDim == 1) {
+ indexVals[indexVals.size() - 1] = rewriter.getIndexAttr(1);
+ }
+
+ auto subview = memref::SubViewOp::create(rewriter, loc, srcOperation,
+ indexVals, sizes, strides);
+ subviews.push_back(subview);
+
+ if (mnDim == 1) {
+ indexVals[indexVals.size() - 1] = rewriter.getIndexAttr(0);
+ sizes[indexVals.size() - 1] = rewriter.getIndexAttr(1);
+
+ auto unitDimEvenIndxSubview = memref::SubViewOp::create(
----------------
arun-thmn wrote:
Added a comment why we need two subviews for unitDim.
Reason: `For unit-dims, two subviews should be created for the odd and even indexed BF16 element because x86vector.avx.bcst_to_f32.packed op loads and broadcast the first BF16 element into packed F32. It cannot distinguish between even and odd BF16 elements within a packed pair.`
https://github.com/llvm/llvm-project/pull/170267
More information about the Mlir-commits
mailing list