[Mlir-commits] [mlir] [mlir][vector] Restrict DropInnerMostUnitDimsTransferWrite (PR #96218)
Andrzej WarzyĆski
llvmlistbot at llvm.org
Thu Jul 11 04:12:30 PDT 2024
================
@@ -1394,6 +1394,33 @@ class DropInnerMostUnitDimsTransferWrite
if (dimsToDrop == 0)
return failure();
+ // We need to consider 3 cases for the dim to drop:
+ // 1. if "in bounds", it can safely be assumeed that the corresponding
+ // index is equal to 0 (safe to collapse) (*)
+ // 2. if "out of bounds" and the corresponding index is 0, it is
+ // effectively "in bounds" (safe to collapse)
+ // 3. If "out of bounds" and the correspondong index is != 0,
+ // be conservative and bail out (not safe to collapse)
+ // (*) This pattern only drops unit dims, so the only possible "in bounds"
+ // index is "0". This could be added as a folder.
+ // TODO: Deal with 3. by e.g. proppaging the "out of bounds" flag to other
+ // dims.
+ bool indexOutOfBounds = true;
+ if (writeOp.getInBounds())
+ indexOutOfBounds = llvm::any_of(
+ llvm::zip(writeOp.getInBounds()->getValue().take_back(dimsToDrop),
+ writeOp.getIndices().take_back(dimsToDrop)),
+ [](auto zipped) {
+ auto inBounds = cast<BoolAttr>(std::get<0>(zipped)).getValue();
+ auto nonZeroIdx = !isZeroIndex(std::get<1>(zipped));
+ return !inBounds && nonZeroIdx;
+ });
+ else
+ indexOutOfBounds = !llvm::all_of(
+ writeOp.getIndices().take_back(dimsToDrop), isZeroIndex);
+ if (indexOutOfBounds)
+ return failure();
----------------
banach-space wrote:
> It's an actual fold too (foldTransferInBoundsAttribute, so it's running pretty much all the time).
I knew about the folder, but incorrectly assumed that the "folder" wasn't guaranteed to be run before the pattern. I was wrong: https://github.com/llvm/llvm-project/blob/1ed84a862f9ce3c60251968f23a5405f06458975/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h#L107-L108
> /// Also performs folding and simple dead-code elimination before attempting to
/// match any of the provided patterns.
So:
> Why not just:
>
> auto inBounds = writeOp.getInBoundsValues();
> auto droppedInBounds = ArrayRef<bool>(inBounds).take_back(dimsToDrop);
> if (llvm::is_contained(droppedInBounds, false))
> return failure();
Yeah, that's perfectly sufficient and covers all the cases. Thanks!
https://github.com/llvm/llvm-project/pull/96218
More information about the Mlir-commits
mailing list