[Mlir-commits] [mlir] [MLIR][SCF] Add an API to fuse consumer to a producer within scf loop (PR #88712)

Quinn Dawkins llvmlistbot at llvm.org
Sat May 18 11:48:26 PDT 2024


================
@@ -160,6 +215,21 @@ struct LinalgOpTilingInterface
     return success();
   }
 
+  FailureOr<TilingResult> getTiledImplementationFromOperandTile(
+      Operation *op, OpBuilder &b, unsigned operandNumber,
+      ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes) const {
+    SmallVector<OpFoldResult> mappedOffsets, mappedSizes;
+    auto tilingInterfaceOp = cast<TilingInterface>(op);
+    if (failed(tilingInterfaceOp.getIterationDomainTileFromOperandTile(
+            b, operandNumber, offsets, sizes, mappedOffsets, mappedSizes))) {
+      return emitError(
+          op->getLoc(),
+          "unable to obtain the iter domain position of the operation.");
+    }
+    return tilingInterfaceOp.getTiledImplementation(b, mappedOffsets,
+                                                    mappedSizes);
----------------
qedawkins wrote:

I think the cause of the problem I noted here: https://github.com/llvm/llvm-project/pull/88712/files#r1605830351

is because here this is returning a tiled implementation based on offsets and sizes of the iteration space, but does not know that one of the operands is already tiled. Concretely, we start with something like this
```
%2 = scf.for iter_args(%dest) {
  %0 = some_producer
  %1 = tensor.insert_slice %0 into %dest
  scf.yield %1
}
%3 = some_tilable_op %2
```
It gets (intermediately) rewritten to something like this
```
%3 = scf.for iter_args(%dest) {
  %0 = some_producer
  %1 = some_tilable_op %0 # INVALID, not yet tiled
  %2 = tensor.insert_slice %1 into %dest
  scf.yield %2
}
```
And then the tiled implementation is generated
```
%3 = scf.for iter_args(%dest) {
  %0 = some_producer
  %slice = tensor.extract_slice %0
  %1 = some_tiling_result %slice
  %2 = tensor.insert_slice %1 into %dest
  scf.yield %2
}
```

This API looks to be enforcing a slightly different contract than `getTiledImplementation` requires, namely that the tiled implementation must produce the requested tile AND consume the requested operand tile. I don't have a great immediate suggestion that doesn't involve rewriting `getTiledImplementation`. Maybe we can add an optional operand map to `getTiledImplementation` that maps from operand index to operand tile?

https://github.com/llvm/llvm-project/pull/88712


More information about the Mlir-commits mailing list