[PATCH] D79766: [mlir][Linalg] Add pass to remove unit-extent dims from tensor operands of Generic ops.

Nicolas Vasilache via Phabricator via llvm-commits llvm-commits at lists.llvm.org
Tue May 12 17:48:25 PDT 2020


nicolasvasilache requested changes to this revision.
nicolasvasilache added a comment.
This revision now requires changes to proceed.

Generally looks good, thanks Mahesh.
I left some code laering comments.



================
Comment at: mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp:74
+///
+static bool canFold(MemRefCastOp castOp) {
+  MemRefType sourceType = castOp.source().getType().dyn_cast<MemRefType>();
----------------
This does not belong here right?
It also intersects with this revision: https://reviews.llvm.org/D79759
Skipping since I am not seeing it used in this revision.


================
Comment at: mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp:264
+///   modified tensor type.
+static UnitExtentReplacementInfo replaceUnitExtents(AffineMap indexMap,
+                                                    RankedTensorType type,
----------------
In the grander scheme, should this function be exposed? 

I am asking because the way it is currently used, it takes and indexMap that comes from a GenericOp.
Are there cases where we'd want to use this but the genericOp may be invalid IR (because inversePermutation would fail) ?

Should this be refactored so that it can also help build correct IR that is canonical by construction?

No need to necessarily change, I am mostly asking form your higher-level user's perspective. 


================
Comment at: mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp:276
+  auto isUnitExtent = [&](int64_t dim) -> bool {
+    return dim < origRank && exprs[dim].isa<AffineConstantExpr>() &&
+           exprs[dim].cast<AffineConstantExpr>().getValue() == 0 &&
----------------
```
auto zero = getAffineConstantExpr(0, context);
```
then just: `return shape[fdim] == 1 && exprs[dim] == zero;`.

Why can dim overflow here?
IMO, it would be more natural to `while (dim < rank && isUnitExtent(dim))`


================
Comment at: mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp:281
+
+  unsigned dim = 0;
+  // Fold dimensions that are unit-extent at the beginning of the tensor.
----------------
This should assert at least one dim is not a unitExtent otherwise reassociationMaps is empty?



================
Comment at: mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp:283
+  // Fold dimensions that are unit-extent at the beginning of the tensor.
+  while (isUnitExtent(dim)) {
+    reassociations.push_back(getAffineDimExpr(dim++, context));
----------------
trivial braces here and below.


================
Comment at: mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp:286
+  }
+  for (; dim < origRank; ++dim) {
+    reassociations.push_back(getAffineDimExpr(dim, context));
----------------
I find the multi-`dim++` a bit tricky to follow but I can't see offhand how to make it more readable.


================
Comment at: mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp:354
+
+    // If any operand types change, insert a reshape to convert from the
+    // original type to the new type.
----------------
typo: result


================
Comment at: mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp:1478
+/// broadcasting.
+void mlir::populateLinalgFoldUnitExtentDimsPatterns(
+    MLIRContext *context, OwningRewritePatternList &patterns) {
----------------
The layering does not seem right to me. Traditionally we either:
1. create an op canonicalization patterns for a specific op (i.e. set hasCanonicalization = 1 in Ops.td) and later collect them with OpName::getCanonicalizationPatterns() but this is not what you want here.
2. what you want here, create a new pass or transform and put everything there this way Ops.cpp does not ned to include Passes.h

If you're worried that other things are not visible they should also be moved in the proper place.
Your populate impl should resemble:
```
  patterns.insert<ReplaceUnitExtentTensors>(context);
  GenericOp::getCanonicalizationPatterns(patterns);
  TensorReshapeOp::getCanonicalizationPatterns(patterns);
```


================
Comment at: mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp:838
 
+std::unique_ptr<OperationPass<FuncOp>>
+mlir::createLinalgFoldUnitExtentDimsPass() {
----------------
This should be its own pass in its own file, with its own pattern it is unclear to me why it is in the Fusion pass. 


Repository:
  rG LLVM Github Monorepo

CHANGES SINCE LAST ACTION
  https://reviews.llvm.org/D79766/new/

https://reviews.llvm.org/D79766





More information about the llvm-commits mailing list