[Mlir-commits] [mlir] [mlir][scf] Extend consumer fuse to single nested `scf.for` (PR #108318)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Sep 11 20:00:36 PDT 2024


================
@@ -1754,79 +1725,108 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
   if (failed(tileAndFuseResult)) {
     return failure();
   }
-  rewriter.replaceAllUsesWith(
-      tileAndFuseResult->tiledOps[0]->getOperand(operandNumber),
-      clonedInsertSliceOp.getSource());
-
-  // 8 - Extract offset/sizes/strides required to create the
-  // tensor.insert_slice/parallel_insert_slice for each result of the consumer.
-  SmallVector<OpFoldResult> offsets = ossSliceOp.getMixedOffsets();
-  SmallVector<OpFoldResult> sizes = ossSliceOp.getMixedSizes();
-  SmallVector<OpFoldResult> strides = ossSliceOp.getMixedStrides();
-
-  // 9. Check all insert stride is 1.
-  if (llvm::any_of(strides, [](OpFoldResult stride) {
-        return !isConstantIntValue(stride, 1);
-      })) {
-    return rewriter.notifyMatchFailure(
-        candidateSliceOp, "containingOp's result yield with stride");
-  }
+  auto tiledConsumerOp = cast<TilingInterface>(tileAndFuseResult->tiledOps[0]);
+  rewriter.replaceAllUsesWith(tiledConsumerOp->getOperand(operandNumber),
+                              clonedInsertSliceOp.getSource());
 
-  // 10. Try to get iter domain position from input position.
-  SmallVector<OpFoldResult> iterDomainOffsets, iterDomainSizes;
-  if (failed(clonedConsumerOp.getIterationDomainTileFromOperandTile(
-          rewriter, operandNumber, offsets, sizes, iterDomainOffsets,
-          iterDomainSizes))) {
-    return rewriter.notifyMatchFailure(
-        clonedConsumerOp, "can't get iter domain position from input position");
-  }
+  // 7. Reconstruct [nested] loop with new inits.
+  YieldTiledValuesFn newYieldValuesFn =
+      [&](RewriterBase &innerRewriter, Location loc, ValueRange /*ivs*/,
+          ValueRange newRegionIterArgs, SmallVector<Value> &tiledResult,
+          SmallVector<SmallVector<OpFoldResult>> &tiledOffset,
+          SmallVector<SmallVector<OpFoldResult>> &tiledSizes) -> LogicalResult {
+    OpBuilder::InsertionGuard g(innerRewriter);
+    // 8. Set inner insertPoint right before tiled consumer op.
+    innerRewriter.setInsertionPoint(tiledConsumerOp);
 
-  // 11. Try to fetch the offset and size for all results of the cloned
-  // consumer. This would then be used to form the corresponding
-  // tensor.insert_slice/parallel_insert_slice later.
-  unsigned totalNumResultsOfConsumer = clonedConsumerOp->getNumResults();
-  SmallVector<SmallVector<OpFoldResult>> resultOffsets(
-      totalNumResultsOfConsumer);
-  SmallVector<SmallVector<OpFoldResult>> resultSizes(totalNumResultsOfConsumer);
-  for (auto [idx, v] : llvm::enumerate(clonedConsumerOp->getResults())) {
-    if (failed(clonedConsumerOp.getResultTilePosition(
-            rewriter, idx, iterDomainOffsets, iterDomainSizes,
-            resultOffsets[idx], resultSizes[idx]))) {
+    SmallVector<OpFoldResult> offsets = ossSliceOp.getMixedOffsets();
+    SmallVector<OpFoldResult> sizes = ossSliceOp.getMixedSizes();
+    SmallVector<OpFoldResult> strides = ossSliceOp.getMixedStrides();
+
+    // 9. Check all insert stride is 1.
+    if (llvm::any_of(strides, [](OpFoldResult stride) {
+          return !isConstantIntValue(stride, 1);
+        })) {
       return rewriter.notifyMatchFailure(
-          clonedConsumerOp,
-          "can't get result domain position from iter domain position");
+          candidateSliceOp, "containingOp's result yield with stride");
     }
-  }
 
-  auto arrayRefOffsets = ArrayRef<SmallVector<OpFoldResult>>(resultOffsets);
-  auto arrayRefSizes = ArrayRef<SmallVector<OpFoldResult>>(resultSizes);
-  if (isInsertSliceOp) {
-    auto newForOp = cast<scf::ForOp>(newLoopOp);
-    fixTerminatorSCFYield(
-        rewriter, newForOp, *tileAndFuseResult, arrayRefOffsets, arrayRefSizes,
-        newForOp.getBody()->getArguments().drop_front(1 + initSize));
-  } else {
-    auto newForallOp = cast<scf::ForallOp>(newLoopOp);
-    fixTerminatorSCFInParallel(
-        rewriter, newForallOp, tileAndFuseResult->tiledOps[0]->getResults(),
-        arrayRefOffsets, arrayRefSizes,
-        newForallOp.getBody()->getArguments().drop_front(rank + initSize));
-  }
+    // 10. Try to get iter domain position from input position.
+    SmallVector<OpFoldResult> iterDomainOffsets, iterDomainSizes;
+    if (failed(tiledConsumerOp.getIterationDomainTileFromOperandTile(
+            rewriter, operandNumber, offsets, sizes, iterDomainOffsets,
+            iterDomainSizes))) {
+      return rewriter.notifyMatchFailure(
+          tiledConsumerOp,
+          "can't get iter domain position from input position");
+    }
 
-  // 12. Replace the result of scf loop and consumer op with new loop's results.
-  for (auto &&[oldResult, newResult] :
-       llvm::zip_first(oldLoopOp->getResults(), newLoopOp->getResults())) {
-    rewriter.replaceAllUsesWith(oldResult, newResult);
+    // 11. Try to fetch the offset and size for all results of the cloned
+    // consumer. This would then be used to form the corresponding
+    // tensor.insert_slice/parallel_insert_slice later.
+    unsigned totalNumResultsOfConsumer = tiledConsumerOp->getNumResults();
+    SmallVector<SmallVector<OpFoldResult>> resultOffsets(
+        totalNumResultsOfConsumer);
+    SmallVector<SmallVector<OpFoldResult>> resultSizes(
+        totalNumResultsOfConsumer);
+    for (auto [idx, v] : llvm::enumerate(tiledConsumerOp->getResults())) {
+      if (failed(tiledConsumerOp.getResultTilePosition(
+              rewriter, idx, iterDomainOffsets, iterDomainSizes,
+              resultOffsets[idx], resultSizes[idx]))) {
+        return rewriter.notifyMatchFailure(
+            tiledConsumerOp,
+            "can't get result domain position from iter domain position");
+      }
+    }
+
+    // 12. Create `extract_slice` for `iter_args` for DPS operation if
+    // necessary.
+    if (auto tiledDestStyleOp = dyn_cast<DestinationStyleOpInterface>(
+            tiledConsumerOp.getOperation())) {
+      rewriter.setInsertionPoint(tiledDestStyleOp);
+      for (const auto &&[index, newRegionArg] :
+           llvm::enumerate(newRegionIterArgs)) {
+        auto destSlice = rewriter.create<tensor::ExtractSliceOp>(
+            loc, newRegionArg, resultOffsets[index], resultSizes[index],
+            SmallVector<OpFoldResult>(resultOffsets[index].size(),
+                                      rewriter.getIndexAttr(1)));
+        // Make C++ 17 happy, otherwise it will throw error `captured structured
+        // bindings are a C++20 extension`.
----------------
Yun-Fly wrote:

Ok, changed.

https://github.com/llvm/llvm-project/pull/108318


More information about the Mlir-commits mailing list