[PATCH] D79766: [mlir][Linalg] Add pass to remove unit-extent dims from tensor operands of Generic ops.
Mahesh Ravishankar via Phabricator via llvm-commits
llvm-commits at lists.llvm.org
Thu May 14 21:56:30 PDT 2020
mravishankar added inline comments.
================
Comment at: mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp:74
+///
+static bool canFold(MemRefCastOp castOp) {
+ MemRefType sourceType = castOp.source().getType().dyn_cast<MemRefType>();
----------------
nicolasvasilache wrote:
> This does not belong here right?
> It also intersects with this revision: https://reviews.llvm.org/D79759
> Skipping since I am not seeing it used in this revision.
Ah, sorry. This shouldnt be here . Removed.
================
Comment at: mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp:264
+/// modified tensor type.
+static UnitExtentReplacementInfo replaceUnitExtents(AffineMap indexMap,
+ RankedTensorType type,
----------------
nicolasvasilache wrote:
> In the grander scheme, should this function be exposed?
>
> I am asking because the way it is currently used, it takes and indexMap that comes from a GenericOp.
> Are there cases where we'd want to use this but the genericOp may be invalid IR (because inversePermutation would fail) ?
>
> Should this be refactored so that it can also help build correct IR that is canonical by construction?
>
> No need to necessarily change, I am mostly asking form your higher-level user's perspective.
I am not sure what you mean by exposed. This is a static function that is useful only for the callsite. The issue here might be that this was added to LinalgOps.cpp. Moving all of this into a separate pass. So it is not visible to op parsing/verification, etc. So in effect it isnt exposed outside of the pass.
Regarding other questions
1) It should work on any indexMap. I dont see this to be related to whether the genericOp is valid or not since it is only looking at a single index map. The validity of the genericOp using the indexMap constructed here is the concern of the caller AFAICS
2) I dont know how to refactor to help build correct IR. I cant trivially see how each index map construction of a generic op can be steered for the final generic op to be valid. The validity depends on concatenation and inversion. How do I break it down to the level of individual index maps?
================
Comment at: mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp:276
+ auto isUnitExtent = [&](int64_t dim) -> bool {
+ return dim < origRank && exprs[dim].isa<AffineConstantExpr>() &&
+ exprs[dim].cast<AffineConstantExpr>().getValue() == 0 &&
----------------
nicolasvasilache wrote:
> ```
> auto zero = getAffineConstantExpr(0, context);
> ```
> then just: `return shape[fdim] == 1 && exprs[dim] == zero;`.
>
> Why can dim overflow here?
> IMO, it would be more natural to `while (dim < rank && isUnitExtent(dim))`
Changed. Thanks!
================
Comment at: mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp:281
+
+ unsigned dim = 0;
+ // Fold dimensions that are unit-extent at the beginning of the tensor.
----------------
nicolasvasilache wrote:
> This should assert at least one dim is not a unitExtent otherwise reassociationMaps is empty?
>
Turns out it is OK to have all dims being unitExtent. Added a test for that.
================
Comment at: mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp:286
+ }
+ for (; dim < origRank; ++dim) {
+ reassociations.push_back(getAffineDimExpr(dim, context));
----------------
nicolasvasilache wrote:
> I find the multi-`dim++` a bit tricky to follow but I can't see offhand how to make it more readable.
Made this a while loop and made the dim++ more explicit.
================
Comment at: mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp:1478
+/// broadcasting.
+void mlir::populateLinalgFoldUnitExtentDimsPatterns(
+ MLIRContext *context, OwningRewritePatternList &patterns) {
----------------
nicolasvasilache wrote:
> The layering does not seem right to me. Traditionally we either:
> 1. create an op canonicalization patterns for a specific op (i.e. set hasCanonicalization = 1 in Ops.td) and later collect them with OpName::getCanonicalizationPatterns() but this is not what you want here.
> 2. what you want here, create a new pass or transform and put everything there this way Ops.cpp does not ned to include Passes.h
>
> If you're worried that other things are not visible they should also be moved in the proper place.
> Your populate impl should resemble:
> ```
> patterns.insert<ReplaceUnitExtentTensors>(context);
> GenericOp::getCanonicalizationPatterns(patterns);
> TensorReshapeOp::getCanonicalizationPatterns(patterns);
> ```
I agree. I moved everything to a separate file and added the relevant patterns into that.
Repository:
rG LLVM Github Monorepo
CHANGES SINCE LAST ACTION
https://reviews.llvm.org/D79766/new/
https://reviews.llvm.org/D79766
More information about the llvm-commits
mailing list