[Mlir-commits] [mlir] [mlir][SCF] Add support for peeling the first iteration out of the loop (PR #74015)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Nov 30 16:53:23 PST 2023
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-scf
Author: Vivian (yzhang93)
<details>
<summary>Changes</summary>
---
Full diff: https://github.com/llvm/llvm-project/pull/74015.diff
7 Files Affected:
- (modified) mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td (+14-5)
- (modified) mlir/include/mlir/Dialect/SCF/Transforms/Passes.td (+3)
- (modified) mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h (+5)
- (modified) mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp (+17-6)
- (modified) mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp (+93-14)
- (added) mlir/test/Dialect/SCF/for-loop-peeling-front.mlir (+106)
- (modified) mlir/test/Dialect/SCF/transform-ops.mlir (+30)
``````````diff
diff --git a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
index 14df7e23a430fb1..a9c70c9e3433f23 100644
--- a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
+++ b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
@@ -106,9 +106,11 @@ def LoopOutlineOp : Op<Transform_Dialect, "loop.outline",
def LoopPeelOp : Op<Transform_Dialect, "loop.peel",
[FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
TransformOpInterface, TransformEachOpTrait]> {
- let summary = "Peels the last iteration of the loop";
+ let summary = "Peels the first or last iteration of the loop";
let description = [{
- Updates the given loop so that its step evenly divides its range and puts
+ Rewrite the given loop with a main loop and a partial (first or last) loop.
+ When the `peelFront` option is set as true, the first iteration is peeled off.
+ Otherwise, updates the given loop so that its step evenly divides its range and puts
the remaining iteration into a separate loop or a conditional.
In the absence of sufficient static information, this op may peel a loop,
@@ -118,9 +120,15 @@ def LoopPeelOp : Op<Transform_Dialect, "loop.peel",
This operation ignores non-scf::ForOp ops and drops them in the return.
- This operation returns two scf::ForOp Ops, with the first Op satisfying
- the postcondition: "the loop trip count is divisible by the step". The
- second loop Op contains the remaining iteration. Note that even though the
+ When `peelFront` is true, this operation returns two scf::ForOp Ops, the
+ first scf::ForOp corresponds to the first iteration of the loop which can
+ be canonicalized away in the following optimization. The second loop Op
+ contains the remaining iteration, and the new lower bound is the original
+ lower bound plus the number of steps.
+
+ For the other case, this operation returns two scf::ForOp Ops, with the first
+ Op satisfying the postcondition: "the loop trip count is divisible by the step".
+ The second loop Op contains the remaining iteration. Note that even though the
Payload IR modification may be performed in-place, this operation consumes
the operand handle and produces a new one.
@@ -131,6 +139,7 @@ def LoopPeelOp : Op<Transform_Dialect, "loop.peel",
let arguments =
(ins Transform_ScfForOp:$target,
+ DefaultValuedAttr<BoolAttr, "false">:$peel_front,
DefaultValuedAttr<BoolAttr, "false">:$fail_if_already_divisible);
let results = (outs TransformHandleTypeInterface:$peeled_loop,
TransformHandleTypeInterface:$remainder_loop);
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td b/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td
index bbc673f44977ac9..350611ad86873d0 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td
@@ -32,6 +32,9 @@ def SCFForLoopPeeling : Pass<"scf-for-loop-peeling"> {
let summary = "Peel `for` loops at their upper bounds.";
let constructor = "mlir::createForLoopPeelingPass()";
let options = [
+ Option<"peelFront", "peel-front", "bool",
+ /*default=*/"false",
+ "Peel the first iteration out of the loop.">,
Option<"skipPartial", "skip-partial", "bool",
/*default=*/"true",
"Do not peel loops inside of the last, partial iteration of another "
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h b/mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h
index 347beb9e4c64f8c..d736d049fd0932e 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h
@@ -81,6 +81,11 @@ void naivelyFuseParallelOps(Region ®ion);
LogicalResult peelForLoopAndSimplifyBounds(RewriterBase &rewriter, ForOp forOp,
scf::ForOp &partialIteration);
+/// Peel the first iteration out of the scf.for loop. If there is only one
+/// iteration, return the original loop.
+LogicalResult peelFirstIterationForLoop(RewriterBase &rewriter, ForOp forOp,
+ scf::ForOp &partialIteration);
+
/// Tile a parallel loop of the form
/// scf.parallel (%i0, %i1) = (%arg0, %arg1) to (%arg2, %arg3)
/// step (%arg4, %arg5)
diff --git a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
index 62370604142cd5b..1fdd6f1a7015a6d 100644
--- a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
+++ b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
@@ -193,13 +193,24 @@ transform::LoopPeelOp::applyToOne(transform::TransformRewriter &rewriter,
transform::ApplyToEachResultList &results,
transform::TransformState &state) {
scf::ForOp result;
- LogicalResult status =
- scf::peelForLoopAndSimplifyBounds(rewriter, target, result);
- if (failed(status)) {
- DiagnosedSilenceableFailure diag = emitSilenceableError()
- << "failed to peel";
- return diag;
+ if (getPeelFront()) {
+ LogicalResult status =
+ scf::peelFirstIterationForLoop(rewriter, target, result);
+ if (failed(status)) {
+ DiagnosedSilenceableFailure diag =
+ emitSilenceableError() << "failed to peel the first iteration";
+ return diag;
+ }
+ } else {
+ LogicalResult status =
+ scf::peelForLoopAndSimplifyBounds(rewriter, target, result);
+ if (failed(status)) {
+ DiagnosedSilenceableFailure diag = emitSilenceableError()
+ << "failed to peel the last iteration";
+ return diag;
+ }
}
+
results.push_back(target);
results.push_back(result);
diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp
index 93b81794ec2652e..9198cdf54328c72 100644
--- a/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp
@@ -191,6 +191,23 @@ static void rewriteAffineOpAfterPeeling(RewriterBase &rewriter, ForOp forOp,
});
}
+static void removeAffineOpInsideFirstIteration(RewriterBase &rewriter,
+ ForOp partialIteration,
+ Value previousUb) {
+ Value partialIv = partialIteration.getInductionVar();
+ Value step = partialIteration.getStep();
+
+ partialIteration.walk([&](Operation *affineOp) {
+ if (!isa<AffineMinOp, AffineMaxOp>(affineOp))
+ return WalkResult::advance();
+
+ // Hack to reuse the existing utils when ub - iv >= step.
+ (void)scf::rewritePeeledMinMaxOp(rewriter, affineOp, partialIv, previousUb,
+ step, /*insideLoop=*/true);
+ return WalkResult::advance();
+ });
+}
+
LogicalResult mlir::scf::peelForLoopAndSimplifyBounds(RewriterBase &rewriter,
ForOp forOp,
ForOp &partialIteration) {
@@ -205,32 +222,92 @@ LogicalResult mlir::scf::peelForLoopAndSimplifyBounds(RewriterBase &rewriter,
return success();
}
+LogicalResult mlir::scf::peelFirstIterationForLoop(RewriterBase &b, ForOp forOp,
+ ForOp &firstIteration) {
+ RewriterBase::InsertionGuard guard(b);
+ auto lbInt = getConstantIntValue(forOp.getLowerBound());
+ auto ubInt = getConstantIntValue(forOp.getUpperBound());
+ auto stepInt = getConstantIntValue(forOp.getStep());
+
+ // Peeling is not needed if there is one or less iteration.
+ if (lbInt && ubInt && stepInt && (*ubInt - *lbInt) / *stepInt <= 1)
+ return success();
+
+ // Slow path: Examine the ops that define lb, ub and step.
+ AffineExpr sym0, sym1, sym2;
+ bindSymbols(b.getContext(), sym0, sym1, sym2);
+ SmallVector<Value> operands{forOp.getLowerBound(), forOp.getUpperBound(),
+ forOp.getStep()};
+ AffineMap map = AffineMap::get(0, 3, {(sym1 - sym0) % sym2});
+ affine::fullyComposeAffineMapAndOperands(&map, &operands);
+ if (auto constExpr = dyn_cast<AffineConstantExpr>(map.getResult(0)))
+ if (constExpr.getValue() == 0)
+ return failure();
+
+ // New lower bound for main loop: %lb + %step
+ auto ubMap = AffineMap::get(0, 3, {sym0 + sym2});
+ b.setInsertionPoint(forOp);
+ auto loc = forOp.getLoc();
+ Value splitBound = b.createOrFold<AffineApplyOp>(
+ loc, ubMap,
+ ValueRange{forOp.getLowerBound(), forOp.getUpperBound(),
+ forOp.getStep()});
+
+ // Peel the first iteration.
+ b.setInsertionPoint(forOp);
+ firstIteration = cast<ForOp>(b.clone(*forOp.getOperation()));
+ firstIteration.getUpperBoundMutable().assign(splitBound);
+
+ // Update main loop with new lower bound.
+ forOp.getInitArgsMutable().assign(firstIteration->getResults());
+ b.updateRootInPlace(
+ forOp, [&]() { forOp.getLowerBoundMutable().assign(splitBound); });
+
+ return success();
+}
+
static constexpr char kPeeledLoopLabel[] = "__peeled_loop__";
static constexpr char kPartialIterationLabel[] = "__partial_iteration__";
namespace {
struct ForLoopPeelingPattern : public OpRewritePattern<ForOp> {
- ForLoopPeelingPattern(MLIRContext *ctx, bool skipPartial)
- : OpRewritePattern<ForOp>(ctx), skipPartial(skipPartial) {}
+ ForLoopPeelingPattern(MLIRContext *ctx, bool peelFront, bool skipPartial)
+ : OpRewritePattern<ForOp>(ctx), peelFront(peelFront),
+ skipPartial(skipPartial) {}
LogicalResult matchAndRewrite(ForOp forOp,
PatternRewriter &rewriter) const override {
// Do not peel already peeled loops.
if (forOp->hasAttr(kPeeledLoopLabel))
return failure();
- if (skipPartial) {
- // No peeling of loops inside the partial iteration of another peeled
- // loop.
- Operation *op = forOp.getOperation();
- while ((op = op->getParentOfType<scf::ForOp>())) {
- if (op->hasAttr(kPartialIterationLabel))
- return failure();
+
+ scf::ForOp partialIteration;
+ // The case for peeling the first iteration of the loop.
+ if (peelFront) {
+ if (failed(
+ peelFirstIterationForLoop(rewriter, forOp, partialIteration))) {
+ return failure();
}
+ // Remove affine.min op in the first (partial) iteration.
+ removeAffineOpInsideFirstIteration(rewriter, partialIteration,
+ forOp.getUpperBound());
+
+ } else {
+ if (skipPartial) {
+ // No peeling of loops inside the partial iteration of another peeled
+ // loop.
+ Operation *op = forOp.getOperation();
+ while ((op = op->getParentOfType<scf::ForOp>())) {
+ if (op->hasAttr(kPartialIterationLabel))
+ return failure();
+ }
+ }
+ // Apply loop peeling.
+ if (failed(
+ peelForLoopAndSimplifyBounds(rewriter, forOp, partialIteration)))
+ return failure();
}
- // Apply loop peeling.
- scf::ForOp partialIteration;
- if (failed(peelForLoopAndSimplifyBounds(rewriter, forOp, partialIteration)))
- return failure();
+
// Apply label, so that the same loop is not rewritten a second time.
rewriter.updateRootInPlace(partialIteration, [&]() {
partialIteration->setAttr(kPeeledLoopLabel, rewriter.getUnitAttr());
@@ -242,6 +319,8 @@ struct ForLoopPeelingPattern : public OpRewritePattern<ForOp> {
return success();
}
+ bool peelFront;
+
/// If set to true, loops inside partial iterations of another peeled loop
/// are not peeled. This reduces the size of the generated code. Partial
/// iterations are not usually performance critical.
@@ -273,7 +352,7 @@ struct ForLoopPeeling : public impl::SCFForLoopPeelingBase<ForLoopPeeling> {
auto *parentOp = getOperation();
MLIRContext *ctx = parentOp->getContext();
RewritePatternSet patterns(ctx);
- patterns.add<ForLoopPeelingPattern>(ctx, skipPartial);
+ patterns.add<ForLoopPeelingPattern>(ctx, peelFront, skipPartial);
(void)applyPatternsAndFoldGreedily(parentOp, std::move(patterns));
// Drop the markers.
diff --git a/mlir/test/Dialect/SCF/for-loop-peeling-front.mlir b/mlir/test/Dialect/SCF/for-loop-peeling-front.mlir
new file mode 100644
index 000000000000000..65214f55f241f94
--- /dev/null
+++ b/mlir/test/Dialect/SCF/for-loop-peeling-front.mlir
@@ -0,0 +1,106 @@
+// RUN: mlir-opt %s -scf-for-loop-peeling=peel-front=true -split-input-file | FileCheck %s
+
+// CHECK-DAG: #[[MAP:.*]] = affine_map<(d0, d1)[s0] -> (4, d0 - d1)>
+// CHECK: func @fully_static_bounds(
+// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
+// CHECK-DAG: %[[C4_I32:.*]] = arith.constant 4 : i32
+// CHECK-DAG: %[[C0_I32:.*]] = arith.constant 0 : i32
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C17:.*]] = arith.constant 17 : index
+// CHECK: %[[FIRST:.*]] = scf.for %[[IV:.*]] = %[[C0]] to %[[C4]]
+// CHECK-SAME: step %[[C4]] iter_args(%[[ACC:.*]] = %[[C0_I32]]) -> (i32) {
+// CHECK: %[[INIT:.*]] = arith.addi %[[ACC]], %[[C4_I32]] : i32
+// CHECK: scf.yield %[[INIT]]
+// CHECK: }
+// CHECK: %[[RESULT:.*]] = scf.for %[[IV:.*]] = %[[C4]] to %[[C17]]
+// CHECK-SAME: step %[[C4]] iter_args(%[[ACC:.*]] = %[[FIRST]]) -> (i32) {
+// CHECK: %[[MIN:.*]] = affine.min #[[MAP]](%[[C17]], %[[IV]])[%[[C4]]]
+// CHECK: %[[CAST:.*]] = arith.index_cast %[[MIN]] : index to i32
+// CHECK: %[[ADD:.*]] = arith.addi %[[ACC]], %[[CAST]] : i32
+// CHECK: scf.yield %[[ADD]]
+// CHECK: }
+// CHECK: return %[[RESULT]]
+#map = affine_map<(d0, d1)[s0] -> (s0, d0 - d1)>
+func.func @fully_static_bounds() -> i32 {
+ %c0_i32 = arith.constant 0 : i32
+ %lb = arith.constant 0 : index
+ %step = arith.constant 4 : index
+ %ub = arith.constant 17 : index
+ %r = scf.for %iv = %lb to %ub step %step iter_args(%arg = %c0_i32) -> i32 {
+ %s = affine.min #map(%ub, %iv)[%step]
+ %casted = arith.index_cast %s : index to i32
+ %0 = arith.addi %arg, %casted : i32
+ scf.yield %0 : i32
+ }
+ return %r : i32
+}
+
+// -----
+
+// CHECK-DAG: #[[MAP:.*]] = affine_map<(d0, d1)[s0] -> (4, d0 - d1)>
+// CHECK: func @no_loop_results(
+// CHECK-SAME: %[[UB:.*]]: index, %[[MEMREF:.*]]: memref<i32>
+// CHECK-DAG: %[[C4_I32:.*]] = arith.constant 4 : i32
+// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: scf.for %[[IV:.*]] = %[[C0]] to %[[C4]] step %[[C4]] {
+// CHECK: %[[LOAD:.*]] = memref.load %[[MEMREF]][]
+// CHECK: %[[ADD:.*]] = arith.addi %[[LOAD]], %[[C4_I32]] : i32
+// CHECK: memref.store %[[ADD]], %[[MEMREF]]
+// CHECK: }
+// CHECK: scf.for %[[IV2:.*]] = %[[C4]] to %[[UB]] step %[[C4]] {
+// CHECK: %[[REM:.*]] = affine.min #[[MAP]](%[[UB]], %[[IV2]])[%[[C4]]]
+// CHECK: %[[LOAD2:.*]] = memref.load %[[MEMREF]][]
+// CHECK: %[[CAST2:.*]] = arith.index_cast %[[REM]]
+// CHECK: %[[ADD2:.*]] = arith.addi %[[LOAD2]], %[[CAST2]]
+// CHECK: memref.store %[[ADD2]], %[[MEMREF]]
+// CHECK: }
+// CHECK: return
+#map = affine_map<(d0, d1)[s0] -> (s0, d0 - d1)>
+func.func @no_loop_results(%ub : index, %d : memref<i32>) {
+ %c0_i32 = arith.constant 0 : i32
+ %lb = arith.constant 0 : index
+ %step = arith.constant 4 : index
+ scf.for %iv = %lb to %ub step %step {
+ %s = affine.min #map(%ub, %iv)[%step]
+ %r = memref.load %d[] : memref<i32>
+ %casted = arith.index_cast %s : index to i32
+ %0 = arith.addi %r, %casted : i32
+ memref.store %0, %d[] : memref<i32>
+ }
+ return
+}
+
+// -----
+
+// CHECK-DAG: #[[MAP0:.*]] = affine_map<()[s0, s1, s2] -> (s0 + s2)>
+// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1)[s0] -> (s0, d0 - d1)>
+// CHECK: func @fully_dynamic_bounds(
+// CHECK-SAME: %[[LB:.*]]: index, %[[UB:.*]]: index, %[[STEP:.*]]: index
+// CHECK: %[[C0_I32:.*]] = arith.constant 0 : i32
+// CHECK: %[[NEW_UB:.*]] = affine.apply #[[MAP0]]()[%[[LB]], %[[UB]], %[[STEP]]]
+// CHECK: %[[FIRST:.*]] = scf.for %[[IV:.*]] = %[[LB]] to %[[NEW_UB]]
+// CHECK-SAME: step %[[STEP]] iter_args(%[[ACC:.*]] = %[[C0_I32]]) -> (i32) {
+// CHECK: %[[CAST:.*]] = arith.index_cast %[[STEP]] : index to i32
+// CHECK: %[[ADD:.*]] = arith.addi %[[ACC]], %[[CAST]] : i32
+// CHECK: scf.yield %[[ADD]]
+// CHECK: }
+// CHECK: %[[RESULT:.*]] = scf.for %[[IV2:.*]] = %[[NEW_UB]] to %[[UB]]
+// CHECK-SAME: step %[[STEP]] iter_args(%[[ACC2:.*]] = %[[FIRST]]) -> (i32) {
+// CHECK: %[[REM:.*]] = affine.min #[[MAP1]](%[[UB]], %[[IV2]])[%[[STEP]]]
+// CHECK: %[[CAST2:.*]] = arith.index_cast %[[REM]]
+// CHECK: %[[ADD2:.*]] = arith.addi %[[ACC2]], %[[CAST2]]
+// CHECK: scf.yield %[[ADD2]]
+// CHECK: }
+// CHECK: return %[[RESULT]]
+#map = affine_map<(d0, d1)[s0] -> (s0, d0 - d1)>
+func.func @fully_dynamic_bounds(%lb : index, %ub: index, %step: index) -> i32 {
+ %c0 = arith.constant 0 : i32
+ %r = scf.for %iv = %lb to %ub step %step iter_args(%arg = %c0) -> i32 {
+ %s = affine.min #map(%ub, %iv)[%step]
+ %casted = arith.index_cast %s : index to i32
+ %0 = arith.addi %arg, %casted : i32
+ scf.yield %0 : i32
+ }
+ return %r : i32
+}
diff --git a/mlir/test/Dialect/SCF/transform-ops.mlir b/mlir/test/Dialect/SCF/transform-ops.mlir
index 74601cf5b34a178..93ebf67f8b71333 100644
--- a/mlir/test/Dialect/SCF/transform-ops.mlir
+++ b/mlir/test/Dialect/SCF/transform-ops.mlir
@@ -77,6 +77,36 @@ module attributes {transform.with_named_sequence} {
// -----
+// CHECK-LABEL: @loop_peel_first_iter_op
+func.func @loop_peel_first_iter_op() {
+ // CHECK: %[[C0:.+]] = arith.constant 0
+ // CHECK: %[[C41:.+]] = arith.constant 41
+ // CHECK: %[[C5:.+]] = arith.constant 5
+ // CHECK: %[[C5_0:.+]] = arith.constant 5
+ // CHECK: scf.for %{{.+}} = %[[C0]] to %[[C5_0]] step %[[C5]]
+ // CHECK: arith.addi
+ // CHECK: scf.for %{{.+}} = %[[C5_0]] to %[[C41]] step %[[C5]]
+ // CHECK: arith.addi
+ %0 = arith.constant 0 : index
+ %1 = arith.constant 41 : index
+ %2 = arith.constant 5 : index
+ scf.for %i = %0 to %1 step %2 {
+ arith.addi %i, %i : index
+ }
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["arith.addi"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %1 = transform.get_parent_op %0 {op_name = "scf.for"} : (!transform.any_op) -> !transform.op<"scf.for">
+ %main_loop, %remainder = transform.loop.peel %1 {peel_front = true} : (!transform.op<"scf.for">) -> (!transform.op<"scf.for">, !transform.op<"scf.for">)
+ transform.yield
+ }
+}
+
+// -----
+
func.func @loop_pipeline_op(%A: memref<?xf32>, %result: memref<?xf32>) {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
``````````
</details>
https://github.com/llvm/llvm-project/pull/74015
More information about the Mlir-commits
mailing list