[Mlir-commits] [mlir] [MLIR][SCF] Add an API to fuse consumer to a producer within scf loop (PR #88712)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon May 20 00:29:36 PDT 2024


================
@@ -160,6 +215,21 @@ struct LinalgOpTilingInterface
     return success();
   }
 
+  FailureOr<TilingResult> getTiledImplementationFromOperandTile(
+      Operation *op, OpBuilder &b, unsigned operandNumber,
+      ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes) const {
+    SmallVector<OpFoldResult> mappedOffsets, mappedSizes;
+    auto tilingInterfaceOp = cast<TilingInterface>(op);
+    if (failed(tilingInterfaceOp.getIterationDomainTileFromOperandTile(
+            b, operandNumber, offsets, sizes, mappedOffsets, mappedSizes))) {
+      return emitError(
+          op->getLoc(),
+          "unable to obtain the iter domain position of the operation.");
+    }
+    return tilingInterfaceOp.getTiledImplementation(b, mappedOffsets,
+                                                    mappedSizes);
----------------
Yun-Fly wrote:

```
%2 = scf.for iter_args(%dest) {
  %0 = some_producer
  %1 = tensor.insert_slice %0 into %dest
  scf.yield %1
}
%3 = some_tilable_op %2
```

>From my understanding, the result through current implementation will look like:

```
%2, %3 = scf.for iter_args(%dest1, %dest2) {
  %0 = some_producer
  %1 = tensor.insert_slice %0 into %dest1
  %4 = tensor.extract_slice %1
  %5 = tiled_consumer %4
  %6 = tensor.insert %5 into %dest2
  scf.yield %1, %6
}
```

> Maybe we can add an optional operand map to getTiledImplementation that maps from operand index to operand tile?

Besides this method, I think another possible option is that we can directly RESET corresponding OPERAND of tiled consumer from the result of `extractSliceOp`_(generated by new tiled consumer)_ to the SOURCE of `InsertSliceOp`_(generated by tiled producer)_ in function `replaceInsertSliceWithTiledConsumer` where all necessary information is avaliable and it just takes one line code change. Then, the resulting IR will become:

```
%2, %3 = scf.for iter_args(%dest1, %dest2) {
  %0 = some_producer
  %1 = tensor.insert_slice %0 into %dest1
  %4 = tensor.extract_slice %1 // expect to be simplified later
  %5 = tiled_consumer %0 // reset `%4` to `%0`, which is the source of `%1` and actual tiled result of producer,
  %6 = tensor.insert %5 into %dest2
  scf.yield %1, %6
}
```

https://github.com/llvm/llvm-project/pull/88712


More information about the Mlir-commits mailing list