[Mlir-commits] [mlir] [mlir][SCF] Allow using a custom operation to generate loops with `mlir::tileUsingSCF`. (PR #159660)
Quinn Dawkins
llvmlistbot at llvm.org
Mon Sep 22 13:17:06 PDT 2025
================
@@ -468,6 +469,153 @@ transform::TestTileAndFuseOuterParallelPartialReductionOp::apply(
: DiagnosedSilenceableFailure::success();
}
+//===----------------------------------------------------------------------===//
+// TestTileAndFuseOuterParallelPartialReduction
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure transform::TestTileUsingCustomLoopOp::apply(
+ TransformRewriter &transformRewriter, TransformResults &transformResults,
+ TransformState &state) {
+ auto target =
+ dyn_cast<TilingInterface>(*state.getPayloadOps(getRootOp()).begin());
+ if (!target) {
+ emitOpError("expected root operation to implement `TilingInterface`");
+ return DiagnosedSilenceableFailure::definiteFailure();
+ }
+
+ OpFoldResult oneOfr = transformRewriter.getIndexAttr(1);
+
+ scf::SCFTilingOptions::GenerateLoopHeaderFn loopHeaderFn =
+ [&](RewriterBase &rewriter, Location loc, ArrayRef<Range> loopRanges,
+ ArrayRef<OpFoldResult> givenTileSizes,
+ ValueRange outerDestinationTensors)
+ -> FailureOr<scf::SCFTilingOptions::CustomLoopHeaderInfo> {
+ // Check that the strides are all 1 (to make it easier in the test).
+ if (llvm::any_of(loopRanges, [](Range r) {
+ return !isConstantIntValue(r.stride, 1);
+ })) {
+ return emitOpError("unable to handle loop ranges with strides != 1");
+ }
+ // For testing disallow any of the tile sizes being 0.
+ if (llvm::any_of(givenTileSizes, isZeroInteger)) {
+ return emitOpError("unhandled case of zero tile size");
+ }
+ // For testing, only handle tensor tiling.
+ if (outerDestinationTensors.empty()) {
+ return emitOpError("expected destination tensors");
+ }
+
+ // Compute the number of iterations for each of the loops.
+ AffineExpr s0, s1, s2;
+ bindSymbols(rewriter.getContext(), s0, s1, s2);
+ AffineExpr numItersExpr = (s1 - s0).ceilDiv(s2); // (ub - lb) / tileSize
+
+ SmallVector<OpFoldResult> allNumIters;
+ allNumIters.reserve(loopRanges.size());
+ for (auto [loopRange, tileSize] :
+ llvm::zip_equal(loopRanges, givenTileSizes)) {
----------------
qedawkins wrote:
Right but it's an assert. Failing gracefully would be better unless this is actually guaranteed.
https://github.com/llvm/llvm-project/pull/159660
More information about the Mlir-commits
mailing list