[PATCH] D78464: [mlir][Linalg] Add support for fusing linalg.tensor_reshape with linalg.generic operations.

Nicolas Vasilache via Phabricator via llvm-commits llvm-commits at lists.llvm.org
Tue Apr 21 14:06:12 PDT 2020


nicolasvasilache requested changes to this revision.
nicolasvasilache added inline comments.
This revision now requires changes to proceed.


================
Comment at: mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp:583
+/// `sourceMap`. Note that this implicitly assumes that the tensors dimensions
+/// are "row-major" ordered logically.
+static AffineMap linearizeCollapsedDims(AffineMap sourceMap,
----------------
Adding the example from your test would be informative:
```
E.g.
%0 = op ... : tensor<?x?x4x5xf32>
with output index_map `affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>`

and reshape:
%1 = linalg.tensor_reshape %0 [affine_map<(i, j, k, l) -> (i)>,
                                 affine_map<(i, j, k, l) -> (j, k, l)>] :
    tensor<?x?x4x5xf32> into tensor<?x?xf32>

would be rewritten into:
%0 = op ... : tensor<?x?x4x5xf32>
with output index_map `affine_map<(d0, d1, d2, d3) -> (d0, d1 * 20 + d2 * 5 + d3)>`
```


================
Comment at: mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp:644
+  // For reshape to be fused, the collapsed (expanded) dimensions of the operand
+  // (result) must be statically shaped.
+  ArrayRef<int64_t> srcShape =
----------------
Could you please add a:
``` 
// TODO: In the future this restriction may justify extending the linalg.generic to semi-affine maps.
// TODO: Alternatively, fusing across a reshape and pushing the reshape towards the boundary of the function could help too.
```


================
Comment at: mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp:648
+  auto reassociationMaps = reshapeOp.getReassociationMaps();
+  unsigned dim = 0;
+  for (auto collapsedDims : reassociationMaps) {
----------------
I wonder if we could slightly refactor the function exposed here https://reviews.llvm.org/D75575
and just use it instead of rewriting a slightly different version.

Basically the existing function would just need a starting dimPos = 0 by default.

You could then just use that to get:
```
for ... {
  AffineExpr stridedExpression = makeCanonicalStridedLayoutExpr(srcShape.slice(), context);
  if (!isPureAffine(stridedExpression))
    return AffineMap();
  results.push_back(stridedExpression);
}
return AffineMap.get(...);
```

And check whether AffineMap is empty to know whether it is fusible.

The downside is that in your current code structure you could be recomputing it twice but IMO this would largely be beneficial to save a few dozen lines of code that is similar to something we already have in the codebase.

If the recomputing bothers you, you could refactor a bit but I wouldn't be worried at this time, these are really trivial computations.


================
Comment at: mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp:681
+    auto consumerIndexMaps = consumer.indexing_maps();
+    SmallVector<Attribute, 4> fusedIndexMaps;
+    fusedIndexMaps.reserve(consumerIndexMaps.size());
----------------
Can we simplify l680 - 695 with something along the lines of:
```
SmallVector<Attribute, 4> fusedIndexMaps(consumerIndexMaps.begin(), consumerIndexMaps.end());
fusedIndexMaps[consumerIdx] = linearize();
```


================
Comment at: mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp:726
+                                  OperationFolder *folder = nullptr) {
+    // Thee indexing_maps for the operands that were originally operands from
+    // the producers are the same as before.
----------------
typo thee


================
Comment at: mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp:729
+    auto producerIndexMaps = producer.indexing_maps();
+    SmallVector<Attribute, 4> fusedIndexMaps;
+    fusedIndexMaps.reserve(producerIndexMaps.size());
----------------
Same here, create then update should be simpler.


================
Comment at: mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp:783
+    return nullptr;
+  if (auto genericOp = dyn_cast<GenericOp>(producer)) {
+    if (genericOp.hasTensorSemantics())
----------------
If you adopt the style in the previous PR, this conditional would just fold into the `Operation*` based form above.


Repository:
  rG LLVM Github Monorepo

CHANGES SINCE LAST ACTION
  https://reviews.llvm.org/D78464/new/

https://reviews.llvm.org/D78464





More information about the llvm-commits mailing list