[Mlir-commits] [mlir] [MLIR][SCF] Add an API to fuse consumer to a producer within scf loop (PR #88712)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon May 20 23:49:13 PDT 2024


================
@@ -160,6 +215,21 @@ struct LinalgOpTilingInterface
     return success();
   }
 
+  FailureOr<TilingResult> getTiledImplementationFromOperandTile(
+      Operation *op, OpBuilder &b, unsigned operandNumber,
+      ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes) const {
+    SmallVector<OpFoldResult> mappedOffsets, mappedSizes;
+    auto tilingInterfaceOp = cast<TilingInterface>(op);
+    if (failed(tilingInterfaceOp.getIterationDomainTileFromOperandTile(
+            b, operandNumber, offsets, sizes, mappedOffsets, mappedSizes))) {
+      return emitError(
+          op->getLoc(),
+          "unable to obtain the iter domain position of the operation.");
+    }
+    return tilingInterfaceOp.getTiledImplementation(b, mappedOffsets,
+                                                    mappedSizes);
----------------
Yun-Fly wrote:

> Moreover this IR is UB unless %0 is forwardable to %4 as it is potentially reading uninitialized data.
> %2, %3 = scf.for iter_args(%dest1, %dest2) {
  %0 = some_producer
  %1 = tensor.insert_slice %0 into %dest1
  %4 = tensor.extract_slice %1
  %5 = tiled_consumer %4
  %6 = tensor.insert %5 into %dest2
  scf.yield %1, %6
}

As you said above, the root cause comes from using `tensor.extract_slice` to extract the tiled result. But NOTE that, this IR is not the FINAL result, after redirecting operand of tiled consumer, there is no use for `%4` any more as I illustrated [above](https://github.com/llvm/llvm-project/pull/88712#:~:text=Then%2C%20the%20resulting%20IR%20will%20become%3A). 

> Maybe we can add an optional operand map to getTiledImplementation that maps from operand index to operand tile?

I guess what you mean here is also using operand tile instead of `extract_slice` inside `getTiledImplementation`. So the difference looks like where we replace it? If wrong, please correct me. BTW, If invalid temporary IR is already not acceptable, I think it began since cloning consumer after `candidateSliceOp` and reset its operand [here](https://github.com/llvm/llvm-project/blob/41123a2871c752cdbbdece54cd7fb11a9a95c768/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp#L1407).

> So the only way that I see to guarantee valid tiling is if getTiledImplementationFromOperandTile takes the operand tile it is expected to consume as an input.

I know your point. However, compared with one line fix, it may take more time or changes involving more than one function I guess. 

Also, except for this issue you had pointed, please assume another complex use case:

```
%0 = linalg.op -> tensor<256,256>
%1 = scf.forall(%dest1) {
    extract_slice
    %2 = scf.for iter_args(%dest2){
           extract_slice    
           %3 =  tiled_producer  
           %4 = insert_slice %3 into %dest2 [of11, of12][sz11, sz12][1,1] tensor<32,32> to tensor<128,128>
           yield %4
    }
   scf.forall.in_parallel {
      parallel_insert_slice %2 into %dest1 [of21, of22][sz21, sz22][1,1] tensor<128,128> to tensor<256,256>
   }
}
// consumer to be fused
%5 = linalg.add(%0, %1: tensor<256,256>, tensor<256,256>) -> tensor<256,256>
```
Given current implementation, even if we correctly address one of the operand tile(like `%1`) of `linalg.add` like you said, it is hard to deal with another one(like `%0`) because the argument `offset` and `size` of `getTiledImplementationFromOperandTile` are both calculated by `candidateSliceOp`(`%4 = insert_slice %3`) based on intermediate buffer(`tensor<128,128>`), not coordinated on the original one(`tensor<256,256>`). Then, they will be used to infer other operand tile info under current implementation, which consequently leads to wrong or unmatched `offset` and `size` for `%0`, similar to the issue you found.

An optional way is that we can calculate the real `offset` and `size` coordinated on the original one without any tile in `replaceInsertSliceWithTiledConsumer`(need enhanced) before calling `getTiledImplementationFromOperandTile`, and that is also why I suggest to reset corresponding operand there.

Although, this scenario maybe not occur due to current limitation of implementation, it is a issue to be solved especially when we want to extend this feature to nested loop cases.

Anyway, it is up to author and you how to address issue you raised in this PR.


https://github.com/llvm/llvm-project/pull/88712


More information about the Mlir-commits mailing list