[Mlir-commits] [mlir] Extend `TilingInterface` to allow more flexible tiling (PR #95422)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Jun 13 08:36:53 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-linalg

Author: Srinath Avadhanula (srinathava)

<details>
<summary>Changes</summary>

Ref: [discource thread](https://discourse.llvm.org/t/extending-tileconsumerandfuseproducer-to-handle-more-patterns/79340/2)

Problem:

The current version of `transform.structured.fuse` relies on ops implementing a `TilingInterface`. An op which implements such an interface returns a  `TilingResult` [defined](https://sourcegraph.robot.car/github.robot.car/cruise/mla-robocomp-llvm-project/-/blob/mlir/include/mlir/Interfaces/TilingInterface.h?L26) as:

```c++
/// Container for result values of tiling.
/// - `tiledOps` contains operations created by the tiling implementation that
/// are returned to the caller for further transformations.
/// - `tiledValues` contains the tiled value corresponding to the result of the
/// untiled operation.
struct TilingResult {
  SmallVector<Operation *> tiledOps;
  SmallVector<Value> tiledValues;
};
```

The way the algorithm is currently implemented, only the _last_ operation in `tiledOps` is considered for further fusion. 

Where it breaks down is when we implement a `TilingInterface` for the `tosa.concat` operation like so (MLIR pseudo-code):

```mlir
%slice = scf.if (%offset < size(t1)) (
    scf.yield tensor.extract_slice %arg1 ...
} else {
    scf.yield tensor.extract_slice %arg2 ...
}
```

Even if both the `scf.yield` ops are returned in the `tiledOps` field, only the last one is further fused with upstream producers.

In this PR, we now extend `TilingResult` to contain a list of `tensor::ExtractSliceOps`. This allows the interface to directly return the list of slice ops it created to implement the tiled result. This required some plumbing of the tensor::ExtractSliceOps through TilingResult -> SCFTilingResult -> SCFFuseProducerOfSliceResult. This is then used to add to the worklist of extract slice ops which we process. This also required the current LinalgTilingInterface to provide the extractSliceOps.
 

---
Full diff: https://github.com/llvm/llvm-project/pull/95422.diff


5 Files Affected:

- (modified) mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h (+2) 
- (modified) mlir/include/mlir/Interfaces/TilingInterface.h (+4) 
- (modified) mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp (+7-1) 
- (modified) mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp (+9-2) 
- (modified) mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp (+19-18) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
index dac79111af3c9..fecd33193eb0d 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
@@ -85,6 +85,7 @@ struct SCFTilingResult {
   /// Values to use as replacements for the untiled op. Is the same size as the
   /// number of results of the untiled op.
   SmallVector<Value> replacements;
+  SmallVector<Operation *> extractSliceOps;
 };
 
 /// Method to tile an op that implements the `TilingInterface` using
@@ -135,6 +136,7 @@ struct SCFFuseProducerOfSliceResult {
   OpResult origProducer;       // Original untiled producer.
   Value tiledAndFusedProducer; // Tile and fused producer value.
   SmallVector<Operation *> tiledOps;
+  SmallVector<Operation *> extractSliceOps;
 };
 std::optional<SCFFuseProducerOfSliceResult>
 tileAndFuseProducerOfSlice(RewriterBase &rewriter,
diff --git a/mlir/include/mlir/Interfaces/TilingInterface.h b/mlir/include/mlir/Interfaces/TilingInterface.h
index ca570490ccf5b..e5ed016d53fc1 100644
--- a/mlir/include/mlir/Interfaces/TilingInterface.h
+++ b/mlir/include/mlir/Interfaces/TilingInterface.h
@@ -28,9 +28,13 @@ namespace mlir {
 /// are returned to the caller for further transformations.
 /// - `tiledValues` contains the tiled value corresponding to the result of the
 /// untiled operation.
+/// - `extractSliceOps` contains all the `tensor.extract_slice` ops used in
+/// generating the `tiledOps`. Usually these are operands to the `tiledOps`
+/// but they can be embedded in regions owned by `tiledOps`.
 struct TilingResult {
   SmallVector<Operation *> tiledOps;
   SmallVector<Value> tiledValues;
+  SmallVector<Operation *> extractSliceOps;
 };
 
 } // namespace mlir
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index b79afebfa8158..5198e0bceaa6e 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -2501,7 +2501,13 @@ SoftmaxOp::getTiledImplementation(OpBuilder &builder,
   Operation *tiledOp =
       mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
 
-  return TilingResult{{tiledOp}, SmallVector<Value>(tiledOp->getResults())};
+  SmallVector<Operation *> sliceOps;
+  for (Value operand : tiledOperands)
+    if (auto sliceOp = operand.getDefiningOp<tensor::ExtractSliceOp>())
+      sliceOps.push_back(sliceOp);
+
+  return TilingResult{
+      {tiledOp}, SmallVector<Value>(tiledOp->getResults()), sliceOps};
 }
 
 LogicalResult SoftmaxOp::getResultTilePosition(
diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
index c3ab3cecfada7..f25ccc38ba0a3 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
@@ -129,7 +129,13 @@ struct LinalgOpTilingInterface
     Operation *tiledOp = clone(b, linalgOp, resultTensorTypes, tiledOperands);
     offsetIndices(b, cast<LinalgOp>(tiledOp), offsets);
 
-    return TilingResult{{tiledOp}, SmallVector<Value>(tiledOp->getResults())};
+    SmallVector<Operation *> sliceOps;
+    for (Value operand : tiledOperands)
+      if (auto sliceOp = operand.getDefiningOp<tensor::ExtractSliceOp>())
+        sliceOps.push_back(sliceOp);
+
+    return TilingResult{
+        {tiledOp}, SmallVector<Value>(tiledOp->getResults()), sliceOps};
   }
 
   /// Utility to fetch the offsets and sizes when applied as per the indexing
@@ -247,7 +253,8 @@ struct LinalgOpTilingInterface
 
     return TilingResult{
         tilingResult->tiledOps,
-        SmallVector<Value>{tilingResult->tiledValues[resultNumber]}};
+        SmallVector<Value>{tilingResult->tiledValues[resultNumber]},
+        tilingResult->extractSliceOps};
   }
 
   /// Method to generate the tiled implementation of an operation from the tile
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index f3d6b7a530117..fb3ec2a5fa0a8 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -619,7 +619,8 @@ mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op,
     if (llvm::all_of(tileSizes, isZeroIndex)) {
       tiledResults.append(clonedOp->result_begin(), clonedOp->result_end());
       tilingResult =
-          TilingResult{/*tiledOps=*/{clonedOp}, clonedOp->getResults()};
+          TilingResult{/*tiledOps=*/{clonedOp}, clonedOp->getResults(),
+                       /*extractSliceOps=*/{}};
       return success();
     }
 
@@ -675,12 +676,14 @@ mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op,
   // op.
   if (loops.empty()) {
     return scf::SCFTilingResult{tilingResult->tiledOps, loops,
-                                tilingResult->tiledValues};
+                                tilingResult->tiledValues,
+                                tilingResult->extractSliceOps};
   }
 
   SmallVector<Value> replacements = llvm::map_to_vector(
       loops.front()->getResults(), [](OpResult r) -> Value { return r; });
-  return scf::SCFTilingResult{tilingResult->tiledOps, loops, replacements};
+  return scf::SCFTilingResult{tilingResult->tiledOps, loops, replacements,
+                              tilingResult->extractSliceOps};
 }
 
 FailureOr<scf::SCFReductionTilingResult>
