[Mlir-commits] [mlir] [mlir][SCF] Add support for peeling the first iteration out of the loop (PR #74015)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Dec 11 20:19:21 PST 2023


https://github.com/yzhang93 updated https://github.com/llvm/llvm-project/pull/74015

>From a3c8551acf5c94ef855f27c6a2ae4b97592ef703 Mon Sep 17 00:00:00 2001
From: yzhang93 <zhyuhang88 at gmail.com>
Date: Tue, 28 Nov 2023 22:00:02 -0800
Subject: [PATCH 1/2] Add support for peeling the first iteration out of the
 loop

There is a use case that we need to peel the first iteration out of the for loop
so that the peeled forOp can be canonicalized away and the fillOp can be fused
into the inner forall loop. For example, we have a nested loop

```
  linalg.fill ins(...) outs(...)
  scf.for %arg = %lb to %ub step %step
    scf.forall ...
```

After the peeling transform, we'll get

```
  scf.forall ...
    linalg.fill ins(...) outs(...)
  scf.for %arg = %(lb + step) to %ub step %step
    scf.forall ...
```

This patch makes the most use of the existing peeling functions and adds support
for peeling the first iteration out of the loop.
---
 .../SCF/TransformOps/SCFTransformOps.td       |  19 +++-
 .../mlir/Dialect/SCF/Transforms/Passes.td     |   3 +
 .../mlir/Dialect/SCF/Transforms/Transforms.h  |   5 +
 .../SCF/TransformOps/SCFTransformOps.cpp      |  23 +++-
 .../SCF/Transforms/LoopSpecialization.cpp     |  99 +++++++++++++---
 .../Dialect/SCF/for-loop-peeling-front.mlir   | 106 ++++++++++++++++++
 mlir/test/Dialect/SCF/transform-ops.mlir      |  30 +++++
 7 files changed, 260 insertions(+), 25 deletions(-)
 create mode 100644 mlir/test/Dialect/SCF/for-loop-peeling-front.mlir

diff --git a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
index 14df7e23a430f..a9c70c9e3433f 100644
--- a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
+++ b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
@@ -106,9 +106,11 @@ def LoopOutlineOp : Op<Transform_Dialect, "loop.outline",
 def LoopPeelOp : Op<Transform_Dialect, "loop.peel",
     [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
      TransformOpInterface, TransformEachOpTrait]> {
-  let summary = "Peels the last iteration of the loop";
+  let summary = "Peels the first or last iteration of the loop";
   let description = [{
-     Updates the given loop so that its step evenly divides its range and puts
+     Rewrite the given loop with a main loop and a partial (first or last) loop.
+     When the `peelFront` option is set as true, the first iteration is peeled off.
+     Otherwise, updates the given loop so that its step evenly divides its range and puts
      the remaining iteration into a separate loop or a conditional.
 
      In the absence of sufficient static information, this op may peel a loop,
@@ -118,9 +120,15 @@ def LoopPeelOp : Op<Transform_Dialect, "loop.peel",
 
      This operation ignores non-scf::ForOp ops and drops them in the return.
 
-     This operation returns two scf::ForOp Ops, with the first Op satisfying
-     the postcondition: "the loop trip count is divisible by the step". The
-     second loop Op contains the remaining iteration. Note that even though the
+     When `peelFront` is true, this operation returns two scf::ForOp Ops, the
+     first scf::ForOp corresponds to the first iteration of the loop which can
+     be canonicalized away in the following optimization. The second loop Op
+     contains the remaining iteration, and the new lower bound is the original
+     lower bound plus the number of steps.
+
+     For the other case, this operation returns two scf::ForOp Ops, with the first
+     Op satisfying the postcondition: "the loop trip count is divisible by the step".
+     The second loop Op contains the remaining iteration. Note that even though the
      Payload IR modification may be performed in-place, this operation consumes
      the operand handle and produces a new one.
 
@@ -131,6 +139,7 @@ def LoopPeelOp : Op<Transform_Dialect, "loop.peel",
 
   let arguments =
       (ins Transform_ScfForOp:$target,
+           DefaultValuedAttr<BoolAttr, "false">:$peel_front,
            DefaultValuedAttr<BoolAttr, "false">:$fail_if_already_divisible);
   let results = (outs TransformHandleTypeInterface:$peeled_loop,
                       TransformHandleTypeInterface:$remainder_loop);
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td b/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td
index bbc673f44977a..350611ad86873 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td
@@ -32,6 +32,9 @@ def SCFForLoopPeeling : Pass<"scf-for-loop-peeling"> {
   let summary = "Peel `for` loops at their upper bounds.";
   let constructor = "mlir::createForLoopPeelingPass()";
   let options = [
+    Option<"peelFront", "peel-front", "bool",
+           /*default=*/"false",
+           "Peel the first iteration out of the loop.">,
     Option<"skipPartial", "skip-partial", "bool",
            /*default=*/"true",
            "Do not peel loops inside of the last, partial iteration of another "
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h b/mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h
index 347beb9e4c64f..d736d049fd093 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h
@@ -81,6 +81,11 @@ void naivelyFuseParallelOps(Region &region);
 LogicalResult peelForLoopAndSimplifyBounds(RewriterBase &rewriter, ForOp forOp,
                                            scf::ForOp &partialIteration);
 
+/// Peel the first iteration out of the scf.for loop. If there is only one
+/// iteration, return the original loop.
+LogicalResult peelFirstIterationForLoop(RewriterBase &rewriter, ForOp forOp,
+                                        scf::ForOp &partialIteration);
+
 /// Tile a parallel loop of the form
 ///   scf.parallel (%i0, %i1) = (%arg0, %arg1) to (%arg2, %arg3)
 ///                                             step (%arg4, %arg5)
diff --git a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
index 62370604142cd..1fdd6f1a7015a 100644
--- a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
+++ b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
@@ -193,13 +193,24 @@ transform::LoopPeelOp::applyToOne(transform::TransformRewriter &rewriter,
                                   transform::ApplyToEachResultList &results,
                                   transform::TransformState &state) {
   scf::ForOp result;
-  LogicalResult status =
-      scf::peelForLoopAndSimplifyBounds(rewriter, target, result);
-  if (failed(status)) {
-    DiagnosedSilenceableFailure diag = emitSilenceableError()
-                                       << "failed to peel";
-    return diag;
+  if (getPeelFront()) {
+    LogicalResult status =
+        scf::peelFirstIterationForLoop(rewriter, target, result);
+    if (failed(status)) {
+      DiagnosedSilenceableFailure diag =
+          emitSilenceableError() << "failed to peel the first iteration";
+      return diag;
+    }
+  } else {
+    LogicalResult status =
+        scf::peelForLoopAndSimplifyBounds(rewriter, target, result);
+    if (failed(status)) {
+      DiagnosedSilenceableFailure diag = emitSilenceableError()
+                                         << "failed to peel the last iteration";
+      return diag;
+    }
   }
+
   results.push_back(target);
   results.push_back(result);
 
diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp
index 93b81794ec265..95e4f2e5b560a 100644
--- a/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp
@@ -191,6 +191,23 @@ static void rewriteAffineOpAfterPeeling(RewriterBase &rewriter, ForOp forOp,
   });
 }
 
+static void removeAffineOpInsideFirstIteration(RewriterBase &rewriter,
+                                               ForOp partialIteration,
+                                               Value previousUb) {
+  Value partialIv = partialIteration.getInductionVar();
+  Value step = partialIteration.getStep();
+
+  partialIteration.walk([&](Operation *affineOp) {
+    if (!isa<AffineMinOp, AffineMaxOp>(affineOp))
+      return WalkResult::advance();
+
+    // Hack to reuse the existing utils when ub - iv >= step.
+    (void)scf::rewritePeeledMinMaxOp(rewriter, affineOp, partialIv, previousUb,
+                                     step, /*insideLoop=*/true);
+    return WalkResult::advance();
+  });
+}
+
 LogicalResult mlir::scf::peelForLoopAndSimplifyBounds(RewriterBase &rewriter,
                                                       ForOp forOp,
                                                       ForOp &partialIteration) {
@@ -205,32 +222,84 @@ LogicalResult mlir::scf::peelForLoopAndSimplifyBounds(RewriterBase &rewriter,
   return success();
 }
 
+LogicalResult mlir::scf::peelFirstIterationForLoop(RewriterBase &b, ForOp forOp,
+                                                   ForOp &firstIteration) {
+  RewriterBase::InsertionGuard guard(b);
+  auto lbInt = getConstantIntValue(forOp.getLowerBound());
+  auto ubInt = getConstantIntValue(forOp.getUpperBound());
+  auto stepInt = getConstantIntValue(forOp.getStep());
+
+  // Peeling is not needed if there is one or less iteration.
+  if (lbInt && ubInt && stepInt && (*ubInt - *lbInt) / *stepInt <= 1)
+    return success();
+
+  AffineExpr sym0, sym1, sym2;
+  bindSymbols(b.getContext(), sym0, sym1, sym2);
+  SmallVector<Value> operands{forOp.getLowerBound(), forOp.getUpperBound(),
+                              forOp.getStep()};
+
+  // New lower bound for main loop: %lb + %step
+  auto ubMap = AffineMap::get(0, 2, {sym0 + sym1});
+  b.setInsertionPoint(forOp);
+  auto loc = forOp.getLoc();
+  Value splitBound = b.createOrFold<AffineApplyOp>(
+      loc, ubMap, ValueRange{forOp.getLowerBound(), forOp.getStep()});
+
+  // Peel the first iteration.
+  b.setInsertionPoint(forOp);
+  firstIteration = cast<ForOp>(b.clone(*forOp.getOperation()));
+  firstIteration.getUpperBoundMutable().assign(splitBound);
+
+  // Update main loop with new lower bound.
+  forOp.getInitArgsMutable().assign(firstIteration->getResults());
+  b.updateRootInPlace(
+      forOp, [&]() { forOp.getLowerBoundMutable().assign(splitBound); });
+
+  return success();
+}
+
 static constexpr char kPeeledLoopLabel[] = "__peeled_loop__";
 static constexpr char kPartialIterationLabel[] = "__partial_iteration__";
 
 namespace {
 struct ForLoopPeelingPattern : public OpRewritePattern<ForOp> {
-  ForLoopPeelingPattern(MLIRContext *ctx, bool skipPartial)
-      : OpRewritePattern<ForOp>(ctx), skipPartial(skipPartial) {}
+  ForLoopPeelingPattern(MLIRContext *ctx, bool peelFront, bool skipPartial)
+      : OpRewritePattern<ForOp>(ctx), peelFront(peelFront),
+        skipPartial(skipPartial) {}
 
   LogicalResult matchAndRewrite(ForOp forOp,
                                 PatternRewriter &rewriter) const override {
     // Do not peel already peeled loops.
     if (forOp->hasAttr(kPeeledLoopLabel))
       return failure();
-    if (skipPartial) {
-      // No peeling of loops inside the partial iteration of another peeled
-      // loop.
-      Operation *op = forOp.getOperation();
-      while ((op = op->getParentOfType<scf::ForOp>())) {
-        if (op->hasAttr(kPartialIterationLabel))
-          return failure();
+
+    scf::ForOp partialIteration;
+    // The case for peeling the first iteration of the loop.
+    if (peelFront) {
+      if (failed(
+              peelFirstIterationForLoop(rewriter, forOp, partialIteration))) {
+        return failure();
+      }
+      // Remove affine.min op in the first (partial) iteration.
+      removeAffineOpInsideFirstIteration(rewriter, partialIteration,
+                                         forOp.getUpperBound());
+
+    } else {
+      if (skipPartial) {
+        // No peeling of loops inside the partial iteration of another peeled
+        // loop.
+        Operation *op = forOp.getOperation();
+        while ((op = op->getParentOfType<scf::ForOp>())) {
+          if (op->hasAttr(kPartialIterationLabel))
+            return failure();
+        }
       }
+      // Apply loop peeling.
+      if (failed(
+              peelForLoopAndSimplifyBounds(rewriter, forOp, partialIteration)))
+        return failure();
     }
-    // Apply loop peeling.
-    scf::ForOp partialIteration;
-    if (failed(peelForLoopAndSimplifyBounds(rewriter, forOp, partialIteration)))
-      return failure();
+
     // Apply label, so that the same loop is not rewritten a second time.
     rewriter.updateRootInPlace(partialIteration, [&]() {
       partialIteration->setAttr(kPeeledLoopLabel, rewriter.getUnitAttr());
@@ -242,6 +311,8 @@ struct ForLoopPeelingPattern : public OpRewritePattern<ForOp> {
     return success();
   }
 
+  bool peelFront;
+
   /// If set to true, loops inside partial iterations of another peeled loop
   /// are not peeled. This reduces the size of the generated code. Partial
   /// iterations are not usually performance critical.
@@ -273,7 +344,7 @@ struct ForLoopPeeling : public impl::SCFForLoopPeelingBase<ForLoopPeeling> {
     auto *parentOp = getOperation();
     MLIRContext *ctx = parentOp->getContext();
     RewritePatternSet patterns(ctx);
-    patterns.add<ForLoopPeelingPattern>(ctx, skipPartial);
+    patterns.add<ForLoopPeelingPattern>(ctx, peelFront, skipPartial);
     (void)applyPatternsAndFoldGreedily(parentOp, std::move(patterns));
 
     // Drop the markers.
diff --git a/mlir/test/Dialect/SCF/for-loop-peeling-front.mlir b/mlir/test/Dialect/SCF/for-loop-peeling-front.mlir
new file mode 100644
index 0000000000000..f886b3522aef8
--- /dev/null
+++ b/mlir/test/Dialect/SCF/for-loop-peeling-front.mlir
@@ -0,0 +1,106 @@
+// RUN: mlir-opt %s -scf-for-loop-peeling=peel-front=true -split-input-file | FileCheck %s
+
+//  CHECK-DAG: #[[MAP:.*]] = affine_map<(d0, d1)[s0] -> (4, d0 - d1)>
+//      CHECK: func @fully_static_bounds(
+//  CHECK-DAG:   %[[C4:.*]] = arith.constant 4 : index
+//  CHECK-DAG:   %[[C4_I32:.*]] = arith.constant 4 : i32
+//  CHECK-DAG:   %[[C0_I32:.*]] = arith.constant 0 : i32
+//  CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : index
+//  CHECK-DAG:   %[[C17:.*]] = arith.constant 17 : index
+//      CHECK:   %[[FIRST:.*]] = scf.for %[[IV:.*]] = %[[C0]] to %[[C4]]
+// CHECK-SAME:       step %[[C4]] iter_args(%[[ACC:.*]] = %[[C0_I32]]) -> (i32) {
+//      CHECK:     %[[INIT:.*]] = arith.addi %[[ACC]], %[[C4_I32]] : i32
+//      CHECK:     scf.yield %[[INIT]]
+//      CHECK:   }
+//      CHECK:   %[[RESULT:.*]] = scf.for %[[IV:.*]] = %[[C4]] to %[[C17]]
+// CHECK-SAME:       step %[[C4]] iter_args(%[[ACC:.*]] = %[[FIRST]]) -> (i32) {
+//      CHECK:     %[[MIN:.*]] = affine.min #[[MAP]](%[[C17]], %[[IV]])[%[[C4]]]
+//      CHECK:     %[[CAST:.*]] = arith.index_cast %[[MIN]] : index to i32
+//      CHECK:     %[[ADD:.*]] = arith.addi %[[ACC]], %[[CAST]] : i32
+//      CHECK:     scf.yield %[[ADD]]
+//      CHECK:   }
+//      CHECK:   return %[[RESULT]]
+#map = affine_map<(d0, d1)[s0] -> (s0, d0 - d1)>
+func.func @fully_static_bounds() -> i32 {
+  %c0_i32 = arith.constant 0 : i32
+  %lb = arith.constant 0 : index
+  %step = arith.constant 4 : index
+  %ub = arith.constant 17 : index
+  %r = scf.for %iv = %lb to %ub step %step iter_args(%arg = %c0_i32) -> i32 {
+    %s = affine.min #map(%ub, %iv)[%step]
+    %casted = arith.index_cast %s : index to i32
+    %0 = arith.addi %arg, %casted : i32
+    scf.yield %0 : i32
+  }
+  return %r : i32
+}
+
+// -----
+
+//  CHECK-DAG: #[[MAP:.*]] = affine_map<(d0, d1)[s0] -> (4, d0 - d1)>
+//      CHECK: func @no_loop_results(
+// CHECK-SAME:     %[[UB:.*]]: index, %[[MEMREF:.*]]: memref<i32>
+//  CHECK-DAG:   %[[C4_I32:.*]] = arith.constant 4 : i32
+//  CHECK-DAG:   %[[C4:.*]] = arith.constant 4 : index
+//  CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : index
+//      CHECK:   scf.for %[[IV:.*]] = %[[C0]] to %[[C4]] step %[[C4]] {
+//      CHECK:     %[[LOAD:.*]] = memref.load %[[MEMREF]][]
+//      CHECK:     %[[ADD:.*]] = arith.addi %[[LOAD]], %[[C4_I32]] : i32
+//      CHECK:     memref.store %[[ADD]], %[[MEMREF]]
+//      CHECK:   }
+//      CHECK:   scf.for %[[IV2:.*]] = %[[C4]] to %[[UB]] step %[[C4]] {
+//      CHECK:     %[[REM:.*]] = affine.min #[[MAP]](%[[UB]], %[[IV2]])[%[[C4]]]
+//      CHECK:     %[[LOAD2:.*]] = memref.load %[[MEMREF]][]
+//      CHECK:     %[[CAST2:.*]] = arith.index_cast %[[REM]]
+//      CHECK:     %[[ADD2:.*]] = arith.addi %[[LOAD2]], %[[CAST2]]
+//      CHECK:     memref.store %[[ADD2]], %[[MEMREF]]
+//      CHECK:   }
+//      CHECK:   return
+#map = affine_map<(d0, d1)[s0] -> (s0, d0 - d1)>
+func.func @no_loop_results(%ub : index, %d : memref<i32>) {
+  %c0_i32 = arith.constant 0 : i32
+  %lb = arith.constant 0 : index
+  %step = arith.constant 4 : index
+  scf.for %iv = %lb to %ub step %step {
+    %s = affine.min #map(%ub, %iv)[%step]
+    %r = memref.load %d[] : memref<i32>
+    %casted = arith.index_cast %s : index to i32
+    %0 = arith.addi %r, %casted : i32
+    memref.store %0, %d[] : memref<i32>
+  }
+  return
+}
+
+// -----
+
+//  CHECK-DAG: #[[MAP0:.*]] = affine_map<()[s0, s1] -> (s0 + s1)>
+//  CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1)[s0] -> (s0, d0 - d1)>
+//      CHECK: func @fully_dynamic_bounds(
+// CHECK-SAME:     %[[LB:.*]]: index, %[[UB:.*]]: index, %[[STEP:.*]]: index
+//      CHECK:   %[[C0_I32:.*]] = arith.constant 0 : i32
+//      CHECK:   %[[NEW_UB:.*]] = affine.apply #[[MAP0]]()[%[[LB]], %[[STEP]]]
+//      CHECK:   %[[FIRST:.*]] = scf.for %[[IV:.*]] = %[[LB]] to %[[NEW_UB]]
+// CHECK-SAME:       step %[[STEP]] iter_args(%[[ACC:.*]] = %[[C0_I32]]) -> (i32) {
+//      CHECK:     %[[CAST:.*]] = arith.index_cast %[[STEP]] : index to i32
+//      CHECK:     %[[ADD:.*]] = arith.addi %[[ACC]], %[[CAST]] : i32
+//      CHECK:     scf.yield %[[ADD]]
+//      CHECK:   }
+//      CHECK:   %[[RESULT:.*]] = scf.for %[[IV2:.*]] = %[[NEW_UB]] to %[[UB]]
+// CHECK-SAME:       step %[[STEP]] iter_args(%[[ACC2:.*]] = %[[FIRST]]) -> (i32) {
+//      CHECK:     %[[REM:.*]] = affine.min #[[MAP1]](%[[UB]], %[[IV2]])[%[[STEP]]]
+//      CHECK:     %[[CAST2:.*]] = arith.index_cast %[[REM]]
+//      CHECK:     %[[ADD2:.*]] = arith.addi %[[ACC2]], %[[CAST2]]
+//      CHECK:     scf.yield %[[ADD2]]
+//      CHECK:   }
+//      CHECK:   return %[[RESULT]]
+#map = affine_map<(d0, d1)[s0] -> (s0, d0 - d1)>
+func.func @fully_dynamic_bounds(%lb : index, %ub: index, %step: index) -> i32 {
+  %c0 = arith.constant 0 : i32
+  %r = scf.for %iv = %lb to %ub step %step iter_args(%arg = %c0) -> i32 {
+    %s = affine.min #map(%ub, %iv)[%step]
+    %casted = arith.index_cast %s : index to i32
+    %0 = arith.addi %arg, %casted : i32
+    scf.yield %0 : i32
+  }
+  return %r : i32
+}
diff --git a/mlir/test/Dialect/SCF/transform-ops.mlir b/mlir/test/Dialect/SCF/transform-ops.mlir
index 74601cf5b34a1..93ebf67f8b713 100644
--- a/mlir/test/Dialect/SCF/transform-ops.mlir
+++ b/mlir/test/Dialect/SCF/transform-ops.mlir
@@ -77,6 +77,36 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
+// CHECK-LABEL: @loop_peel_first_iter_op
+func.func @loop_peel_first_iter_op() {
+  // CHECK: %[[C0:.+]] = arith.constant 0
+  // CHECK: %[[C41:.+]] = arith.constant 41
+  // CHECK: %[[C5:.+]] = arith.constant 5
+  // CHECK: %[[C5_0:.+]] = arith.constant 5
+  // CHECK: scf.for %{{.+}} = %[[C0]] to %[[C5_0]] step %[[C5]]
+  // CHECK:   arith.addi
+  // CHECK: scf.for %{{.+}} = %[[C5_0]] to %[[C41]] step %[[C5]]
+  // CHECK:   arith.addi
+  %0 = arith.constant 0 : index
+  %1 = arith.constant 41 : index
+  %2 = arith.constant 5 : index
+  scf.for %i = %0 to %1 step %2 {
+    arith.addi %i, %i : index
+  }
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["arith.addi"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %1 = transform.get_parent_op %0 {op_name = "scf.for"} : (!transform.any_op) -> !transform.op<"scf.for">
+    %main_loop, %remainder = transform.loop.peel %1 {peel_front = true} : (!transform.op<"scf.for">) -> (!transform.op<"scf.for">, !transform.op<"scf.for">)
+    transform.yield
+  }
+}
+
+// -----
+
 func.func @loop_pipeline_op(%A: memref<?xf32>, %result: memref<?xf32>) {
   %c0 = arith.constant 0 : index
   %c1 = arith.constant 1 : index

>From 6069966eb3665d9a6666593d0e02696bc17baef8 Mon Sep 17 00:00:00 2001
From: yzhang93 <zhyuhang88 at gmail.com>
Date: Mon, 11 Dec 2023 20:18:16 -0800
Subject: [PATCH 2/2] Some changes

---
 .../SCF/TransformOps/SCFTransformOps.td       |  4 +-
 .../mlir/Dialect/SCF/Transforms/Transforms.h  |  2 +-
 .../SCF/TransformOps/SCFTransformOps.cpp      |  2 +-
 .../SCF/Transforms/LoopSpecialization.cpp     | 56 +++++++------------
 .../Dialect/SCF/for-loop-peeling-front.mlir   | 49 +++++++++++++---
 5 files changed, 66 insertions(+), 47 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
index a9c70c9e3433f..b5ac22a2a758d 100644
--- a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
+++ b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
@@ -126,8 +126,8 @@ def LoopPeelOp : Op<Transform_Dialect, "loop.peel",
      contains the remaining iteration, and the new lower bound is the original
      lower bound plus the number of steps.
 
-     For the other case, this operation returns two scf::ForOp Ops, with the first
-     Op satisfying the postcondition: "the loop trip count is divisible by the step".
+     When `peelFront` is not true, this operation returns two scf::ForOp Ops, with the first
+     scf::ForOp satisfying: "the loop trip count is divisible by the step".
      The second loop Op contains the remaining iteration. Note that even though the
      Payload IR modification may be performed in-place, this operation consumes
      the operand handle and produces a new one.
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h b/mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h
index d736d049fd093..78be1c77bfab8 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h
@@ -83,7 +83,7 @@ LogicalResult peelForLoopAndSimplifyBounds(RewriterBase &rewriter, ForOp forOp,
 
 /// Peel the first iteration out of the scf.for loop. If there is only one
 /// iteration, return the original loop.
-LogicalResult peelFirstIterationForLoop(RewriterBase &rewriter, ForOp forOp,
+LogicalResult peelForLoopFirstIteration(RewriterBase &rewriter, ForOp forOp,
                                         scf::ForOp &partialIteration);
 
 /// Tile a parallel loop of the form
diff --git a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
index 1fdd6f1a7015a..bc2fe5772af9d 100644
--- a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
+++ b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
@@ -195,7 +195,7 @@ transform::LoopPeelOp::applyToOne(transform::TransformRewriter &rewriter,
   scf::ForOp result;
   if (getPeelFront()) {
     LogicalResult status =
-        scf::peelFirstIterationForLoop(rewriter, target, result);
+        scf::peelForLoopFirstIteration(rewriter, target, result);
     if (failed(status)) {
       DiagnosedSilenceableFailure diag =
           emitSilenceableError() << "failed to peel the first iteration";
diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp
index 95e4f2e5b560a..9fda4861d40a3 100644
--- a/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp
@@ -191,23 +191,6 @@ static void rewriteAffineOpAfterPeeling(RewriterBase &rewriter, ForOp forOp,
   });
 }
 
-static void removeAffineOpInsideFirstIteration(RewriterBase &rewriter,
-                                               ForOp partialIteration,
-                                               Value previousUb) {
-  Value partialIv = partialIteration.getInductionVar();
-  Value step = partialIteration.getStep();
-
-  partialIteration.walk([&](Operation *affineOp) {
-    if (!isa<AffineMinOp, AffineMaxOp>(affineOp))
-      return WalkResult::advance();
-
-    // Hack to reuse the existing utils when ub - iv >= step.
-    (void)scf::rewritePeeledMinMaxOp(rewriter, affineOp, partialIv, previousUb,
-                                     step, /*insideLoop=*/true);
-    return WalkResult::advance();
-  });
-}
-
 LogicalResult mlir::scf::peelForLoopAndSimplifyBounds(RewriterBase &rewriter,
                                                       ForOp forOp,
                                                       ForOp &partialIteration) {
@@ -222,7 +205,13 @@ LogicalResult mlir::scf::peelForLoopAndSimplifyBounds(RewriterBase &rewriter,
   return success();
 }
 
-LogicalResult mlir::scf::peelFirstIterationForLoop(RewriterBase &b, ForOp forOp,
+/// When the `peelFront` option is set as true, the first iteration of the loop
+/// is peeled off. This function rewrites the original scf::ForOp as two
+/// scf::ForOp Ops, the first scf::ForOp corresponds to the first iteration of
+/// the loop which can be canonicalized away in the following optimization. The
+/// second loop Op contains the remaining iteration, and the new lower bound is
+/// the original lower bound plus the number of steps.
+LogicalResult mlir::scf::peelForLoopFirstIteration(RewriterBase &b, ForOp forOp,
                                                    ForOp &firstIteration) {
   RewriterBase::InsertionGuard guard(b);
   auto lbInt = getConstantIntValue(forOp.getLowerBound());
@@ -231,29 +220,28 @@ LogicalResult mlir::scf::peelFirstIterationForLoop(RewriterBase &b, ForOp forOp,
 
   // Peeling is not needed if there is one or less iteration.
   if (lbInt && ubInt && stepInt && (*ubInt - *lbInt) / *stepInt <= 1)
-    return success();
+    return failure();
 
-  AffineExpr sym0, sym1, sym2;
-  bindSymbols(b.getContext(), sym0, sym1, sym2);
-  SmallVector<Value> operands{forOp.getLowerBound(), forOp.getUpperBound(),
-                              forOp.getStep()};
+  AffineExpr lbSymbol, stepSymbol;
+  bindSymbols(b.getContext(), lbSymbol, stepSymbol);
 
   // New lower bound for main loop: %lb + %step
-  auto ubMap = AffineMap::get(0, 2, {sym0 + sym1});
+  auto ubMap = AffineMap::get(0, 2, {lbSymbol + stepSymbol});
   b.setInsertionPoint(forOp);
   auto loc = forOp.getLoc();
   Value splitBound = b.createOrFold<AffineApplyOp>(
       loc, ubMap, ValueRange{forOp.getLowerBound(), forOp.getStep()});
 
   // Peel the first iteration.
-  b.setInsertionPoint(forOp);
-  firstIteration = cast<ForOp>(b.clone(*forOp.getOperation()));
-  firstIteration.getUpperBoundMutable().assign(splitBound);
+  IRMapping map;
+  map.map(forOp.getUpperBound(), splitBound);
+  firstIteration = cast<ForOp>(b.clone(*forOp.getOperation(), map));
 
   // Update main loop with new lower bound.
-  forOp.getInitArgsMutable().assign(firstIteration->getResults());
-  b.updateRootInPlace(
-      forOp, [&]() { forOp.getLowerBoundMutable().assign(splitBound); });
+  b.updateRootInPlace(forOp, [&]() {
+    forOp.getInitArgsMutable().assign(firstIteration->getResults());
+    forOp.getLowerBoundMutable().assign(splitBound);
+  });
 
   return success();
 }
@@ -277,13 +265,9 @@ struct ForLoopPeelingPattern : public OpRewritePattern<ForOp> {
     // The case for peeling the first iteration of the loop.
     if (peelFront) {
       if (failed(
-              peelFirstIterationForLoop(rewriter, forOp, partialIteration))) {
+              peelForLoopFirstIteration(rewriter, forOp, partialIteration))) {
         return failure();
       }
-      // Remove affine.min op in the first (partial) iteration.
-      removeAffineOpInsideFirstIteration(rewriter, partialIteration,
-                                         forOp.getUpperBound());
-
     } else {
       if (skipPartial) {
         // No peeling of loops inside the partial iteration of another peeled
@@ -311,6 +295,8 @@ struct ForLoopPeelingPattern : public OpRewritePattern<ForOp> {
     return success();
   }
 
+  // If set to true, the first iteration of the loop will be peeled. Otherwise,
+  // the unevenly divisible loop will be peeled at the end.
   bool peelFront;
 
   /// If set to true, loops inside partial iterations of another peeled loop
diff --git a/mlir/test/Dialect/SCF/for-loop-peeling-front.mlir b/mlir/test/Dialect/SCF/for-loop-peeling-front.mlir
index f886b3522aef8..65141ff7623ff 100644
--- a/mlir/test/Dialect/SCF/for-loop-peeling-front.mlir
+++ b/mlir/test/Dialect/SCF/for-loop-peeling-front.mlir
@@ -3,20 +3,21 @@
 //  CHECK-DAG: #[[MAP:.*]] = affine_map<(d0, d1)[s0] -> (4, d0 - d1)>
 //      CHECK: func @fully_static_bounds(
 //  CHECK-DAG:   %[[C4:.*]] = arith.constant 4 : index
-//  CHECK-DAG:   %[[C4_I32:.*]] = arith.constant 4 : i32
 //  CHECK-DAG:   %[[C0_I32:.*]] = arith.constant 0 : i32
 //  CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : index
 //  CHECK-DAG:   %[[C17:.*]] = arith.constant 17 : index
 //      CHECK:   %[[FIRST:.*]] = scf.for %[[IV:.*]] = %[[C0]] to %[[C4]]
 // CHECK-SAME:       step %[[C4]] iter_args(%[[ACC:.*]] = %[[C0_I32]]) -> (i32) {
-//      CHECK:     %[[INIT:.*]] = arith.addi %[[ACC]], %[[C4_I32]] : i32
+//      CHECK:     %[[MIN:.*]] = affine.min #[[MAP]](%[[C4]], %[[IV]])[%[[C4]]]
+//      CHECK:     %[[CAST:.*]] = arith.index_cast %[[MIN]] : index to i32
+//      CHECK:     %[[INIT:.*]] = arith.addi %[[ACC]], %[[CAST]] : i32
 //      CHECK:     scf.yield %[[INIT]]
 //      CHECK:   }
 //      CHECK:   %[[RESULT:.*]] = scf.for %[[IV:.*]] = %[[C4]] to %[[C17]]
 // CHECK-SAME:       step %[[C4]] iter_args(%[[ACC:.*]] = %[[FIRST]]) -> (i32) {
-//      CHECK:     %[[MIN:.*]] = affine.min #[[MAP]](%[[C17]], %[[IV]])[%[[C4]]]
-//      CHECK:     %[[CAST:.*]] = arith.index_cast %[[MIN]] : index to i32
-//      CHECK:     %[[ADD:.*]] = arith.addi %[[ACC]], %[[CAST]] : i32
+//      CHECK:     %[[MIN2:.*]] = affine.min #[[MAP]](%[[C17]], %[[IV]])[%[[C4]]]
+//      CHECK:     %[[CAST2:.*]] = arith.index_cast %[[MIN2]] : index to i32
+//      CHECK:     %[[ADD:.*]] = arith.addi %[[ACC]], %[[CAST2]] : i32
 //      CHECK:     scf.yield %[[ADD]]
 //      CHECK:   }
 //      CHECK:   return %[[RESULT]]
@@ -40,12 +41,13 @@ func.func @fully_static_bounds() -> i32 {
 //  CHECK-DAG: #[[MAP:.*]] = affine_map<(d0, d1)[s0] -> (4, d0 - d1)>
 //      CHECK: func @no_loop_results(
 // CHECK-SAME:     %[[UB:.*]]: index, %[[MEMREF:.*]]: memref<i32>
-//  CHECK-DAG:   %[[C4_I32:.*]] = arith.constant 4 : i32
 //  CHECK-DAG:   %[[C4:.*]] = arith.constant 4 : index
 //  CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : index
 //      CHECK:   scf.for %[[IV:.*]] = %[[C0]] to %[[C4]] step %[[C4]] {
+//      CHECK:     %[[MIN:.*]] = affine.min #[[MAP]](%[[C4]], %[[IV]])[%[[C4]]]
 //      CHECK:     %[[LOAD:.*]] = memref.load %[[MEMREF]][]
-//      CHECK:     %[[ADD:.*]] = arith.addi %[[LOAD]], %[[C4_I32]] : i32
+//      CHECK:     %[[CAST:.*]] = arith.index_cast %[[MIN]]
+//      CHECK:     %[[ADD:.*]] = arith.addi %[[LOAD]], %[[CAST]] : i32
 //      CHECK:     memref.store %[[ADD]], %[[MEMREF]]
 //      CHECK:   }
 //      CHECK:   scf.for %[[IV2:.*]] = %[[C4]] to %[[UB]] step %[[C4]] {
@@ -81,7 +83,8 @@ func.func @no_loop_results(%ub : index, %d : memref<i32>) {
 //      CHECK:   %[[NEW_UB:.*]] = affine.apply #[[MAP0]]()[%[[LB]], %[[STEP]]]
 //      CHECK:   %[[FIRST:.*]] = scf.for %[[IV:.*]] = %[[LB]] to %[[NEW_UB]]
 // CHECK-SAME:       step %[[STEP]] iter_args(%[[ACC:.*]] = %[[C0_I32]]) -> (i32) {
-//      CHECK:     %[[CAST:.*]] = arith.index_cast %[[STEP]] : index to i32
+//      CHECK:     %[[MIN:.*]] = affine.min #[[MAP1]](%[[NEW_UB]], %[[IV]])[%[[STEP]]]
+//      CHECK:     %[[CAST:.*]] = arith.index_cast %[[MIN]] : index to i32
 //      CHECK:     %[[ADD:.*]] = arith.addi %[[ACC]], %[[CAST]] : i32
 //      CHECK:     scf.yield %[[ADD]]
 //      CHECK:   }
@@ -104,3 +107,33 @@ func.func @fully_dynamic_bounds(%lb : index, %ub: index, %step: index) -> i32 {
   }
   return %r : i32
 }
+
+// -----
+
+//  CHECK-DAG: #[[MAP:.*]] = affine_map<(d0, d1)[s0] -> (4, d0 - d1)>
+//      CHECK: func @no_peeling_front(
+//  CHECK-DAG:   %[[C0_I32:.*]] = arith.constant 0 : i32
+//  CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : index
+//  CHECK-DAG:   %[[C4:.*]] = arith.constant 4 : index
+//      CHECK:   %[[RESULT:.*]] = scf.for %[[IV:.*]] = %[[C0]] to %[[C4]]
+// CHECK-SAME:       step %[[C4]] iter_args(%[[ACC:.*]] = %[[C0_I32]]) -> (i32) {
+//      CHECK:     %[[MIN:.*]] = affine.min #[[MAP]](%[[C4]], %[[IV]])[%[[C4]]]
+//      CHECK:     %[[CAST:.*]] = arith.index_cast %[[MIN]] : index to i32
+//      CHECK:     %[[ADD:.*]] = arith.addi %[[ACC]], %[[CAST]] : i32
+//      CHECK:     scf.yield %[[ADD]]
+//      CHECK:   }
+//      CHECK:   return %[[RESULT]]
+#map = affine_map<(d0, d1)[s0] -> (s0, d0 - d1)>
+func.func @no_peeling_front() -> i32 {
+  %c0_i32 = arith.constant 0 : i32
+  %lb = arith.constant 0 : index
+  %step = arith.constant 4 : index
+  %ub = arith.constant 4 : index
+  %r = scf.for %iv = %lb to %ub step %step iter_args(%arg = %c0_i32) -> i32 {
+    %s = affine.min #map(%ub, %iv)[%step]
+    %casted = arith.index_cast %s : index to i32
+    %0 = arith.addi %arg, %casted : i32
+    scf.yield %0 : i32
+  }
+  return %r : i32
+}



More information about the Mlir-commits mailing list