[Mlir-commits] [mlir] [mlir][SCF] Allow using a custom operation to generate loops with `mlir::tileUsingSCF`. (PR #159660)

Quinn Dawkins llvmlistbot at llvm.org
Mon Sep 22 13:17:06 PDT 2025


================
@@ -468,6 +469,153 @@ transform::TestTileAndFuseOuterParallelPartialReductionOp::apply(
                         : DiagnosedSilenceableFailure::success();
 }
 
+//===----------------------------------------------------------------------===//
+// TestTileAndFuseOuterParallelPartialReduction
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure transform::TestTileUsingCustomLoopOp::apply(
+    TransformRewriter &transformRewriter, TransformResults &transformResults,
+    TransformState &state) {
+  auto target =
+      dyn_cast<TilingInterface>(*state.getPayloadOps(getRootOp()).begin());
+  if (!target) {
+    emitOpError("expected root operation to implement `TilingInterface`");
+    return DiagnosedSilenceableFailure::definiteFailure();
+  }
+
+  OpFoldResult oneOfr = transformRewriter.getIndexAttr(1);
+
+  scf::SCFTilingOptions::GenerateLoopHeaderFn loopHeaderFn =
+      [&](RewriterBase &rewriter, Location loc, ArrayRef<Range> loopRanges,
+          ArrayRef<OpFoldResult> givenTileSizes,
+          ValueRange outerDestinationTensors)
+      -> FailureOr<scf::SCFTilingOptions::CustomLoopHeaderInfo> {
+    // Check that the strides are all 1 (to make it easier in the test).
+    if (llvm::any_of(loopRanges, [](Range r) {
+          return !isConstantIntValue(r.stride, 1);
+        })) {
+      return emitOpError("unable to handle loop ranges with strides != 1");
+    }
+    // For testing disallow any of the tile sizes being 0.
+    if (llvm::any_of(givenTileSizes, isZeroInteger)) {
+      return emitOpError("unhandled case of zero tile size");
+    }
+    // For testing, only handle tensor tiling.
+    if (outerDestinationTensors.empty()) {
+      return emitOpError("expected destination tensors");
+    }
+
+    // Compute the number of iterations for each of the loops.
+    AffineExpr s0, s1, s2;
+    bindSymbols(rewriter.getContext(), s0, s1, s2);
+    AffineExpr numItersExpr = (s1 - s0).ceilDiv(s2); // (ub - lb) / tileSize
+
+    SmallVector<OpFoldResult> allNumIters;
+    allNumIters.reserve(loopRanges.size());
+    for (auto [loopRange, tileSize] :
+         llvm::zip_equal(loopRanges, givenTileSizes)) {
----------------
qedawkins wrote:

Right but it's an assert. Failing gracefully would be better unless this is actually guaranteed.

https://github.com/llvm/llvm-project/pull/159660


More information about the Mlir-commits mailing list