@@ -931,9 +934,9 @@ mlir::scf::tileAndFuseProducerOfSlice(
         ->getOpOperands()[destinationInitArg.value()->getOperandNumber()]
         .set(origDestinationTensors[resultNumber]);
   }
-  return scf::SCFFuseProducerOfSliceResult{fusableProducer,
-                                           tileAndFuseResult->tiledValues[0],
-                                           tileAndFuseResult->tiledOps};
+  return scf::SCFFuseProducerOfSliceResult{
+      fusableProducer, tileAndFuseResult->tiledValues[0],
+      tileAndFuseResult->tiledOps, tileAndFuseResult->extractSliceOps};
 }
 
 /// Reconstruct the fused producer from within the tiled-and-fused code.
@@ -962,13 +965,12 @@ LogicalResult mlir::scf::yieldReplacementForFusedProducer(
                   .getDefiningOp<DestinationStyleOpInterface>()) {
         rewriter.setInsertionPoint(tiledDestStyleOp);
         Value newRegionArg = newRegionIterArgs.back();
-        auto destSlice = rewriter.create<tensor::ExtractSliceOp>(
-            sliceOp.getLoc(), newRegionArg, sliceOp.getMixedOffsets(),
-            sliceOp.getMixedSizes(), sliceOp.getMixedStrides());
         unsigned resultNumber = fusableProducer.getResultNumber();
-        rewriter.modifyOpInPlace(tiledDestStyleOp, [&]() {
-          tiledDestStyleOp.getDpsInitsMutable()[resultNumber].set(destSlice);
-        });
+        auto origSlice = tiledDestStyleOp.getDpsInits()[resultNumber]
+                             .getDefiningOp<tensor::ExtractSliceOp>();
+        if (origSlice) {
+          origSlice.getSourceMutable().set(newRegionArg);
+        }
       }
       Block *block = rewriter.getInsertionPoint()->getBlock();
       rewriter.setInsertionPoint(block->getTerminator());
@@ -1036,15 +1038,14 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF(
   //    operations. If the producers of the source of the `tensor.extract_slice`
   //    can be tiled such that the tiled value is generated in-place, that
   //    effectively tiles + fuses the operations.
-  auto addCandidateSlices = [](Operation *fusedOp,
+  auto addCandidateSlices = [](const SmallVector<Operation *> &newSliceOps,
                                std::deque<tensor::ExtractSliceOp> &candidates) {
-    for (Value operand : fusedOp->getOperands())
-      if (auto sliceOp = operand.getDefiningOp<tensor::ExtractSliceOp>())
-        candidates.push_back(sliceOp);
+    for (auto *op : newSliceOps)
+      candidates.push_back(llvm::cast<tensor::ExtractSliceOp>(op));
   };
 
   std::deque<tensor::ExtractSliceOp> candidates;
-  addCandidateSlices(tiledAndFusedOps.back(), candidates);
+  addCandidateSlices(tilingResult->extractSliceOps, candidates);
   OpBuilder::InsertionGuard g(rewriter);
   while (!candidates.empty()) {
     // Traverse the slices in BFS fashion.
@@ -1086,7 +1087,7 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF(
             fusedResult->tiledAndFusedProducer.getDefiningOp()) {
       fusedProducers.insert(fusedResult->origProducer.getDefiningOp());
       tiledAndFusedOps.insert(tiledAndFusedOp);
-      addCandidateSlices(tiledAndFusedOp, candidates);
+      addCandidateSlices(fusedResult->extractSliceOps, candidates);
     }
   }
 

``````````

</details>


https://github.com/llvm/llvm-project/pull/95422


More information about the Mlir-commits mailing list