[Mlir-commits] [mlir] [mlir][vector] Restrict DropInnerMostUnitDimsTransferWrite (PR #96218)

Andrzej WarzyƄski llvmlistbot at llvm.org
Thu Jul 11 03:46:19 PDT 2024


================
@@ -1394,6 +1394,33 @@ class DropInnerMostUnitDimsTransferWrite
     if (dimsToDrop == 0)
       return failure();
 
+    // We need to consider 3 cases for the dim to drop:
+    //  1. if "in bounds", it can safely be assumeed that the corresponding
+    //     index is equal to 0 (safe to collapse) (*)
+    //  2. if "out of bounds" and the corresponding index is 0, it is
+    //     effectively "in bounds" (safe to collapse)
+    //  3. If "out of bounds" and the correspondong index is != 0,
+    //     be conservative and bail out (not safe to collapse)
+    // (*)  This pattern only drops unit dims, so the only possible "in bounds"
+    // index is "0". This could be added as a folder.
+    // TODO: Deal with 3. by e.g. proppaging the "out of bounds" flag to other
+    // dims.
+    bool indexOutOfBounds = true;
+    if (writeOp.getInBounds())
+      indexOutOfBounds = llvm::any_of(
+          llvm::zip(writeOp.getInBounds()->getValue().take_back(dimsToDrop),
+                    writeOp.getIndices().take_back(dimsToDrop)),
+          [](auto zipped) {
+            auto inBounds = cast<BoolAttr>(std::get<0>(zipped)).getValue();
+            auto nonZeroIdx = !isZeroIndex(std::get<1>(zipped));
+            return !inBounds && nonZeroIdx;
+          });
+    else
+      indexOutOfBounds = !llvm::all_of(
+          writeOp.getIndices().take_back(dimsToDrop), isZeroIndex);
+    if (indexOutOfBounds)
+      return failure();
----------------
banach-space wrote:

>  dropping the dim if the index is zero and marked as out-of-bounds does not seem valid. If index zero is out-of-bounds, then we can't safely write to that unit dimension

As you observed further down:

> If index 0 and in-bounds = false for a unit-dim actually means in-bounds = true

:) So:
1.  if the index is == 0 then `in_bounds` is effectively irrelevant (safe to collapse)
2. if index is != 0, but `in_bounds = true`, the index is effectively ==0 and  (safe to collapse)
3. if index != 0 and `in_bounds = false`, bail out.

Note that once `in_bounds` is mandatory, I will be able to simply the above as:
```cpp
      if(llvm::any_of(
          llvm::zip(writeOp.getInBounds()->getValue().take_back(dimsToDrop),
                    writeOp.getIndices().take_back(dimsToDrop)),
          [](auto zipped) {
            auto inBounds = cast<BoolAttr>(std::get<0>(zipped)).getValue();
            auto nonZeroIdx = !isZeroIndex(std::get<1>(zipped));
            return !inBounds && nonZeroIdx;
          }))
      return failure();
```

> out-of-bounds dims are an edge-case

Not sure we can say that - the default for `in_bounds` is "out of bounds".

https://github.com/llvm/llvm-project/pull/96218


More information about the Mlir-commits mailing list