[Mlir-commits] [mlir] [mlir][vector] Relax the requirements on broadcast dims (PR #99341)

Andrzej WarzyƄski llvmlistbot at llvm.org
Thu Oct 3 03:13:21 PDT 2024


================
@@ -4138,22 +4134,42 @@ static LogicalResult foldTransferInBoundsAttribute(TransferOp op) {
   bool changed = false;
   SmallVector<bool, 4> newInBounds;
   newInBounds.reserve(op.getTransferRank());
+  SmallVector<unsigned> nonBcastDims;
   for (unsigned i = 0; i < op.getTransferRank(); ++i) {
-    // Already marked as in-bounds, nothing to see here.
+    // 1. Already marked as in-bounds, nothing to see here.
     if (op.isDimInBounds(i)) {
       newInBounds.push_back(true);
       continue;
     }
-    // Currently out-of-bounds, check whether we can statically determine it is
-    // inBounds.
+    // 2. Currently out-of-bounds, check whether we can statically determine it
+    // is inBounds.
+    bool inBounds = false;
     auto dimExpr = dyn_cast<AffineDimExpr>(permutationMap.getResult(i));
-    assert(dimExpr && "Broadcast dims must be in-bounds");
-    auto inBounds =
-        isInBounds(op, /*resultIdx=*/i, /*indicesIdx=*/dimExpr.getPosition());
+    if (dimExpr) {
+      // 2.a Non-broadcast dim
+      inBounds = isInBounds(op, /*resultIdx=*/i,
+                            /*indicesIdx=*/dimExpr.getPosition());
+      // 2.b Broadcast dims are handled after processing non-bcast dims
+      // FIXME: constant expr != 0 are not broadcasts - should such
+      // constants be allowed at all?
+      nonBcastDims.push_back(i);
+    }
+
     newInBounds.push_back(inBounds);
     // We commit the pattern if it is "more inbounds".
     changed |= inBounds;
   }
+
+  // Handle broadcast dims: if all non-broadcast dims are "in
+  // bounds", then all bcast dims should be "in bounds" as well.
+  bool allNonBcastDimsInBounds = llvm::all_of(
+      nonBcastDims, [&newInBounds](unsigned idx) { return newInBounds[idx]; });
+  if (allNonBcastDimsInBounds)
+    llvm::for_each(permutationMap.getBroadcastDims(), [&](unsigned idx) {
+      changed |= !newInBounds[idx];
+      newInBounds[idx] = true;
----------------
banach-space wrote:

Updated.

https://github.com/llvm/llvm-project/pull/99341


More information about the Mlir-commits mailing list