[Mlir-commits] [mlir] [mlir][x86vector] Lower BF16 vector.contract to FMA using AVX2 BF16 packed ops. (PR #170267)

Arun Thangamani llvmlistbot at llvm.org
Wed Dec 17 05:25:28 PST 2025


================
@@ -24,6 +26,63 @@ using namespace mlir;
 using namespace mlir::vector;
 using namespace mlir::x86vector;
 
+static bool validateVectorProdOp(Value prodOp) {
+  Operation *defOp = prodOp.getDefiningOp();
+  if (!defOp)
+    return false;
+
+  // If the LHS/RHS op is transfer_read return false if:
+  // (1) - It has false in-bounds
+  // (2) - The permutation map is not identical
+  if (auto readOp = prodOp.getDefiningOp<mlir::vector::TransferReadOp>()) {
+    ArrayAttr inBoundsAttr = readOp.getInBoundsAttr();
+    if (inBoundsAttr) {
+
+      for (Attribute attr : inBoundsAttr) {
+        auto boolAttr = llvm::dyn_cast<BoolAttr>(attr);
+        if (!boolAttr || !boolAttr.getValue()) {
+          return false;
+        }
+      }
+    }
+
+    if (!readOp.getPermutationMap().isIdentity())
+      return false;
+  }
+
+  Value srcBuff;
+  SmallVector<OpFoldResult> indexVals;
+  llvm::TypeSwitch<Operation *>(defOp).Case<TransferReadOp, LoadOp>(
+      [&](auto readOp) {
+        srcBuff = readOp.getOperand(0);
+        indexVals = SmallVector<OpFoldResult>(readOp.getIndices().begin(),
+                                              readOp.getIndices().end());
+      });
+
+  if (!srcBuff)
+    return false;
+
+  // Return false, if the source is not a memref type
+  Type srcType = srcBuff.getType();
+  if (!llvm::isa<MemRefType>(srcType))
+    return false;
+
+  // Return false, if the innermost stride of the memref is not 1.
----------------
arun-thmn wrote:

Updated the comments with reason that the last two dimensions should be contiguous.

https://github.com/llvm/llvm-project/pull/170267


More information about the Mlir-commits mailing list