[Mlir-commits] [mlir] [mlir][x86vector] Lower BF16 vector.contract to FMA using AVX2 BF16 packed ops. (PR #170267)
Adam Siemieniuk
llvmlistbot at llvm.org
Thu Dec 11 09:24:17 PST 2025
================
@@ -0,0 +1,290 @@
+//===- VectorContractBF16ToFMA.cpp
+//--------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
+#include "mlir/Dialect/X86Vector/Transforms.h"
+#include "mlir/Dialect/X86Vector/Utils/X86VectorUtils.h"
+#include "mlir/Dialect/X86Vector/X86VectorDialect.h"
+
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/Dominance.h"
+#include "mlir/IR/PatternMatch.h"
+
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/Support/Casting.h"
+
+using namespace mlir;
+using namespace mlir::vector;
+using namespace mlir::x86vector;
+
+// This function retrives the source operation of the load or transfer
+// reads and creates subviews for the BF16 packed-operations to
+// broadcast or load BF16 elements as F32 packed elements.
+//
+// For example:
+// ```
+// vector.load %arg0[%c0, %c0, %c0]:memref<4x1x2xbf16>,vector<1x1x2xbf16>
+// vector.load %arg0[%c0, %c0, %c0]:memref<1x32x2xbf16>,vector<1x8x2xbf16>
+// ```
+// to
+// ```
+// memref.subview %arg0[%c0,%c0,%c1]:memref<4x1x2xbf16> to memref<1x1x1xbf16>
+// memref.subview %arg1[%c0,%c0,%c0]:memref<1x32x2xbf16> to memref<1x8x2xbf16>
+// memref.subview %arg0[%c0,%c0,%c0]:memref<4x1x2xbf16> to memref<1x1x1xbf16>
+// ```
+static FailureOr<llvm::SmallVector<mlir::memref::SubViewOp>>
+getSubviewFromVectorInput(Location loc, PatternRewriter &rewriter,
+ mlir::Value prodOp, int64_t mnDim, int64_t vnniDim,
+ int64_t mnDimIndx) {
+
+ llvm::SmallVector<mlir::memref::SubViewOp> subviews;
+
+ Value srcOperation;
+ SmallVector<OpFoldResult> indexVals;
+
+ Operation *defOp = prodOp.getDefiningOp();
+ if (!defOp)
+ return failure();
+
+ llvm::TypeSwitch<Operation *>(defOp)
+ .Case<mlir::vector::TransferReadOp, mlir::vector::LoadOp>(
+ [&](auto readOp) {
+ srcOperation = readOp.getOperand(0);
+ indexVals = SmallVector<OpFoldResult>(readOp.getIndices().begin(),
+ readOp.getIndices().end());
+ });
+
+ if (!srcOperation)
+ return failure();
+
+ Type srcType = srcOperation.getType();
+ if (!llvm::isa<mlir::MemRefType>(srcType))
+ return failure();
+
+ llvm::SmallVector<OpFoldResult> strides;
+ llvm::SmallVector<OpFoldResult> sizes;
+
+ auto nonVNNIDimSize = indexVals.size() - 1;
+ // Create the size and stride offsets.
+ for (unsigned int i = 0; i < nonVNNIDimSize; i++) {
+ strides.push_back(rewriter.getIndexAttr(1));
+ sizes.push_back(rewriter.getIndexAttr(1));
+ }
+
+ strides.push_back(rewriter.getIndexAttr(1));
+ sizes.push_back(rewriter.getIndexAttr(vnniDim));
+
+ // update the unit/nonUnit Dim size eiither it is A(LHS) or B(RHS).
+ sizes[indexVals.size() - mnDimIndx] = rewriter.getIndexAttr(mnDim);
+
+ // for unitDim, first broadcast odd element, so index is set to C1.
+ if (mnDim == 1) {
+ indexVals[indexVals.size() - 1] = rewriter.getIndexAttr(1);
+ }
+
+ auto subview = memref::SubViewOp::create(rewriter, loc, srcOperation,
+ indexVals, sizes, strides);
+ subviews.push_back(subview);
+
+ // For unit-dims, two subviews should be created for the odd and even
----------------
adam-smnk wrote:
I feel like `mnDimSize` and `vnniDimSize` are tightly coupled.
Do they really need to be free parameters? Perhaps they should be selected by the function itself based on some decision flag?
I'm trying to figure out some representation that would make this function more self-contained.
https://github.com/llvm/llvm-project/pull/170267
More information about the Mlir-commits
mailing list