[PATCH] D79766: [mlir][Linalg] Add pass to remove unit-extent dims from tensor operands of Generic ops.

Mahesh Ravishankar via Phabricator via llvm-commits llvm-commits at lists.llvm.org
Thu May 14 21:56:30 PDT 2020


mravishankar added inline comments.


================
Comment at: mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp:74
+///
+static bool canFold(MemRefCastOp castOp) {
+  MemRefType sourceType = castOp.source().getType().dyn_cast<MemRefType>();
----------------
nicolasvasilache wrote:
> This does not belong here right?
> It also intersects with this revision: https://reviews.llvm.org/D79759
> Skipping since I am not seeing it used in this revision.
Ah, sorry. This shouldnt be here . Removed.


================
Comment at: mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp:264
+///   modified tensor type.
+static UnitExtentReplacementInfo replaceUnitExtents(AffineMap indexMap,
+                                                    RankedTensorType type,
----------------
nicolasvasilache wrote:
> In the grander scheme, should this function be exposed? 
> 
> I am asking because the way it is currently used, it takes and indexMap that comes from a GenericOp.
> Are there cases where we'd want to use this but the genericOp may be invalid IR (because inversePermutation would fail) ?
> 
> Should this be refactored so that it can also help build correct IR that is canonical by construction?
> 
> No need to necessarily change, I am mostly asking form your higher-level user's perspective. 
I am not sure what you mean by exposed. This is a static function that is useful only for the callsite. The issue here might be that this was added to LinalgOps.cpp. Moving all of this into a separate pass. So it is not visible to op parsing/verification, etc. So in effect it isnt exposed outside of the pass.

Regarding other questions
1) It should work on any indexMap. I dont see this to be related to whether the genericOp is valid or not since it is only looking at a single index map. The validity of the genericOp using the indexMap constructed here is the concern of the caller AFAICS
2) I dont know how to refactor to help build correct IR. I cant trivially see how each index map construction of a generic op can be steered for the final generic op to be valid. The validity depends on concatenation and inversion. How do I break it down to the level of individual index maps?



================
Comment at: mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp:276
+  auto isUnitExtent = [&](int64_t dim) -> bool {
+    return dim < origRank && exprs[dim].isa<AffineConstantExpr>() &&
+           exprs[dim].cast<AffineConstantExpr>().getValue() == 0 &&
----------------
nicolasvasilache wrote:
> ```
> auto zero = getAffineConstantExpr(0, context);
> ```
> then just: `return shape[fdim] == 1 && exprs[dim] == zero;`.
> 
> Why can dim overflow here?
> IMO, it would be more natural to `while (dim < rank && isUnitExtent(dim))`
Changed. Thanks!


================
Comment at: mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp:281
+
+  unsigned dim = 0;
+  // Fold dimensions that are unit-extent at the beginning of the tensor.
----------------
nicolasvasilache wrote:
> This should assert at least one dim is not a unitExtent otherwise reassociationMaps is empty?
> 
Turns out it is OK to have all dims being unitExtent. Added a test for that.


================
Comment at: mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp:286
+  }
+  for (; dim < origRank; ++dim) {
+    reassociations.push_back(getAffineDimExpr(dim, context));
----------------
nicolasvasilache wrote:
> I find the multi-`dim++` a bit tricky to follow but I can't see offhand how to make it more readable.
Made this a while loop and made the dim++ more explicit.


================
Comment at: mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp:1478
+/// broadcasting.
+void mlir::populateLinalgFoldUnitExtentDimsPatterns(
+    MLIRContext *context, OwningRewritePatternList &patterns) {
----------------
nicolasvasilache wrote:
> The layering does not seem right to me. Traditionally we either:
> 1. create an op canonicalization patterns for a specific op (i.e. set hasCanonicalization = 1 in Ops.td) and later collect them with OpName::getCanonicalizationPatterns() but this is not what you want here.
> 2. what you want here, create a new pass or transform and put everything there this way Ops.cpp does not ned to include Passes.h
> 
> If you're worried that other things are not visible they should also be moved in the proper place.
> Your populate impl should resemble:
> ```
>   patterns.insert<ReplaceUnitExtentTensors>(context);
>   GenericOp::getCanonicalizationPatterns(patterns);
>   TensorReshapeOp::getCanonicalizationPatterns(patterns);
> ```
I agree. I  moved everything to a separate file and added the relevant patterns into that.


Repository:
  rG LLVM Github Monorepo

CHANGES SINCE LAST ACTION
  https://reviews.llvm.org/D79766/new/

https://reviews.llvm.org/D79766





More information about the llvm-commits mailing list