[Mlir-commits] [mlir] [mlir][x86vector] Lower BF16 vector.contract to FMA using AVX2 BF16 packed ops. (PR #170267)

Arun Thangamani llvmlistbot at llvm.org
Tue Dec 16 04:02:38 PST 2025


================
@@ -0,0 +1,298 @@
+//===- VectorContractBF16ToFMA.cpp-----------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
+#include "mlir/Dialect/X86Vector/Transforms.h"
+#include "mlir/Dialect/X86Vector/Utils/X86VectorUtils.h"
+#include "mlir/Dialect/X86Vector/X86VectorDialect.h"
+
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/Dominance.h"
+#include "mlir/IR/PatternMatch.h"
+
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/Support/Casting.h"
+
+using namespace mlir;
+using namespace mlir::vector;
+using namespace mlir::x86vector;
+
+// This function retrieves the source operation of the load or transfer
+// reads and creates subviews for the BF16 packed-operations to
+// broadcast or load BF16 elements as F32 packed elements.
+//
+// Example(1) Unit Dim:
+// ```
+//   vector.load %arg0[%c0, %c0, %c0]:memref<4x1x2xbf16>,vector<1x1x2xbf16>
+// ```
+// to
+// ```
+//   memref.subview %arg0[%c0,%c0,%c1]:memref<4x1x2xbf16> to memref<1x1x1xbf16>
+//   memref.subview %arg0[%c0,%c0,%c0]:memref<4x1x2xbf16> to memref<1x1x1xbf16>
+// ```
+//
+// Example(2) Non-unit Dim:
+// ```
+//   vector.load %arg1[%c0, %c0, %c0]:memref<1x32x2xbf16>,vector<1x8x2xbf16>
+// ```
+// to
+// ```
+//   memref.subview %arg1[%c0,%c0,%c0]:memref<1x32x2xbf16> to memref<1x8x2xbf16>
+// ```
+static FailureOr<SmallVector<memref::SubViewOp>>
+getSubviewFromVectorInput(Location loc, PatternRewriter &rewriter, Value prodOp,
+                          int64_t mnDimSize, int64_t vnniDimSize,
+                          int64_t mnDimIdx) {
+
+  Operation *defOp = prodOp.getDefiningOp();
+  if (!defOp)
+    return failure();
+
+  Value srcOperation;
+  SmallVector<OpFoldResult> indexVals;
+  llvm::TypeSwitch<Operation *>(defOp).Case<TransferReadOp, LoadOp>(
+      [&](auto readOp) {
+        srcOperation = readOp.getOperand(0);
+        indexVals = SmallVector<OpFoldResult>(readOp.getIndices().begin(),
+                                              readOp.getIndices().end());
+      });
+
+  if (!srcOperation)
+    return failure();
+
+  Type srcType = srcOperation.getType();
+  if (!llvm::isa<MemRefType>(srcType))
+    return failure();
+
+  auto nonVNNIDimSize = indexVals.size() - 1;
+  // Create the size and stride offsets.
+  auto one = rewriter.getIndexAttr(1);
+  SmallVector<OpFoldResult> strides(nonVNNIDimSize, one);
+  SmallVector<OpFoldResult> sizes(nonVNNIDimSize, one);
+
+  strides.push_back(rewriter.getIndexAttr(1));
+  sizes.push_back(rewriter.getIndexAttr(vnniDimSize));
+
+  // update the unit/nonUnit Dim size either it is A(LHS) or B(RHS).
+  sizes[mnDimIdx] = rewriter.getIndexAttr(mnDimSize);
+
+  // for unitDim, first broadcast odd element, so index is set to 1.
+  if (mnDimSize == 1)
+    indexVals[indexVals.size() - 1] = rewriter.getIndexAttr(1);
----------------
arun-thmn wrote:

Added a check to validate the `vnni` offset should be `zero`. Added dynamic offsett related test-cases (both +ve and -ve).

https://github.com/llvm/llvm-project/pull/170267


More information about the Mlir-commits mailing list