[Mlir-commits] [mlir] [mlir][x86vector] Lower BF16 vector.contract to FMA using AVX2 BF16 packed ops. (PR #170267)
Adam Siemieniuk
llvmlistbot at llvm.org
Wed Dec 17 02:31:55 PST 2025
================
@@ -24,6 +26,63 @@ using namespace mlir;
using namespace mlir::vector;
using namespace mlir::x86vector;
+static bool validateVectorProdOp(Value prodOp) {
+ Operation *defOp = prodOp.getDefiningOp();
+ if (!defOp)
+ return false;
+
+ // If the LHS/RHS op is transfer_read return false if:
+ // (1) - It has false in-bounds
+ // (2) - The permutation map is not identical
+ if (auto readOp = prodOp.getDefiningOp<mlir::vector::TransferReadOp>()) {
+ ArrayAttr inBoundsAttr = readOp.getInBoundsAttr();
+ if (inBoundsAttr) {
+
+ for (Attribute attr : inBoundsAttr) {
+ auto boolAttr = llvm::dyn_cast<BoolAttr>(attr);
+ if (!boolAttr || !boolAttr.getValue()) {
+ return false;
+ }
+ }
+ }
+
+ if (!readOp.getPermutationMap().isIdentity())
+ return false;
+ }
+
+ Value srcBuff;
+ SmallVector<OpFoldResult> indexVals;
+ llvm::TypeSwitch<Operation *>(defOp).Case<TransferReadOp, LoadOp>(
+ [&](auto readOp) {
+ srcBuff = readOp.getOperand(0);
+ indexVals = SmallVector<OpFoldResult>(readOp.getIndices().begin(),
+ readOp.getIndices().end());
+ });
+
+ if (!srcBuff)
+ return false;
+
+ // Return false, if the source is not a memref type
+ Type srcType = srcBuff.getType();
+ if (!llvm::isa<MemRefType>(srcType))
+ return false;
+
+ // Return false, if the innermost stride of the memref is not 1.
+ auto [strides, offset] =
+ llvm::cast<mlir::MemRefType>(srcType).getStridesAndOffset();
+ if (!strides.empty()) {
+ int64_t s = strides.back();
+ if (s != mlir::ShapedType::kDynamic && s != 1)
+ return false;
+ }
----------------
adam-smnk wrote:
nit: you could use `areTrailingDimsContiguous` instead. See `MemRefType` API.
https://github.com/llvm/llvm-project/pull/170267
More information about the Mlir-commits
mailing list