[Mlir-commits] [mlir] [mlir][scf] Extend consumer fuse to single nested `scf.for` (PR #94190)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Sep 10 09:07:05 PDT 2024


================
@@ -1754,79 +1723,105 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
   if (failed(tileAndFuseResult)) {
     return failure();
   }
-  rewriter.replaceAllUsesWith(
-      tileAndFuseResult->tiledOps[0]->getOperand(operandNumber),
-      clonedInsertSliceOp.getSource());
-
-  // 8 - Extract offset/sizes/strides required to create the
-  // tensor.insert_slice/parallel_insert_slice for each result of the consumer.
-  SmallVector<OpFoldResult> offsets = ossSliceOp.getMixedOffsets();
-  SmallVector<OpFoldResult> sizes = ossSliceOp.getMixedSizes();
-  SmallVector<OpFoldResult> strides = ossSliceOp.getMixedStrides();
-
-  // 9. Check all insert stride is 1.
-  if (llvm::any_of(strides, [](OpFoldResult stride) {
-        return !isConstantIntValue(stride, 1);
-      })) {
-    return rewriter.notifyMatchFailure(
-        candidateSliceOp, "containingOp's result yield with stride");
-  }
+  auto tiledConsumerOp = cast<TilingInterface>(tileAndFuseResult->tiledOps[0]);
+  rewriter.replaceAllUsesWith(tiledConsumerOp->getOperand(operandNumber),
+                              clonedInsertSliceOp.getSource());
 
-  // 10. Try to get iter domain position from input position.
-  SmallVector<OpFoldResult> iterDomainOffsets, iterDomainSizes;
-  if (failed(clonedConsumerOp.getIterationDomainTileFromOperandTile(
-          rewriter, operandNumber, offsets, sizes, iterDomainOffsets,
-          iterDomainSizes))) {
-    return rewriter.notifyMatchFailure(
-        clonedConsumerOp, "can't get iter domain position from input position");
-  }
+  // 7 - Reconstruct [nested] loop with new inits.
+  YieldTiledValuesFn newYieldValuesFn =
+      [&](RewriterBase &innerRewriter, Location loc, ValueRange /*ivs*/,
+          ValueRange newRegionIterArgs, SmallVector<Value> &tiledResult,
+          SmallVector<SmallVector<OpFoldResult>> &tiledOffset,
+          SmallVector<SmallVector<OpFoldResult>> &tiledSizes) -> LogicalResult {
+    OpBuilder::InsertionGuard g(innerRewriter);
+    // 9. Set inner insertPoint right before tiled consumer op.
----------------
MaheshRavishankar wrote:

Nit: this is 8. 

https://github.com/llvm/llvm-project/pull/94190


More information about the Mlir-commits mailing list