[Mlir-commits] [mlir] [mlir][TilingInterface] Use `LoopLikeOpInterface` in tiling using SCF to unify tiling with `scf.for` and `scf.forall`. (PR #77874)

Matthias Springer llvmlistbot at llvm.org
Sun Jan 14 05:09:30 PST 2024


================
@@ -622,6 +626,47 @@ LogicalResult scf::ForallOp::promoteIfSingleIteration(RewriterBase &rewriter) {
   return success();
 }
 
+Block::BlockArgListType ForallOp::getRegionIterArgs() {
+  return getBody()->getArguments().drop_front(getRank());
+}
+
+MutableArrayRef<OpOperand> ForallOp::getInitsMutable() {
+  return getOutputsMutable();
+}
+
+FailureOr<LoopLikeOpInterface>
+ForallOp::replaceWithAdditionalYields(RewriterBase &rewriter,
+                                      ValueRange newInitOperands,
+                                      bool replaceInitOperandUsesInLoop,
+                                      const NewYieldValuesFn &newYieldValueFn) {
+  OpBuilder::InsertionGuard g(rewriter);
+  rewriter.setInsertionPoint(getOperation());
+  auto inits = llvm::to_vector(getOutputs());
+  inits.append(newInitOperands.begin(), newInitOperands.end());
+  auto newLoop = rewriter.create<scf::ForallOp>(
+      getLoc(), getMixedLowerBound(), getMixedUpperBound(), getMixedStep(),
+      inits, getMapping(), [](OpBuilder &, Location, ValueRange) {});
+
+  // Move the region of the current block to the newly created op.
+  Block *newLoopBody = newLoop.getBody();
+  rewriter.mergeBlocks(
+      getBody(), newLoopBody,
+      newLoopBody->getArguments().take_front(getBody()->getNumArguments()));
+
+  // Update the terminator.
+  {
+    OpBuilder::InsertionGuard g(rewriter);
+    auto terminator = cast<scf::InParallelOp>(newLoopBody->getTerminator());
+    rewriter.setInsertionPointToEnd(terminator.getBody());
+    newYieldValueFn(
----------------
matthias-springer wrote:

> I think there is an inherent mismatch between the behavior of `scf.forall` and other loop constructs in the sense it doesnt yield anything.

I ran into the same problem when cleaning up the `LoopLikeOpInterface` recently. I was able to add the interface to `scf.while`, but not `scf.forall` because of this issue.

I think we should generalize `NewYieldValuesFn` to `RewriteTerminatorFn`. That lambda can do whatever it wants.

```c++
/// A function that updates the terminator of a loop-like op during
/// `replaceWithAdditionalYields`. `newBbArgs` are the newly added region
/// iter_args. This function should rewrite the loop terminator (`terminator`)
/// such that the loop op verifies successfully.
using RewriteTerminatorFn =
    std::function<void(RewriterBase &rewriter, Operation *terminator, ArrayRef<BlockArgument> newBbArgs)>;
```

I think we also have to turn this interface function into an interface method:
```c++
    /// Append the specified additional "init" operands: replace this loop with
    /// a new loop that has the additional init operands. The loop body of this
    /// loop is moved over to the new loop.
    ///
    /// The newly added region iter_args are yielded from the loop.
    ::mlir::FailureOr<::mlir::LoopLikeOpInterface>
        replaceWithAdditionalIterOperands(::mlir::RewriterBase &rewriter,
                                          ::mlir::ValueRange newInitOperands,
                                          bool replaceInitOperandUsesInLoop)
```

For `scf.forall`, this interface could either:
- "return failure();"
- or: insert a parallel_insert_slice that inserts the newly added bbArg into itself (which is a no-op)
- or: remove this interface method entirely


https://github.com/llvm/llvm-project/pull/77874


More information about the Mlir-commits mailing list