[Mlir-commits] [mlir] [mlir][x86vector] Lower BF16 vector.contract to FMA using AVX2 BF16 packed ops. (PR #170267)

Adam Siemieniuk llvmlistbot at llvm.org
Thu Dec 11 09:24:17 PST 2025


================
@@ -0,0 +1,298 @@
+//===- VectorContractBF16ToFMA.cpp-----------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
+#include "mlir/Dialect/X86Vector/Transforms.h"
+#include "mlir/Dialect/X86Vector/Utils/X86VectorUtils.h"
+#include "mlir/Dialect/X86Vector/X86VectorDialect.h"
+
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/Dominance.h"
+#include "mlir/IR/PatternMatch.h"
+
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/Support/Casting.h"
+
+using namespace mlir;
+using namespace mlir::vector;
+using namespace mlir::x86vector;
+
+// This function retrieves the source operation of the load or transfer
+// reads and creates subviews for the BF16 packed-operations to
+// broadcast or load BF16 elements as F32 packed elements.
+//
+// Example(1) Unit Dim:
+// ```
+//   vector.load %arg0[%c0, %c0, %c0]:memref<4x1x2xbf16>,vector<1x1x2xbf16>
+// ```
+// to
+// ```
+//   memref.subview %arg0[%c0,%c0,%c1]:memref<4x1x2xbf16> to memref<1x1x1xbf16>
+//   memref.subview %arg0[%c0,%c0,%c0]:memref<4x1x2xbf16> to memref<1x1x1xbf16>
+// ```
+//
+// Example(2) Non-unit Dim:
+// ```
+//   vector.load %arg1[%c0, %c0, %c0]:memref<1x32x2xbf16>,vector<1x8x2xbf16>
+// ```
+// to
+// ```
+//   memref.subview %arg1[%c0,%c0,%c0]:memref<1x32x2xbf16> to memref<1x8x2xbf16>
+// ```
+static FailureOr<SmallVector<memref::SubViewOp>>
+getSubviewFromVectorInput(Location loc, PatternRewriter &rewriter, Value prodOp,
+                          int64_t mnDimSize, int64_t vnniDimSize,
+                          int64_t mnDimIdx) {
+
+  Operation *defOp = prodOp.getDefiningOp();
+  if (!defOp)
+    return failure();
+
+  Value srcOperation;
+  SmallVector<OpFoldResult> indexVals;
+  llvm::TypeSwitch<Operation *>(defOp).Case<TransferReadOp, LoadOp>(
----------------
adam-smnk wrote:

We gotta be careful with replacing transfer read as it's a pretty semantically rich op.
I don't think this rewrite can be performed when there are out of bound accesses that need to apply padding.

Please, double check potential transfer edge cases (e.g., are arbitrary permutation maps fine for us, can masking be supported) and add more tests around it.

https://github.com/llvm/llvm-project/pull/170267


More information about the Mlir-commits mailing list