[Mlir-commits] [mlir] [MLIR][SCF] Add an API to fuse consumer to a producer within scf loop (PR #88712)

Abhishek Varma llvmlistbot at llvm.org
Tue May 21 03:28:26 PDT 2024


================
@@ -160,6 +215,21 @@ struct LinalgOpTilingInterface
     return success();
   }
 
+  FailureOr<TilingResult> getTiledImplementationFromOperandTile(
+      Operation *op, OpBuilder &b, unsigned operandNumber,
+      ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes) const {
+    SmallVector<OpFoldResult> mappedOffsets, mappedSizes;
+    auto tilingInterfaceOp = cast<TilingInterface>(op);
+    if (failed(tilingInterfaceOp.getIterationDomainTileFromOperandTile(
+            b, operandNumber, offsets, sizes, mappedOffsets, mappedSizes))) {
+      return emitError(
+          op->getLoc(),
+          "unable to obtain the iter domain position of the operation.");
+    }
+    return tilingInterfaceOp.getTiledImplementation(b, mappedOffsets,
+                                                    mappedSizes);
----------------
Abhishek-Varma wrote:

Hi @qedawkins

So to deal with the above situation better, I've created two APIs now - `getTiledImplementationAsProducer` and `getTiledImplementationAsConsumer` - this helps us get around the issue you pointed out earlier. Please take a look.

Reasons for doing this :-
1. Previously having just `getTiledImplementation` worked because we didn't have use case for fusing as a consumer, but now that we do, this distinction in the function naming sounds a better design to me. Also, downstream integration can simply replace every `getTiledImplementation` with `getTiledImplementationAsProducer`.
2. During the downstream integration we don't want to add additional overhead of sending `std::nullopt` in all invocations of `getTiledImplementation` - please let me know if there's a way that I can achieve the following if you feel the above design is not welcomed :-
```
API: getTiledImplementation(Operation*, Builder, Offsets, Sizes, Optional<DenseMap<uint64_t, Value>>)

User:
// As a producer.
getTiledImplementation(op, b, offset, size);
// As a consumer.
getTiledImplementation(op, b, offset, size, operandTileMap);
```

I tried using `::std::optional<::llvm::DenseMap<uint64_t, Value>>` in [TilingInterface.td](https://github.com/llvm/llvm-project/blob/375761bcaba9ba5694890ece6eed5db286fc4fd1/mlir/include/mlir/Interfaces/TilingInterface.td#L71) and then tried assigning a default value (`std::nullopt`) to the argument when defining the function definition [here in TilingInterfaceImpl.cpp](https://github.com/llvm/llvm-project/blob/main/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp#L115-L117), it required every invocation to explicitly send in that argument during invocation. Perhaps I'm missing some ODS specific declaration instead where this default value (`std::nullopt`) needs to be assigned ? Will need pointers for this in case the current design is not welcomed.

https://github.com/llvm/llvm-project/pull/88712


More information about the Mlir-commits mailing list