[Mlir-commits] [mlir] [flang] [mlir][TilingInterface] Use `LoopLikeOpInterface` in tiling using SCF to unify tiling with `scf.for` and `scf.forall`. (PR #77874)

Nicolas Vasilache llvmlistbot at llvm.org
Thu Jan 18 00:35:10 PST 2024


================
@@ -584,6 +588,63 @@ ForOp::replaceWithAdditionalYields(RewriterBase &rewriter,
   return cast<LoopLikeOpInterface>(newLoop.getOperation());
 }
 
+FailureOr<LoopLikeOpInterface> ForOp::yieldTiledValuesAndReplace(
+    RewriterBase &rewriter, ValueRange newInitOperands,
+    const YieldTiledValuesFn &yieldTiledValuesFn) {
+  OpBuilder::InsertionGuard g(rewriter);
+  rewriter.setInsertionPoint(getOperation());
+
+  auto inits = llvm::to_vector(getInitArgs());
+  inits.append(newInitOperands.begin(), newInitOperands.end());
+  auto newLoop = rewriter.create<ForOp>(
+      getLoc(), getLowerBound(), getUpperBound(), getStep(), inits,
+      [](OpBuilder &, Location, Value, ValueRange) {});
+
+  // Move the loop body to the new op.
+  rewriter.mergeBlocks(getBody(), newLoop.getBody(),
+                       newLoop.getBody()->getArguments().take_front(
+                           getBody()->getNumArguments()));
+
+  auto yieldOp = cast<scf::YieldOp>(newLoop.getBody()->getTerminator());
+  rewriter.setInsertionPoint(yieldOp);
+
+  SmallVector<Value> tiledValues;
+  SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes;
+  ValueRange newRegionIterArgs =
+      newLoop.getRegionIterArgs().take_back(newInitOperands.size());
+  if (failed(yieldTiledValuesFn(rewriter, getLoc(), newLoop.getInductionVar(),
+                                newRegionIterArgs, tiledValues, resultOffsets,
+                                resultSizes))) {
+    return rewriter.notifyMatchFailure(getOperation(),
+                                       "failed to get tiled values");
+  }
+
+  if (tiledValues.size() != resultOffsets.size() ||
----------------
nicolasvasilache wrote:

In case my comment just above is relevant, this has the same issue.
Consider making the writer of `yieldTiledValuesFn` guarantee this invariant and use an assert here instead?

https://github.com/llvm/llvm-project/pull/77874


More information about the Mlir-commits mailing list