[Mlir-commits] [mlir] [mlir][x86vector] Lower BF16 vector.contract to FMA using AVX2 BF16 packed ops. (PR #170267)
Adam Siemieniuk
llvmlistbot at llvm.org
Thu Dec 11 09:24:18 PST 2025
================
@@ -0,0 +1,298 @@
+//===- VectorContractBF16ToFMA.cpp-----------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
+#include "mlir/Dialect/X86Vector/Transforms.h"
+#include "mlir/Dialect/X86Vector/Utils/X86VectorUtils.h"
+#include "mlir/Dialect/X86Vector/X86VectorDialect.h"
+
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/Dominance.h"
+#include "mlir/IR/PatternMatch.h"
+
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/Support/Casting.h"
+
+using namespace mlir;
+using namespace mlir::vector;
+using namespace mlir::x86vector;
+
+// This function retrieves the source operation of the load or transfer
+// reads and creates subviews for the BF16 packed-operations to
+// broadcast or load BF16 elements as F32 packed elements.
+//
+// Example(1) Unit Dim:
+// ```
+// vector.load %arg0[%c0, %c0, %c0]:memref<4x1x2xbf16>,vector<1x1x2xbf16>
+// ```
+// to
+// ```
+// memref.subview %arg0[%c0,%c0,%c1]:memref<4x1x2xbf16> to memref<1x1x1xbf16>
+// memref.subview %arg0[%c0,%c0,%c0]:memref<4x1x2xbf16> to memref<1x1x1xbf16>
+// ```
+//
+// Example(2) Non-unit Dim:
+// ```
+// vector.load %arg1[%c0, %c0, %c0]:memref<1x32x2xbf16>,vector<1x8x2xbf16>
+// ```
+// to
+// ```
+// memref.subview %arg1[%c0,%c0,%c0]:memref<1x32x2xbf16> to memref<1x8x2xbf16>
+// ```
+static FailureOr<SmallVector<memref::SubViewOp>>
+getSubviewFromVectorInput(Location loc, PatternRewriter &rewriter, Value prodOp,
+ int64_t mnDimSize, int64_t vnniDimSize,
+ int64_t mnDimIdx) {
+
+ Operation *defOp = prodOp.getDefiningOp();
+ if (!defOp)
+ return failure();
+
+ Value srcOperation;
+ SmallVector<OpFoldResult> indexVals;
+ llvm::TypeSwitch<Operation *>(defOp).Case<TransferReadOp, LoadOp>(
+ [&](auto readOp) {
+ srcOperation = readOp.getOperand(0);
+ indexVals = SmallVector<OpFoldResult>(readOp.getIndices().begin(),
+ readOp.getIndices().end());
+ });
+
+ if (!srcOperation)
+ return failure();
+
+ Type srcType = srcOperation.getType();
+ if (!llvm::isa<MemRefType>(srcType))
+ return failure();
+
+ auto nonVNNIDimSize = indexVals.size() - 1;
+ // Create the size and stride offsets.
+ auto one = rewriter.getIndexAttr(1);
+ SmallVector<OpFoldResult> strides(nonVNNIDimSize, one);
+ SmallVector<OpFoldResult> sizes(nonVNNIDimSize, one);
+
+ strides.push_back(rewriter.getIndexAttr(1));
+ sizes.push_back(rewriter.getIndexAttr(vnniDimSize));
+
+ // update the unit/nonUnit Dim size either it is A(LHS) or B(RHS).
+ sizes[mnDimIdx] = rewriter.getIndexAttr(mnDimSize);
+
+ // for unitDim, first broadcast odd element, so index is set to 1.
+ if (mnDimSize == 1)
+ indexVals[indexVals.size() - 1] = rewriter.getIndexAttr(1);
----------------
adam-smnk wrote:
There's no guarantee offsets are zero so, they can't be directly overwritten.
You probably need to add the unit offset to the original value.
Also, maybe make offsets dynamic, use function arg as offset value, in at least one or two test cases for increased coverage.
https://github.com/llvm/llvm-project/pull/170267
More information about the Mlir-commits
mailing list