[Lldb-commits] [lld] [clang-tools-extra] [llvm] [libcxx] [compiler-rt] [mlir] [libc] [clang] [lldb] [MLIR][LLVM] Add Continuous Loop Peeling transform to SCF (PR #71555)

via lldb-commits lldb-commits at lists.llvm.org
Thu Dec 14 03:21:35 PST 2023

https://github.com/muneebkhan85 updated https://github.com/llvm/llvm-project/pull/71555

>From 7bb2f9793b2a2cccbaa401f6e2ac850b587f2b59 Mon Sep 17 00:00:00 2001
From: Muneeb Khan <muneeb.khan at huawei.com>
Date: Tue, 7 Nov 2023 23:52:17 +0800
Subject: [PATCH 1/9] [MLIR][LLVM] Add Continuous Loop Peeling transform to SCF

This patch adds continuous loop peeling to scf loop transforms
in the MLIR backend. This transforms the target loop into a
chain of loops, with step sizes that are powers of two and
decrease exponetially across subsequent loops. Originally
authored by Litu Zhou litu.zhou at huawei.com.
 .../SCF/TransformOps/SCFTransformOps.td       |  36 +++++
 .../SCF/TransformOps/SCFTransformOps.cpp      | 147 ++++++++++++++++++
 .../Dialect/SCF/loop-continuous-peel.mlir     |  98 ++++++++++++
 3 files changed, 281 insertions(+)
 create mode 100644 mlir/test/Dialect/SCF/loop-continuous-peel.mlir

diff --git a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
index 14df7e23a430fb..e3d79a7f0ae40f 100644
--- a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
+++ b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
@@ -147,6 +147,42 @@ def LoopPeelOp : Op<Transform_Dialect, "loop.peel",
+def LoopContinuousPeelOp : Op<Transform_Dialect, "loop.loop_continuous_peel",
+    [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
+     TransformOpInterface, TransformEachOpTrait]> {
+  let description = [{
+    Transforms the loop into a chain of loops, with step sizes that are
+    powers of two and decrease exponetially across subsequent loops.
+    The transform is similar to loop.peel in the effect that it creates a loop
+    with a step (that is power of 2) to divide the range evenly, with the
+    difference that the remaining iterations are spread across similar loops
+    with exponentially decreasing step sizes, with the last loop with step size
+    of 2^0 = 1.
+    #### Return modes
+    This operation consumes the `target` handles and produces the
+    continuously-peeled loop.
+  }];
+  let arguments =
+      (ins TransformHandleTypeInterface:$target,
+           DefaultValuedAttr<BoolAttr, "false">:$single_iter_opt);
+  // TODO: Return both the peeled loop and the remainder loop.
+  let results = (outs TransformHandleTypeInterface:$transformed);
+  let assemblyFormat =
+    "$target attr-dict `:` functional-type(operands, results)";
+  let extraClassDeclaration = [{
+    ::mlir::DiagnosedSilenceableFailure applyToOne(
+        ::mlir::transform::TransformRewriter &rewriter,
+        ::mlir::Operation *target,
+        ::mlir::transform::ApplyToEachResultList &results,
+        ::mlir::transform::TransformState &state);
+  }];
 def LoopPipelineOp : Op<Transform_Dialect, "loop.pipeline",
     [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
      TransformOpInterface, TransformEachOpTrait]> {
diff --git a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
index 62370604142cd5..dcba6a8b406b21 100644
--- a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
+++ b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
@@ -206,6 +206,153 @@ transform::LoopPeelOp::applyToOne(transform::TransformRewriter &rewriter,
   return DiagnosedSilenceableFailure::success();
+// LoopContinuousPeelOp
+static LogicalResult splitLoopHelper(RewriterBase &b, scf::ForOp &forOp,
+                                     scf::ForOp &partialIteration,
+                                     Value &splitBound) {
+  RewriterBase::InsertionGuard guard(b);
+  auto lbInt = getConstantIntValue(forOp.getLowerBound());
+  auto ubInt = getConstantIntValue(forOp.getUpperBound());
+  auto stepInt = getConstantIntValue(forOp.getStep());
+  // No specialization necessary if step already divides upper bound evenly.
+  if (lbInt && ubInt && stepInt && (*ubInt - *lbInt) % *stepInt == 0)
+    return failure();
+  // No specialization necessary if step size is 1.
+  if (stepInt == static_cast<int64_t>(1))
+    return failure();
+  // Create ForOp for partial iteration.
+  b.setInsertionPointAfter(forOp);
+  partialIteration = cast<scf::ForOp>(b.clone(*forOp.getOperation()));
+  partialIteration.getLowerBoundMutable().assign(splitBound);
+  forOp.replaceAllUsesWith(partialIteration->getResults());
+  partialIteration.getInitArgsMutable().assign(forOp->getResults());
+  // Set new upper loop bound.
+  b.updateRootInPlace(
+      forOp, [&]() { forOp.getUpperBoundMutable().assign(splitBound); });
+  return success();
+static scf::IfOp convertSingleIterFor(RewriterBase &b, scf::ForOp &forOp) {
+  Location loc = forOp->getLoc();
+  IRMapping mapping;
+  mapping.map(forOp.getInductionVar(), forOp.getLowerBound());
+  for (auto [arg, operand] :
+       llvm::zip(forOp.getRegionIterArgs(), forOp.getInitsMutable())) {
+    mapping.map(arg, operand.get());
+  }
+  b.setInsertionPoint(forOp);
+  auto cond =
+      b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
+                              forOp.getLowerBound(), forOp.getUpperBound());
+  auto ifOp = b.create<scf::IfOp>(loc, forOp->getResultTypes(), cond, true);
+  // then branch
+  b.setInsertionPointToStart(ifOp.thenBlock());
+  for (Operation &op : forOp.getBody()->getOperations()) {
+    b.clone(op, mapping);
+  }
+  // else branch
+  b.setInsertionPointToStart(ifOp.elseBlock());
+  if (!forOp->getResultTypes().empty()) {
+    b.create<scf::YieldOp>(loc, forOp.getInits());
+  }
+  b.replaceOp(forOp, ifOp->getResults());
+  return ifOp;
+DiagnosedSilenceableFailure transform::LoopContinuousPeelOp::applyToOne(
+    transform::TransformRewriter &rewriter, Operation *target,
+    transform::ApplyToEachResultList &results,
+    transform::TransformState &state) {
+  scf::ForOp loop, currentLoop, partialLoop;
+  loop = dyn_cast<scf::ForOp>(target);
+  auto lbInt = getConstantIntValue(loop.getLowerBound());
+  auto stepInt = getConstantIntValue(loop.getStep());
+  if (!stepInt.has_value() || *stepInt <= 0)
+    return DiagnosedSilenceableFailure::
+        definiteFailure(); // step size must be a known positive constant
+  Value initialUb = loop.getUpperBound();
+  Value initialStep = loop.getStep();
+  uint64_t loopStep = *stepInt;
+  currentLoop = loop;
+  AffineExpr sym0, sym1, sym2;
+  bindSymbols(rewriter.getContext(), sym0, sym1, sym2);
+  AffineMap defaultSplitMap =
+      AffineMap::get(0, 3, {sym1 - ((sym1 - sym0) % sym2)});
+  AffineMap powerSplitMap = AffineMap::get(0, 3, {sym1 - (sym1 % sym2)});
+  bool usePowerSplit = (lbInt.has_value()) &&
+                       (*lbInt % *stepInt == static_cast<int64_t>(0)) &&
+                       (loopStep == llvm::bit_floor(loopStep));
+  AffineMap splitMap = usePowerSplit ? powerSplitMap : defaultSplitMap;
+  SmallVector<scf::ForOp> loops;
+  while (loopStep) {
+    rewriter.setInsertionPoint(currentLoop);
+    auto constStepOp =
+        rewriter.create<arith::ConstantIndexOp>(currentLoop.getLoc(), loopStep);
+    currentLoop.getStepMutable().assign(constStepOp);
+    rewriter.setInsertionPoint(currentLoop);
+    Value splitBound = rewriter.createOrFold<affine::AffineApplyOp>(
+        currentLoop.getLoc(), splitMap,
+        ValueRange{currentLoop.getLowerBound(), currentLoop.getUpperBound(),
+                   currentLoop.getStep()});
+    LogicalResult status =
+        splitLoopHelper(rewriter, currentLoop, partialLoop, splitBound);
+    // Canonicalize min/max affine operations
+    // It uses scf::rewritePeeledMinMaxOp to identify operations to be replaced,
+    // they are then replaced by the current step size.
+    // TODO: Alternative method - update affine map to reflect the loop step
+    // Example: min(ub - iv, 8) -> min(ub - iv, 4)
+    currentLoop.walk([&](affine::AffineMinOp affineOp) {
+      rewriter.setInsertionPoint(affineOp);
+      auto clonedOp = cast<affine::AffineMinOp>(rewriter.clone(*affineOp));
+      LogicalResult result = scf::rewritePeeledMinMaxOp(
+          rewriter, clonedOp, currentLoop.getInductionVar(), initialUb,
+          initialStep,
+          /*insideLoop=*/true);
+      if (result.succeeded())
+        rewriter.replaceOp(affineOp, currentLoop.getStep());
+      else
+        rewriter.eraseOp(clonedOp); // to avoid infinite walk
+    });
+    currentLoop.walk([&](affine::AffineMaxOp affineOp) {
+      rewriter.setInsertionPoint(affineOp);
+      auto clonedOp = cast<affine::AffineMaxOp>(rewriter.clone(*affineOp));
+      LogicalResult result = scf::rewritePeeledMinMaxOp(
+          rewriter, clonedOp, currentLoop.getInductionVar(), initialUb,
+          initialStep,
+          /*insideLoop=*/true);
+      if (result.succeeded())
+        rewriter.replaceOp(affineOp, currentLoop.getStep());
+      else
+        rewriter.eraseOp(clonedOp); // to avoid infinite walk
+    });
+    // Prepare for the next iteration
+    loops.push_back(currentLoop);
+    if (failed(status))
+      break;
+    currentLoop = partialLoop;
+    uint64_t maxPower = llvm::bit_floor(loopStep);
+    loopStep = maxPower == loopStep ? maxPower >> 1 : maxPower;
+  }
+  assert(loops.size() > 0 && "There should be at least one loop available");
+  if (getSingleIterOpt()) {
+    for (size_t i = 1; i < loops.size(); ++i) {
+      convertSingleIterFor(rewriter, loops[i]);
+    }
+  }
+  results.push_back(loops.front());
+  return DiagnosedSilenceableFailure::success();
 // LoopPipelineOp
diff --git a/mlir/test/Dialect/SCF/loop-continuous-peel.mlir b/mlir/test/Dialect/SCF/loop-continuous-peel.mlir
new file mode 100644
index 00000000000000..752e1b1efed92a
--- /dev/null
+++ b/mlir/test/Dialect/SCF/loop-continuous-peel.mlir
@@ -0,0 +1,98 @@
+// RUN: mlir-opt %s --transform-interpreter -split-input-file | FileCheck %s
+#map = affine_map<(d0) -> ()>
+#map1 = affine_map<(d0) -> (d0)>
+module {
+  func.func @foo(%arg0: f32, %arg1: tensor<?xf32>) -> tensor<?xf32> {
+    %0 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel"]} ins(%arg0 : f32) outs(%arg1 : tensor<?xf32>) {
+    ^bb0(%in: f32, %out: f32):
+      %3 = arith.mulf %in, %out : f32
+      linalg.yield %3 : f32
+    } -> tensor<?xf32>
+    return %0 : tensor<?xf32>
+  }
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %1, %loops = transform.structured.tile_using_for %0[8] : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    %2 = transform.cast %loops : !transform.any_op to !transform.op<"scf.for">
+    %3 = transform.loop.loop_continuous_peel %2 {single_iter_opt = true} : (!transform.op<"scf.for">) -> (!transform.any_op)
+    transform.yield
+  }
+// CHECK: #[[MAP:.*]] = affine_map<()[s0, s1, s2] -> (s1 - s1 mod s2)>
+// CHECK: #[[MAP1:.*]] = affine_map<() -> (8)>
+// CHECK: #[[MAP2:.*]] = affine_map<(d0) -> (d0 - 1)>
+// CHECK: #[[MAP3:.*]] = affine_map<(d0) -> ()>
+// CHECK: #[[MAP4:.*]] = affine_map<(d0) -> (d0)>
+// CHECK: func.func @foo(%[[S:.*]]: f32, %[[INVEC1:.*]]: tensor<?xf32>) -> tensor<?xf32> {
+// CHECK:       %[[C0:.*]] = arith.constant 0 : index
+// CHECK:       %[[DIM:.*]] = tensor.dim %[[INVEC1]], %[[C0]] : tensor<?xf32>
+// CHECK:       %[[C0:.*]] = arith.constant 0 : index
+// CHECK:       %{{.*}} = arith.constant 8 : index
+// CHECK:       %[[C8:.*]] = arith.constant 8 : index
+// CHECK:       %[[IDX0:.*]] = affine.apply #[[MAP]]()[%[[C0]], %[[DIM]], %[[C8]]]
+// CHECK:       %[[INS1:.*]] = scf.for %[[IDX:.*]] = %[[C0]] to %[[IDX0]] step %[[C8]] iter_args(%[[AINVEC1:.*]] = %[[INVEC1]]) -> (tensor<?xf32>) {
+// CHECK:         %{{.*}} = affine.apply #[[MAP2]](%[[C8]])
+// CHECK:         %[[XS8:.*]] = tensor.extract_slice %[[AINVEC1]][%[[IDX]]] [%[[C8]]] [1] : tensor<?xf32> to tensor<?xf32>
+// CHECK:         %[[MUL:.*]] = linalg.generic {indexing_maps = [#[[MAP3]], #[[MAP4]]], iterator_types = ["parallel"]} ins(%{{.*}} : f32) outs(%[[XS8]] : tensor<?xf32>) {
+// CHECK:         ^bb0(%{{.*}}: f32, %{{.*}}: f32):
+// CHECK:           %{{.*}} = arith.mulf %{{.*}}, %{{.*}} : f32
+// CHECK:           linalg.yield %{{.*}} : f32
+// CHECK:         } -> tensor<?xf32>
+// CHECK:         %[[INS:.*]] = tensor.insert_slice %[[MUL]] into %[[AINVEC1]][%[[IDX]]] [%[[C8]]] [1] : tensor<?xf32> into tensor<?xf32>
+// CHECK:         scf.yield %[[INS]] : tensor<?xf32>
+// CHECK:       }
+// CHECK:       %[[C4:.*]] = arith.constant 4 : index
+// CHECK:       %[[IDX2:.*]] = affine.apply #[[MAP]]()[%[[IDX0]], %[[DIM]], %[[C4]]]
+// CHECK:       %[[CMP3:.*]] = arith.cmpi slt, %[[IDX0]], %[[IDX2]] : index
+// CHECK:       %[[INS2:.*]] = scf.if %[[CMP3]] -> (tensor<?xf32>) {
+// CHECK:          %{{.*}} = affine.apply #[[MAP2]](%[[C4]])
+// CHECK:         %[[XS4:.*]] = tensor.extract_slice %[[INS1]][%[[IDX0]]] [%[[C4]]] [1] : tensor<?xf32> to tensor<?xf32>
+// CHECK:         %[[MUL:.*]] = linalg.generic {indexing_maps = [#[[MAP3]], #[[MAP4]]], iterator_types = ["parallel"]} ins(%[[S]] : f32) outs(%[[XS4]] : tensor<?xf32>) {
+// CHECK:         ^bb0(%{{.*}}: f32, %{{.*}}: f32):
+// CHECK:           %{{.*}} = arith.mulf  %{{.*}},  %{{.*}} : f32
+// CHECK:           linalg.yield  %{{.*}} : f32
+// CHECK:         } -> tensor<?xf32>
+// CHECK:         %[[INS:.*]] = tensor.insert_slice %[[MUL]] into %[[INS1]][%[[IDX0]]] [%[[C4]]] [1] : tensor<?xf32> into tensor<?xf32>
+// CHECK:         scf.yield %[[INS]] : tensor<?xf32>
+// CHECK:       } else {
+// CHECK:         scf.yield %[[INS1]] : tensor<?xf32>
+// CHECK:       }
+// CHECK:       %[[C2:.*]] = arith.constant 2 : index
+// CHECK:       %[[IDX3:.*]] = affine.apply #[[MAP]]()[%[[IDX2]], %[[DIM]], %[[C2]]]
+// CHECK:       %[[CMP4:.*]] = arith.cmpi slt, %[[IDX2]], %[[IDX3]] : index
+// CHECK:       %[[INS3:.*]] = scf.if %[[CMP4]] -> (tensor<?xf32>) {
+// CHECK:         %{{.*}} = affine.apply #[[MAP2]](%[[C2]])
+// CHECK:         %[[XS2:.*]] = tensor.extract_slice %[[INS2]][%[[IDX2]]] [%[[C2]]] [1] : tensor<?xf32> to tensor<?xf32>
+// CHECK:         %[[MUL:.*]] = linalg.generic {indexing_maps = [#[[MAP3]], #[[MAP4]]], iterator_types = ["parallel"]} ins(%[[S]] : f32) outs(%[[XS2]] : tensor<?xf32>) {
+// CHECK:         ^bb0(%{{.*}}: f32, %{{.*}}: f32):
+// CHECK:           %{{.*}} = arith.mulf %{{.*}}, %{{.*}} : f32
+// CHECK:           linalg.yield %{{.*}} : f32
+// CHECK:         } -> tensor<?xf32>
+// CHECK:         %[[INS:.*]] = tensor.insert_slice %[[MUL]] into %[[INS2]][%[[IDX2]]] [%[[C2]]] [1] : tensor<?xf32> into tensor<?xf32>
+// CHECK:         scf.yield %[[INS]] : tensor<?xf32>
+// CHECK:       } else {
+// CHECK:         scf.yield %[[INS2]] : tensor<?xf32>
+// CHECK:       }
+// CHECK:       %[[C1:.*]] = arith.constant 1 : index
+// CHECK:       %{{.*}} = affine.apply #[[MAP]]()[%[[IDX3]], %[[DIM]], %[[C1]]]
+// CHECK:       %[[CMP5:.*]] = arith.cmpi slt, %[[IDX3]], %[[DIM]] : index
+// CHECK:       %[[INS4:.*]] = scf.if %[[CMP5]] -> (tensor<?xf32>) {
+// CHECK:         %{{.*}} = affine.apply #[[MAP2]](%[[C1]])
+// CHECK:         %[[XS1:.*]] = tensor.extract_slice %[[INS3]][%[[IDX3]]] [%[[C1]]] [1] : tensor<?xf32> to tensor<?xf32>
+// CHECK:         %[[MUL:.*]] = linalg.generic {indexing_maps = [#[[MAP3]], #[[MAP4]]], iterator_types = ["parallel"]} ins(%[[S]] : f32) outs(%[[XS1]] : tensor<?xf32>) {
+// CHECK:         ^bb0(%{{.*}}: f32, %{{.*}}: f32):
+// CHECK:           %{{.*}} = arith.mulf %{{.*}}, %{{.*}} : f32
+// CHECK:           linalg.yield %{{.*}} : f32
+// CHECK:         } -> tensor<?xf32>
+// CHECK:         %[[INS:.*]] = tensor.insert_slice %[[MUL]] into %[[INS3]][%[[IDX3]]] [%[[C1]]] [1] : tensor<?xf32> into tensor<?xf32>
+// CHECK:         scf.yield %[[INS]] : tensor<?xf32>
+// CHECK:       } else {
+// CHECK:         scf.yield %[[INS3]] : tensor<?xf32>
+// CHECK:       }
+// CHECK:       return %[[INS4]] : tensor<?xf32>

>From 98ba0bd1177af08f286a0eaf16a85828f5410c7d Mon Sep 17 00:00:00 2001
From: Muneeb Khan <muneeb.khan at huawei.com>
Date: Tue, 7 Nov 2023 23:52:17 +0800
Subject: [PATCH 2/9] [MLIR][LLVM][Fixes] Add Continuous Loop Peeling transform
 to SCF

1. The transformation has been moved to

2. Test in mlir/test/Dialect/SCF/loop-continuous-peel.mlir
   simplified using scf.for

3. Added -scf-for-loop-continuous-peeling pass for applying
   this transform independently on scf.for loops.

This commit should be squashed into the original.
 .../SCF/TransformOps/SCFTransformOps.td       |   5 +-
 .../mlir/Dialect/SCF/Transforms/Passes.h      |   6 +
 .../mlir/Dialect/SCF/Transforms/Passes.td     |  11 +
 .../mlir/Dialect/SCF/Transforms/Transforms.h  |  41 ++++
 .../SCF/TransformOps/SCFTransformOps.cpp      | 145 ++----------
 .../SCF/Transforms/LoopSpecialization.cpp     | 221 ++++++++++++++++++
 .../Dialect/SCF/loop-continuous-peel.mlir     | 134 ++++-------
 7 files changed, 336 insertions(+), 227 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
index e3d79a7f0ae40f..10602271a9aa52 100644
--- a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
+++ b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
@@ -167,9 +167,10 @@ def LoopContinuousPeelOp : Op<Transform_Dialect, "loop.loop_continuous_peel",
   let arguments =
       (ins TransformHandleTypeInterface:$target,
-           DefaultValuedAttr<BoolAttr, "false">:$single_iter_opt);
+           DefaultValuedAttr<BoolAttr, "false">:$convert_single_iter_loops_to_if);
   // TODO: Return both the peeled loop and the remainder loop.
-  let results = (outs TransformHandleTypeInterface:$transformed);
+  let results = (outs TransformHandleTypeInterface:$peeled_loop,
+                      TransformHandleTypeInterface:$remainder_loop);
   let assemblyFormat =
     "$target attr-dict `:` functional-type(operands, results)";
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/Passes.h b/mlir/include/mlir/Dialect/SCF/Transforms/Passes.h
index 90b315e83a8cfd..9ff1d6a07f17c3 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/Passes.h
@@ -31,6 +31,12 @@ std::unique_ptr<Pass> createForLoopSpecializationPass();
 /// better vectorization.
 std::unique_ptr<Pass> createForLoopPeelingPass();
+/// Creates a pass that transforms a for loop into a chain of loops
+/// where the step size is always a power of 2 but decreases exponentially
+/// across the loops. Helps with dividing the iteration space across all
+/// resulting peeled loops evenly.
+std::unique_ptr<Pass> createForLoopContinuousPeelingPass();
 /// Creates a pass that canonicalizes affine.min and affine.max operations
 /// inside of scf.for loops with known lower and upper bounds.
 std::unique_ptr<Pass> createSCFForLoopCanonicalizationPass();
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td b/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td
index bbc673f44977ac..daafb78a9134cc 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td
@@ -40,6 +40,17 @@ def SCFForLoopPeeling : Pass<"scf-for-loop-peeling"> {
   let dependentDialects = ["affine::AffineDialect"];
+def SCFForLoopContinuousPeeling : Pass<"scf-for-loop-continuous-peeling"> {
+  let summary = "Convert a loop into a chain of loops with exponentially decreasing steps that are power of 2.";
+  let constructor = "mlir::createForLoopContinuousPeelingPass()";
+  let options = [
+    Option<"convertSingleIterLoopsToIf", "convert-single-iter-loops-to-if", "bool",
+           /*default=*/"false",
+           "Convert single iteration loops to if. ">
+  ];
+  let dependentDialects = ["affine::AffineDialect"];
 def SCFForLoopSpecialization : Pass<"scf-for-loop-specialization"> {
   let summary = "Specialize `for` loops for vectorization";
   let constructor = "mlir::createForLoopSpecializationPass()";
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h b/mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h
index 347beb9e4c64f8..8a69b207677731 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h
@@ -81,6 +81,47 @@ void naivelyFuseParallelOps(Region &region);
 LogicalResult peelForLoopAndSimplifyBounds(RewriterBase &rewriter, ForOp forOp,
                                            scf::ForOp &partialIteration);
+/// Rewrite a for loop with bounds/step that potentially do not divide the
+/// iteration space evenly into a chain of for loops where the step is a
+/// power of 2 and decreases exponentially across subsequent loops.
+/// E.g., assuming a lower bound of 0, the following loop
+/// ```
+/// scf.for %iv = %c0 to %ub step %c8 {
+///   (loop body)
+/// }
+/// ```
+/// is rewritten into the following pseudo IR:
+/// ```
+/// %newUb = %ub - (%ub mod %c8)
+/// scf.for %iv = %c0 to %newUb step %c8 {
+///   (loop body)
+/// }
+/// %newUb2 = %ub - (%ub mod %c4)
+/// scf.for %iv2 = %newUb to %newUb2 {
+///   (loop body)
+/// }
+/// %newUb3 = %ub - (%ub mod %c2)
+/// scf.for %iv2 = %newUb2 to %newUb3 {
+///   (loop body)
+/// }
+/// scf.for %iv2 = %newUb3 to %ub {
+///   (loop body)
+/// }
+/// ```
+/// Similar to loop peeling, this function simplifies the affine.min and
+/// affine.max ops in the body of each resulting for loop for better
+/// canonicalization opportunities.
+/// The return value indicates if the loop was rewritten. The loop
+/// is not rewritten if the step size is 1 or dynamic.
+continuousPeelForLoopAndSimplifyBounds(RewriterBase &rewriter, ForOp forOp,
+                                       scf::ForOp &partialIteration,
+                                       bool convertSingleIterLoopsToIf);
 /// Tile a parallel loop of the form
 ///   scf.parallel (%i0, %i1) = (%arg0, %arg1) to (%arg2, %arg3)
 ///                                             step (%arg4, %arg5)
diff --git a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
index dcba6a8b406b21..45b91f0d9b1fcf 100644
--- a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
+++ b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
@@ -210,146 +210,27 @@ transform::LoopPeelOp::applyToOne(transform::TransformRewriter &rewriter,
 // LoopContinuousPeelOp
-static LogicalResult splitLoopHelper(RewriterBase &b, scf::ForOp &forOp,
-                                     scf::ForOp &partialIteration,
-                                     Value &splitBound) {
-  RewriterBase::InsertionGuard guard(b);
-  auto lbInt = getConstantIntValue(forOp.getLowerBound());
-  auto ubInt = getConstantIntValue(forOp.getUpperBound());
-  auto stepInt = getConstantIntValue(forOp.getStep());
-  // No specialization necessary if step already divides upper bound evenly.
-  if (lbInt && ubInt && stepInt && (*ubInt - *lbInt) % *stepInt == 0)
-    return failure();
-  // No specialization necessary if step size is 1.
-  if (stepInt == static_cast<int64_t>(1))
-    return failure();
-  // Create ForOp for partial iteration.
-  b.setInsertionPointAfter(forOp);
-  partialIteration = cast<scf::ForOp>(b.clone(*forOp.getOperation()));
-  partialIteration.getLowerBoundMutable().assign(splitBound);
-  forOp.replaceAllUsesWith(partialIteration->getResults());
-  partialIteration.getInitArgsMutable().assign(forOp->getResults());
-  // Set new upper loop bound.
-  b.updateRootInPlace(
-      forOp, [&]() { forOp.getUpperBoundMutable().assign(splitBound); });
-  return success();
-static scf::IfOp convertSingleIterFor(RewriterBase &b, scf::ForOp &forOp) {
-  Location loc = forOp->getLoc();
-  IRMapping mapping;
-  mapping.map(forOp.getInductionVar(), forOp.getLowerBound());
-  for (auto [arg, operand] :
-       llvm::zip(forOp.getRegionIterArgs(), forOp.getInitsMutable())) {
-    mapping.map(arg, operand.get());
-  }
-  b.setInsertionPoint(forOp);
-  auto cond =
-      b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
-                              forOp.getLowerBound(), forOp.getUpperBound());
-  auto ifOp = b.create<scf::IfOp>(loc, forOp->getResultTypes(), cond, true);
-  // then branch
-  b.setInsertionPointToStart(ifOp.thenBlock());
-  for (Operation &op : forOp.getBody()->getOperations()) {
-    b.clone(op, mapping);
-  }
-  // else branch
-  b.setInsertionPointToStart(ifOp.elseBlock());
-  if (!forOp->getResultTypes().empty()) {
-    b.create<scf::YieldOp>(loc, forOp.getInits());
-  }
-  b.replaceOp(forOp, ifOp->getResults());
-  return ifOp;
 DiagnosedSilenceableFailure transform::LoopContinuousPeelOp::applyToOne(
     transform::TransformRewriter &rewriter, Operation *target,
     transform::ApplyToEachResultList &results,
     transform::TransformState &state) {
-  scf::ForOp loop, currentLoop, partialLoop;
+  scf::ForOp loop, result;
   loop = dyn_cast<scf::ForOp>(target);
-  auto lbInt = getConstantIntValue(loop.getLowerBound());
-  auto stepInt = getConstantIntValue(loop.getStep());
-  if (!stepInt.has_value() || *stepInt <= 0)
-    return DiagnosedSilenceableFailure::
-        definiteFailure(); // step size must be a known positive constant
-  Value initialUb = loop.getUpperBound();
-  Value initialStep = loop.getStep();
-  uint64_t loopStep = *stepInt;
-  currentLoop = loop;
-  AffineExpr sym0, sym1, sym2;
-  bindSymbols(rewriter.getContext(), sym0, sym1, sym2);
-  AffineMap defaultSplitMap =
-      AffineMap::get(0, 3, {sym1 - ((sym1 - sym0) % sym2)});
-  AffineMap powerSplitMap = AffineMap::get(0, 3, {sym1 - (sym1 % sym2)});
-  bool usePowerSplit = (lbInt.has_value()) &&
-                       (*lbInt % *stepInt == static_cast<int64_t>(0)) &&
-                       (loopStep == llvm::bit_floor(loopStep));
-  AffineMap splitMap = usePowerSplit ? powerSplitMap : defaultSplitMap;
-  SmallVector<scf::ForOp> loops;
-  while (loopStep) {
-    rewriter.setInsertionPoint(currentLoop);
-    auto constStepOp =
-        rewriter.create<arith::ConstantIndexOp>(currentLoop.getLoc(), loopStep);
-    currentLoop.getStepMutable().assign(constStepOp);
-    rewriter.setInsertionPoint(currentLoop);
-    Value splitBound = rewriter.createOrFold<affine::AffineApplyOp>(
-        currentLoop.getLoc(), splitMap,
-        ValueRange{currentLoop.getLowerBound(), currentLoop.getUpperBound(),
-                   currentLoop.getStep()});
-    LogicalResult status =
-        splitLoopHelper(rewriter, currentLoop, partialLoop, splitBound);
-    // Canonicalize min/max affine operations
-    // It uses scf::rewritePeeledMinMaxOp to identify operations to be replaced,
-    // they are then replaced by the current step size.
-    // TODO: Alternative method - update affine map to reflect the loop step
-    // Example: min(ub - iv, 8) -> min(ub - iv, 4)
-    currentLoop.walk([&](affine::AffineMinOp affineOp) {
-      rewriter.setInsertionPoint(affineOp);
-      auto clonedOp = cast<affine::AffineMinOp>(rewriter.clone(*affineOp));
-      LogicalResult result = scf::rewritePeeledMinMaxOp(
-          rewriter, clonedOp, currentLoop.getInductionVar(), initialUb,
-          initialStep,
-          /*insideLoop=*/true);
-      if (result.succeeded())
-        rewriter.replaceOp(affineOp, currentLoop.getStep());
-      else
-        rewriter.eraseOp(clonedOp); // to avoid infinite walk
-    });
-    currentLoop.walk([&](affine::AffineMaxOp affineOp) {
-      rewriter.setInsertionPoint(affineOp);
-      auto clonedOp = cast<affine::AffineMaxOp>(rewriter.clone(*affineOp));
-      LogicalResult result = scf::rewritePeeledMinMaxOp(
-          rewriter, clonedOp, currentLoop.getInductionVar(), initialUb,
-          initialStep,
-          /*insideLoop=*/true);
-      if (result.succeeded())
-        rewriter.replaceOp(affineOp, currentLoop.getStep());
-      else
-        rewriter.eraseOp(clonedOp); // to avoid infinite walk
-    });
+  bool convertSingleIterLoopsToIf = false;
-    // Prepare for the next iteration
-    loops.push_back(currentLoop);
-    if (failed(status))
-      break;
-    currentLoop = partialLoop;
-    uint64_t maxPower = llvm::bit_floor(loopStep);
-    loopStep = maxPower == loopStep ? maxPower >> 1 : maxPower;
-  }
-  assert(loops.size() > 0 && "There should be at least one loop available");
-  if (getSingleIterOpt()) {
-    for (size_t i = 1; i < loops.size(); ++i) {
-      convertSingleIterFor(rewriter, loops[i]);
-    }
+  if (getConvertSingleIterLoopsToIf())
+    convertSingleIterLoopsToIf = true;
+  LogicalResult status = scf::continuousPeelForLoopAndSimplifyBounds(
+      rewriter, loop, result, convertSingleIterLoopsToIf);
+  if (failed(status)) {
+    DiagnosedSilenceableFailure diag =
+        emitSilenceableError() << "failed to perform continuous peeling";
+    return diag;
-  results.push_back(loops.front());
+  results.push_back(loop);
+  results.push_back(result);
   return DiagnosedSilenceableFailure::success();
diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp
index f208e5245977d8..e2bc0e410878d4 100644
--- a/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp
@@ -28,6 +28,7 @@
 namespace mlir {
 #include "mlir/Dialect/SCF/Transforms/Passes.h.inc"
@@ -105,6 +106,165 @@ static void specializeForLoopForUnrolling(ForOp op) {
+static LogicalResult splitLoopHelper(RewriterBase &b, scf::ForOp &forOp,
+                                     scf::ForOp &partialIteration,
+                                     Value &splitBound) {
+  RewriterBase::InsertionGuard guard(b);
+  auto lbInt = getConstantIntValue(forOp.getLowerBound());
+  auto ubInt = getConstantIntValue(forOp.getUpperBound());
+  auto stepInt = getConstantIntValue(forOp.getStep());
+  // No specialization necessary if step already divides upper bound evenly.
+  if (lbInt && ubInt && stepInt && (*ubInt - *lbInt) % *stepInt == 0)
+    return failure();
+  // No specialization necessary if step size is 1.
+  if (stepInt == static_cast<int64_t>(1))
+    return failure();
+  // Create ForOp for partial iteration.
+  b.setInsertionPointAfter(forOp);
+  partialIteration = cast<scf::ForOp>(b.clone(*forOp.getOperation()));
+  partialIteration.getLowerBoundMutable().assign(splitBound);
+  forOp.replaceAllUsesWith(partialIteration->getResults());
+  partialIteration.getInitArgsMutable().assign(forOp->getResults());
+  // Set new upper loop bound.
+  b.updateRootInPlace(
+      forOp, [&]() { forOp.getUpperBoundMutable().assign(splitBound); });
+  return success();
+static scf::IfOp convertSingleIterFor(RewriterBase &b, scf::ForOp &forOp) {
+  Location loc = forOp->getLoc();
+  IRMapping mapping;
+  mapping.map(forOp.getInductionVar(), forOp.getLowerBound());
+  for (auto [arg, operand] :
+       llvm::zip(forOp.getRegionIterArgs(), forOp.getInitsMutable())) {
+    mapping.map(arg, operand.get());
+  }
+  b.setInsertionPoint(forOp);
+  auto cond =
+      b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
+                              forOp.getLowerBound(), forOp.getUpperBound());
+  auto ifOp = b.create<scf::IfOp>(loc, forOp->getResultTypes(), cond, true);
+  // then branch
+  b.setInsertionPointToStart(ifOp.thenBlock());
+  for (Operation &op : forOp.getBody()->getOperations()) {
+    b.clone(op, mapping);
+  }
+  // else branch
+  b.setInsertionPointToStart(ifOp.elseBlock());
+  if (!forOp->getResultTypes().empty()) {
+    b.create<scf::YieldOp>(loc, forOp.getInits());
+  }
+  b.replaceOp(forOp, ifOp->getResults());
+  return ifOp;
+/// Rewrite a for loop with bounds/step that potentially do not divide the
+/// iteration space evenly into a chain of for loops where the step is a
+/// power of 2 and decreases exponentially across subsequent loops. Helps
+/// divide the iteration space across all resulting peeled loops evenly.
+/// Optionally, convert all single iteration for loops to if-else
+/// blocks when convert_single_iter_loops_to_if attribute is set to true or
+/// alternatively with the convert-single-iter-loops-to-if option for the
+/// scf-for-loop-continuous-peeling pass.
+static LogicalResult continuousPeelForLoop(RewriterBase &b, ForOp forOp,
+                                           ForOp &partialIteration,
+                                           bool convertSingleIterLoopsToIf) {
+  scf::ForOp currentLoop;
+  auto lbInt = getConstantIntValue(forOp.getLowerBound());
+  auto stepInt = getConstantIntValue(forOp.getStep());
+  if (!stepInt.has_value() || *stepInt <= 0)
+    return failure(); // step size must be a known positive constant
+  Value initialUb = forOp.getUpperBound();
+  Value initialStep = forOp.getStep();
+  uint64_t loopStep = *stepInt;
+  currentLoop = forOp;
+  AffineExpr sym0, sym1, sym2;
+  bindSymbols(b.getContext(), sym0, sym1, sym2);
+  AffineMap defaultSplitMap =
+      AffineMap::get(0, 3, {sym1 - ((sym1 - sym0) % sym2)});
+  AffineMap powerSplitMap = AffineMap::get(0, 3, {sym1 - (sym1 % sym2)});
+  bool usePowerSplit = (lbInt.has_value()) &&
+                       (*lbInt % *stepInt == static_cast<int64_t>(0)) &&
+                       (loopStep == llvm::bit_floor(loopStep));
+  AffineMap splitMap = usePowerSplit ? powerSplitMap : defaultSplitMap;
+  SmallVector<scf::ForOp> loops;
+  while (loopStep) {
+    b.setInsertionPoint(currentLoop);
+    auto constStepOp =
+        b.create<arith::ConstantIndexOp>(currentLoop.getLoc(), loopStep);
+    currentLoop.getStepMutable().assign(constStepOp);
+    b.setInsertionPoint(currentLoop);
+    Value splitBound = b.createOrFold<affine::AffineApplyOp>(
+        currentLoop.getLoc(), splitMap,
+        ValueRange{currentLoop.getLowerBound(), currentLoop.getUpperBound(),
+                   currentLoop.getStep()});
+    LogicalResult status =
+        splitLoopHelper(b, currentLoop, partialIteration, splitBound);
+    // Canonicalize min/max affine operations
+    // It uses scf::rewritePeeledMinMaxOp to identify operations to be replaced,
+    // they are then replaced by the current step size.
+    // TODO: Alternative method - update affine map to reflect the loop step
+    // Example: min(ub - iv, 8) -> min(ub - iv, 4)
+    currentLoop.walk([&](affine::AffineMinOp affineOp) {
+      b.setInsertionPoint(affineOp);
+      auto clonedOp = cast<affine::AffineMinOp>(b.clone(*affineOp));
+      LogicalResult result = scf::rewritePeeledMinMaxOp(
+          b, clonedOp, currentLoop.getInductionVar(), initialUb, initialStep,
+          /*insideLoop=*/true);
+      if (result.succeeded())
+        b.replaceOp(affineOp, currentLoop.getStep());
+      else
+        b.eraseOp(clonedOp); // to avoid infinite walk
+    });
+    currentLoop.walk([&](affine::AffineMaxOp affineOp) {
+      b.setInsertionPoint(affineOp);
+      auto clonedOp = cast<affine::AffineMaxOp>(b.clone(*affineOp));
+      LogicalResult result = scf::rewritePeeledMinMaxOp(
+          b, clonedOp, currentLoop.getInductionVar(), initialUb, initialStep,
+          /*insideLoop=*/true);
+      if (result.succeeded())
+        b.replaceOp(affineOp, currentLoop.getStep());
+      else
+        b.eraseOp(clonedOp); // to avoid infinite walk
+    });
+    // Prepare for the next iteration
+    loops.push_back(currentLoop);
+    if (failed(status))
+      break;
+    currentLoop = partialIteration;
+    uint64_t maxPower = llvm::bit_floor(loopStep);
+    loopStep = maxPower == loopStep ? maxPower >> 1 : maxPower;
+  }
+  assert(loops.size() > 0 && "There should be at least one loop available");
+  if (convertSingleIterLoopsToIf) {
+    for (size_t i = 1; i < loops.size(); ++i) {
+      convertSingleIterFor(b, loops[i]);
+    }
+  }
+  return success();
+LogicalResult mlir::scf::continuousPeelForLoopAndSimplifyBounds(
+    RewriterBase &rewriter, ForOp forOp, ForOp &partialIteration,
+    bool convertSingleIterLoopsToIf) {
+  if (failed(continuousPeelForLoop(rewriter, forOp, partialIteration,
+                                   convertSingleIterLoopsToIf)))
+    return failure();
+  return success();
 /// Rewrite a for loop with bounds/step that potentially do not divide evenly
 /// into a for loop where the step divides the iteration space evenly, followed
 /// by an scf.if for the last (partial) iteration (if any).
@@ -241,6 +401,45 @@ struct ForLoopPeelingPattern : public OpRewritePattern<ForOp> {
 } // namespace
+namespace {
+struct ForLoopContinuousPeelingPattern : public OpRewritePattern<ForOp> {
+  ForLoopContinuousPeelingPattern(MLIRContext *ctx,
+                                  bool convertSingleIterLoopsToIf)
+      : OpRewritePattern<ForOp>(ctx),
+        convertSingleIterLoopsToIf(convertSingleIterLoopsToIf) {}
+  LogicalResult matchAndRewrite(ForOp forOp,
+                                PatternRewriter &rewriter) const override {
+    // Do not peel already peeled loops.
+    if (forOp->hasAttr(kPeeledLoopLabel))
+      return failure();
+    // Apply continuous loop peeling.
+    scf::ForOp partialIteration;
+    if (failed(continuousPeelForLoopAndSimplifyBounds(
+            rewriter, forOp, partialIteration, convertSingleIterLoopsToIf)))
+      return failure();
+    rewriter.updateRootInPlace(partialIteration, [&]() {
+      partialIteration->setAttr(kPeeledLoopLabel, rewriter.getUnitAttr());
+      partialIteration->setAttr(kPartialIterationLabel, rewriter.getUnitAttr());
+    });
+    rewriter.updateRootInPlace(forOp, [&]() {
+      forOp->setAttr(kPeeledLoopLabel, rewriter.getUnitAttr());
+    });
+    return success();
+  }
+  /// If set to true, loops inside partial iterations of another peeled loop
+  /// are not peeled. This reduces the size of the generated code. Partial
+  /// iterations are not usually performance critical.
+  /// Note: Takes into account the entire chain of parent operations, not just
+  /// the direct parent.
+  bool convertSingleIterLoopsToIf;
+} // namespace
 namespace {
 struct ParallelLoopSpecialization
     : public impl::SCFParallelLoopSpecializationBase<
@@ -273,6 +472,24 @@ struct ForLoopPeeling : public impl::SCFForLoopPeelingBase<ForLoopPeeling> {
+struct ForLoopContinuousPeeling
+    : public impl::SCFForLoopContinuousPeelingBase<ForLoopContinuousPeeling> {
+  void runOnOperation() override {
+    auto *parentOp = getOperation();
+    MLIRContext *ctx = parentOp->getContext();
+    RewritePatternSet patterns(ctx);
+    patterns.add<ForLoopContinuousPeelingPattern>(ctx,
+                                                  convertSingleIterLoopsToIf);
+    (void)applyPatternsAndFoldGreedily(parentOp, std::move(patterns));
+    // Drop the markers.
+    parentOp->walk([](Operation *op) {
+      op->removeAttr(kPeeledLoopLabel);
+      op->removeAttr(kPartialIterationLabel);
+    });
+  }
 } // namespace
 std::unique_ptr<Pass> mlir::createParallelLoopSpecializationPass() {
@@ -286,3 +503,7 @@ std::unique_ptr<Pass> mlir::createForLoopSpecializationPass() {
 std::unique_ptr<Pass> mlir::createForLoopPeelingPass() {
   return std::make_unique<ForLoopPeeling>();
+std::unique_ptr<Pass> mlir::createForLoopContinuousPeelingPass() {
+  return std::make_unique<ForLoopContinuousPeeling>();
diff --git a/mlir/test/Dialect/SCF/loop-continuous-peel.mlir b/mlir/test/Dialect/SCF/loop-continuous-peel.mlir
index 752e1b1efed92a..e051e6a43be70e 100644
--- a/mlir/test/Dialect/SCF/loop-continuous-peel.mlir
+++ b/mlir/test/Dialect/SCF/loop-continuous-peel.mlir
@@ -1,98 +1,46 @@
-// RUN: mlir-opt %s --transform-interpreter -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -scf-for-loop-continuous-peeling=convert-single-iter-loops-to-if=true -split-input-file | FileCheck %s
-#map = affine_map<(d0) -> ()>
-#map1 = affine_map<(d0) -> (d0)>
-module {
-  func.func @foo(%arg0: f32, %arg1: tensor<?xf32>) -> tensor<?xf32> {
-    %0 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel"]} ins(%arg0 : f32) outs(%arg1 : tensor<?xf32>) {
-    ^bb0(%in: f32, %out: f32):
-      %3 = arith.mulf %in, %out : f32
-      linalg.yield %3 : f32
-    } -> tensor<?xf32>
-    return %0 : tensor<?xf32>
-  }
-module attributes {transform.with_named_sequence} {
-  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
-    %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
-    %1, %loops = transform.structured.tile_using_for %0[8] : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
-    %2 = transform.cast %loops : !transform.any_op to !transform.op<"scf.for">
-    %3 = transform.loop.loop_continuous_peel %2 {single_iter_opt = true} : (!transform.op<"scf.for">) -> (!transform.any_op)
-    transform.yield
+#map = affine_map<(d0, d1)[s0] -> (s0, d0 - d1)>
+func.func @foo(%ub: index) -> index {
+  %c0 = arith.constant 0 : index
+  %step = arith.constant 8 : index
+  %0 = scf.for %iv = %c0 to %ub step %step iter_args(%arg = %c0) -> (index) {
+    %1 = affine.min #map(%ub, %iv)[%step]
+    %2 = index.add %1, %arg
+    scf.yield %2 : index
+  return %0 : index
 // CHECK: #[[MAP:.*]] = affine_map<()[s0, s1, s2] -> (s1 - s1 mod s2)>
-// CHECK: #[[MAP1:.*]] = affine_map<() -> (8)>
-// CHECK: #[[MAP2:.*]] = affine_map<(d0) -> (d0 - 1)>
-// CHECK: #[[MAP3:.*]] = affine_map<(d0) -> ()>
-// CHECK: #[[MAP4:.*]] = affine_map<(d0) -> (d0)>
-// CHECK: func.func @foo(%[[S:.*]]: f32, %[[INVEC1:.*]]: tensor<?xf32>) -> tensor<?xf32> {
-// CHECK:       %[[C0:.*]] = arith.constant 0 : index
-// CHECK:       %[[DIM:.*]] = tensor.dim %[[INVEC1]], %[[C0]] : tensor<?xf32>
-// CHECK:       %[[C0:.*]] = arith.constant 0 : index
-// CHECK:       %{{.*}} = arith.constant 8 : index
-// CHECK:       %[[C8:.*]] = arith.constant 8 : index
-// CHECK:       %[[IDX0:.*]] = affine.apply #[[MAP]]()[%[[C0]], %[[DIM]], %[[C8]]]
-// CHECK:       %[[INS1:.*]] = scf.for %[[IDX:.*]] = %[[C0]] to %[[IDX0]] step %[[C8]] iter_args(%[[AINVEC1:.*]] = %[[INVEC1]]) -> (tensor<?xf32>) {
-// CHECK:         %{{.*}} = affine.apply #[[MAP2]](%[[C8]])
-// CHECK:         %[[XS8:.*]] = tensor.extract_slice %[[AINVEC1]][%[[IDX]]] [%[[C8]]] [1] : tensor<?xf32> to tensor<?xf32>
-// CHECK:         %[[MUL:.*]] = linalg.generic {indexing_maps = [#[[MAP3]], #[[MAP4]]], iterator_types = ["parallel"]} ins(%{{.*}} : f32) outs(%[[XS8]] : tensor<?xf32>) {
-// CHECK:         ^bb0(%{{.*}}: f32, %{{.*}}: f32):
-// CHECK:           %{{.*}} = arith.mulf %{{.*}}, %{{.*}} : f32
-// CHECK:           linalg.yield %{{.*}} : f32
-// CHECK:         } -> tensor<?xf32>
-// CHECK:         %[[INS:.*]] = tensor.insert_slice %[[MUL]] into %[[AINVEC1]][%[[IDX]]] [%[[C8]]] [1] : tensor<?xf32> into tensor<?xf32>
-// CHECK:         scf.yield %[[INS]] : tensor<?xf32>
-// CHECK:       }
-// CHECK:       %[[C4:.*]] = arith.constant 4 : index
-// CHECK:       %[[IDX2:.*]] = affine.apply #[[MAP]]()[%[[IDX0]], %[[DIM]], %[[C4]]]
-// CHECK:       %[[CMP3:.*]] = arith.cmpi slt, %[[IDX0]], %[[IDX2]] : index
-// CHECK:       %[[INS2:.*]] = scf.if %[[CMP3]] -> (tensor<?xf32>) {
-// CHECK:          %{{.*}} = affine.apply #[[MAP2]](%[[C4]])
-// CHECK:         %[[XS4:.*]] = tensor.extract_slice %[[INS1]][%[[IDX0]]] [%[[C4]]] [1] : tensor<?xf32> to tensor<?xf32>
-// CHECK:         %[[MUL:.*]] = linalg.generic {indexing_maps = [#[[MAP3]], #[[MAP4]]], iterator_types = ["parallel"]} ins(%[[S]] : f32) outs(%[[XS4]] : tensor<?xf32>) {
-// CHECK:         ^bb0(%{{.*}}: f32, %{{.*}}: f32):
-// CHECK:           %{{.*}} = arith.mulf  %{{.*}},  %{{.*}} : f32
-// CHECK:           linalg.yield  %{{.*}} : f32
-// CHECK:         } -> tensor<?xf32>
-// CHECK:         %[[INS:.*]] = tensor.insert_slice %[[MUL]] into %[[INS1]][%[[IDX0]]] [%[[C4]]] [1] : tensor<?xf32> into tensor<?xf32>
-// CHECK:         scf.yield %[[INS]] : tensor<?xf32>
-// CHECK:       } else {
-// CHECK:         scf.yield %[[INS1]] : tensor<?xf32>
-// CHECK:       }
-// CHECK:       %[[C2:.*]] = arith.constant 2 : index
-// CHECK:       %[[IDX3:.*]] = affine.apply #[[MAP]]()[%[[IDX2]], %[[DIM]], %[[C2]]]
-// CHECK:       %[[CMP4:.*]] = arith.cmpi slt, %[[IDX2]], %[[IDX3]] : index
-// CHECK:       %[[INS3:.*]] = scf.if %[[CMP4]] -> (tensor<?xf32>) {
-// CHECK:         %{{.*}} = affine.apply #[[MAP2]](%[[C2]])
-// CHECK:         %[[XS2:.*]] = tensor.extract_slice %[[INS2]][%[[IDX2]]] [%[[C2]]] [1] : tensor<?xf32> to tensor<?xf32>
-// CHECK:         %[[MUL:.*]] = linalg.generic {indexing_maps = [#[[MAP3]], #[[MAP4]]], iterator_types = ["parallel"]} ins(%[[S]] : f32) outs(%[[XS2]] : tensor<?xf32>) {
-// CHECK:         ^bb0(%{{.*}}: f32, %{{.*}}: f32):
-// CHECK:           %{{.*}} = arith.mulf %{{.*}}, %{{.*}} : f32
-// CHECK:           linalg.yield %{{.*}} : f32
-// CHECK:         } -> tensor<?xf32>
-// CHECK:         %[[INS:.*]] = tensor.insert_slice %[[MUL]] into %[[INS2]][%[[IDX2]]] [%[[C2]]] [1] : tensor<?xf32> into tensor<?xf32>
-// CHECK:         scf.yield %[[INS]] : tensor<?xf32>
-// CHECK:       } else {
-// CHECK:         scf.yield %[[INS2]] : tensor<?xf32>
-// CHECK:       }
-// CHECK:       %[[C1:.*]] = arith.constant 1 : index
-// CHECK:       %{{.*}} = affine.apply #[[MAP]]()[%[[IDX3]], %[[DIM]], %[[C1]]]
-// CHECK:       %[[CMP5:.*]] = arith.cmpi slt, %[[IDX3]], %[[DIM]] : index
-// CHECK:       %[[INS4:.*]] = scf.if %[[CMP5]] -> (tensor<?xf32>) {
-// CHECK:         %{{.*}} = affine.apply #[[MAP2]](%[[C1]])
-// CHECK:         %[[XS1:.*]] = tensor.extract_slice %[[INS3]][%[[IDX3]]] [%[[C1]]] [1] : tensor<?xf32> to tensor<?xf32>
-// CHECK:         %[[MUL:.*]] = linalg.generic {indexing_maps = [#[[MAP3]], #[[MAP4]]], iterator_types = ["parallel"]} ins(%[[S]] : f32) outs(%[[XS1]] : tensor<?xf32>) {
-// CHECK:         ^bb0(%{{.*}}: f32, %{{.*}}: f32):
-// CHECK:           %{{.*}} = arith.mulf %{{.*}}, %{{.*}} : f32
-// CHECK:           linalg.yield %{{.*}} : f32
-// CHECK:         } -> tensor<?xf32>
-// CHECK:         %[[INS:.*]] = tensor.insert_slice %[[MUL]] into %[[INS3]][%[[IDX3]]] [%[[C1]]] [1] : tensor<?xf32> into tensor<?xf32>
-// CHECK:         scf.yield %[[INS]] : tensor<?xf32>
-// CHECK:       } else {
-// CHECK:         scf.yield %[[INS3]] : tensor<?xf32>
-// CHECK:       }
-// CHECK:       return %[[INS4]] : tensor<?xf32>
+// CHECK: func.func @foo(%[[UB:.*]]: index) -> index {
+// CHECK: %[[STEP8:.*]] = arith.constant 8 : index
+// CHECK: %[[STEP4:.*]] = arith.constant 4 : index
+// CHECK: %[[STEP2:.*]] = arith.constant 2 : index
+// CHECK: %[[STEP1:.*]] = arith.constant 1 : index
+// CHECK: %[[LB:.*]] = arith.constant 0 : index
+// CHECK: %[[I0:.*]] = affine.apply #[[MAP]]()[%[[LB]], %[[UB]], %[[STEP8]]]
+// CHECK: %[[I1:.*]] = scf.for %{{.*}} = %[[LB]] to %[[I0]] step %[[STEP8]] iter_args(%[[ALB:.*]] = %[[LB]]) -> (index) {
+// CHECK: %[[SUM:.*]] = index.add %[[ALB]], %[[STEP8]]
+// CHECK: scf.yield %[[SUM]] : index
+// CHECK: %[[I2:.*]] = affine.apply #[[MAP]]()[%[[I0]], %[[UB]], %[[STEP4]]]
+// CHECK: %[[I3:.*]] = arith.cmpi slt, %[[I0]], %[[I2]] : index
+// CHECK: %[[I4:.*]] = scf.if %[[I3]] -> (index) {
+// CHECK: %[[SUM:.*]] = index.add %[[I1]], %[[STEP4]]
+// CHECK: scf.yield %[[SUM]] : index
+// CHECK: } else {
+// CHECK: scf.yield %[[I1]] : index
+// CHECK: %[[I5:.*]] = affine.apply #[[MAP]]()[%[[I2]], %[[UB]], %[[STEP2]]]
+// CHECK: %[[I6:.*]] = arith.cmpi slt, %[[I2]], %[[I5]] : index
+// CHECK: %[[I7:.*]] = scf.if %[[I6]] -> (index) {
+// CHECK: %[[SUM:.*]] = index.add %[[I4]], %[[STEP2]]
+// CHECK: scf.yield %[[SUM]] : index
+// CHECK: } else {
+// CHECK: scf.yield %[[I4]] : index
+// CHECK: %[[I8:.*]] = arith.cmpi slt, %[[I5]], %[[UB]] : index
+// CHECK: %[[I9:.*]] = scf.if %[[I8]] -> (index) {
+// CHECK: %[[SUM:.*]] = index.add %[[I7]], %[[STEP1]]
+// CHECK: scf.yield %[[SUM]] : index
+// CHECK: } else {
+// CHECK: scf.yield %[[I7]] : index
+// CHECK: return %[[I9]] : index

>From 3a7d8b0c73d06a4b07ab654d69693af03a7f8e16 Mon Sep 17 00:00:00 2001
From: Muneeb Khan <muneeb.khan at huawei.com>
Date: Tue, 7 Nov 2023 23:52:17 +0800
Subject: [PATCH 3/9] [MLIR][LLVM][Fixes] Add Continuous Loop Peeling transform
 to SCF

Add case for step size 1

This commit should be squashed into the original.
 mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp | 7 +++++--
 1 file changed, 5 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp
index e2bc0e410878d4..1c32b483d0d192 100644
--- a/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp
@@ -178,8 +178,11 @@ static LogicalResult continuousPeelForLoop(RewriterBase &b, ForOp forOp,
   scf::ForOp currentLoop;
   auto lbInt = getConstantIntValue(forOp.getLowerBound());
   auto stepInt = getConstantIntValue(forOp.getStep());
-  if (!stepInt.has_value() || *stepInt <= 0)
-    return failure(); // step size must be a known positive constant
+  // Step size must be a known positive constant greater than 1.
+  if (stepInt && stepInt <= static_cast<int64_t>(1))
+    return failure();
   Value initialUb = forOp.getUpperBound();
   Value initialStep = forOp.getStep();
   uint64_t loopStep = *stepInt;

>From 57062458bce81a06180969874ec1f04f50cd8d1b Mon Sep 17 00:00:00 2001
From: Muneeb Khan <muneeb.khan at huawei.com>
Date: Fri, 10 Nov 2023 20:42:20 +0800
Subject: [PATCH 4/9] [MLIR][LLVM][Fixes] Add Continuous Loop Peeling transform
 to SCF

Removed redundant TODO comment.

This commit should be squashed into the original.
 mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
index 10602271a9aa52..85d2086b1f49ef 100644
--- a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
+++ b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
@@ -168,7 +168,7 @@ def LoopContinuousPeelOp : Op<Transform_Dialect, "loop.loop_continuous_peel",
   let arguments =
       (ins TransformHandleTypeInterface:$target,
            DefaultValuedAttr<BoolAttr, "false">:$convert_single_iter_loops_to_if);
-  // TODO: Return both the peeled loop and the remainder loop.
   let results = (outs TransformHandleTypeInterface:$peeled_loop,

>From 9f2f8e36d9dd352c7a63629408c9708055417320 Mon Sep 17 00:00:00 2001
From: Muneeb Khan <muneeb.khan at huawei.com>
Date: Fri, 10 Nov 2023 23:38:29 +0800
Subject: [PATCH 5/9] [MLIR][LLVM][Fixes] Add Continuous Loop Peeling transform
 to SCF

Renamed test to conform to other similar tests.

This commit should be squashed into the original.
 ...op-continuous-peel.mlir => for-loop-continuous-peeling.mlir} | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)
 rename mlir/test/Dialect/SCF/{loop-continuous-peel.mlir => for-loop-continuous-peeling.mlir} (98%)

diff --git a/mlir/test/Dialect/SCF/loop-continuous-peel.mlir b/mlir/test/Dialect/SCF/for-loop-continuous-peeling.mlir
similarity index 98%
rename from mlir/test/Dialect/SCF/loop-continuous-peel.mlir
rename to mlir/test/Dialect/SCF/for-loop-continuous-peeling.mlir
index e051e6a43be70e..37a16c6dd094b7 100644
--- a/mlir/test/Dialect/SCF/loop-continuous-peel.mlir
+++ b/mlir/test/Dialect/SCF/for-loop-continuous-peeling.mlir
@@ -43,4 +43,4 @@ func.func @foo(%ub: index) -> index {
 // CHECK: scf.yield %[[SUM]] : index
 // CHECK: } else {
 // CHECK: scf.yield %[[I7]] : index
-// CHECK: return %[[I9]] : index
+// CHECK: return %[[I9]] : index
\ No newline at end of file

>From 4a6830333e2e50a1a2e9d42bd3fab3f6c41a9fa1 Mon Sep 17 00:00:00 2001
From: Muneeb Khan <muneeb.khan at huawei.com>
Date: Tue, 14 Nov 2023 23:44:35 +0800
Subject: [PATCH 6/9] [MLIR][LLVM][Fixes] Add Continuous Loop Peeling transform
 to SCF

1. Added comment for splitLoopHelper
2. Added comment for convertSingleIterFor.
3. Some minor changes as per  review suggestion.

This commit should be squashed into the original.
 mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp | 10 ++++++++--
 1 file changed, 8 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp
index 1c32b483d0d192..76ca7856e11768 100644
--- a/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp
@@ -106,7 +106,11 @@ static void specializeForLoopForUnrolling(ForOp op) {
-static LogicalResult splitLoopHelper(RewriterBase &b, scf::ForOp &forOp,
+/// Create a new for loop for the remaining iterations (partiaIteration)
+/// after a for loop has been peeled. This is followed by correcting the
+/// loop bounds for both loops given the index (splitBound) where the
+/// iteration space is to be split up.
+static LogicalResult splitLoopHelper(RewriterBase &b, scf::ForOp forOp,
                                      scf::ForOp &partialIteration,
                                      Value &splitBound) {
   RewriterBase::InsertionGuard guard(b);
@@ -135,6 +139,7 @@ static LogicalResult splitLoopHelper(RewriterBase &b, scf::ForOp &forOp,
   return success();
+/// Convert single-iteration for loop to if-else block.
 static scf::IfOp convertSingleIterFor(RewriterBase &b, scf::ForOp &forOp) {
   Location loc = forOp->getLoc();
   IRMapping mapping;
@@ -180,7 +185,7 @@ static LogicalResult continuousPeelForLoop(RewriterBase &b, ForOp forOp,
   auto stepInt = getConstantIntValue(forOp.getStep());
   // Step size must be a known positive constant greater than 1.
-  if (stepInt && stepInt <= static_cast<int64_t>(1))
+  if (!stepInt || stepInt <= static_cast<int64_t>(1))
     return failure();
   Value initialUb = forOp.getUpperBound();
@@ -249,6 +254,7 @@ static LogicalResult continuousPeelForLoop(RewriterBase &b, ForOp forOp,
   assert(loops.size() > 0 && "There should be at least one loop available");
   if (convertSingleIterLoopsToIf) {
+    // Loops following the first loop will always be single-iteration.
     for (size_t i = 1; i < loops.size(); ++i) {
       convertSingleIterFor(b, loops[i]);

>From 1cfe41741ba0eb92e24861ba8bfd239059554540 Mon Sep 17 00:00:00 2001
From: Muneeb Khan <muneeb.khan at huawei.com>
Date: Tue, 12 Dec 2023 23:38:03 +0800
Subject: [PATCH 7/9] [MLIR][LLVM][Fixes] Add Continuous Loop Peeling transform
 to SCF

Use inlining function instead of manually cloning operations from
replaced for loop.

This commit should be squashed into the original.
 .../SCF/Transforms/LoopSpecialization.cpp     | 30 ++++++++-----------
 1 file changed, 13 insertions(+), 17 deletions(-)

diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp
index 76ca7856e11768..7a7bffd0294123 100644
--- a/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp
@@ -154,10 +154,15 @@ static scf::IfOp convertSingleIterFor(RewriterBase &b, scf::ForOp &forOp) {
                               forOp.getLowerBound(), forOp.getUpperBound());
   auto ifOp = b.create<scf::IfOp>(loc, forOp->getResultTypes(), cond, true);
   // then branch
-  b.setInsertionPointToStart(ifOp.thenBlock());
-  for (Operation &op : forOp.getBody()->getOperations()) {
-    b.clone(op, mapping);
-  }
+  ifOp->moveBefore(forOp);
+  // Replace block arguments with lower bound (replacement for IV) and
+  // iter_args.
+  SmallVector<Value> bbArgReplacements;
+  bbArgReplacements.push_back(forOp.getLowerBound());
+  llvm::append_range(bbArgReplacements, forOp.getInitArgs());
+  b.inlineBlockBefore(forOp.getBody(), ifOp.thenBlock(),
+                      ifOp.thenBlock()->begin(), bbArgReplacements);
   // else branch
   if (!forOp->getResultTypes().empty()) {
@@ -220,20 +225,11 @@ static LogicalResult continuousPeelForLoop(RewriterBase &b, ForOp forOp,
     // they are then replaced by the current step size.
     // TODO: Alternative method - update affine map to reflect the loop step
     // Example: min(ub - iv, 8) -> min(ub - iv, 4)
-    currentLoop.walk([&](affine::AffineMinOp affineOp) {
-      b.setInsertionPoint(affineOp);
-      auto clonedOp = cast<affine::AffineMinOp>(b.clone(*affineOp));
-      LogicalResult result = scf::rewritePeeledMinMaxOp(
-          b, clonedOp, currentLoop.getInductionVar(), initialUb, initialStep,
-          /*insideLoop=*/true);
-      if (result.succeeded())
-        b.replaceOp(affineOp, currentLoop.getStep());
-      else
-        b.eraseOp(clonedOp); // to avoid infinite walk
-    });
-    currentLoop.walk([&](affine::AffineMaxOp affineOp) {
+    currentLoop.walk([&](Operation *affineOp) {
+      if (!isa<AffineMinOp, AffineMaxOp>(affineOp))
+        return WalkResult::advance();
-      auto clonedOp = cast<affine::AffineMaxOp>(b.clone(*affineOp));
+      auto clonedOp = b.clone(*affineOp);
       LogicalResult result = scf::rewritePeeledMinMaxOp(
           b, clonedOp, currentLoop.getInductionVar(), initialUb, initialStep,

>From 08790ff090d4a36338f7ae4d55fbff0a9d3feb17 Mon Sep 17 00:00:00 2001
From: Muneeb Khan <muneeb.khan at huawei.com>
Date: Thu, 14 Dec 2023 17:40:39 +0800
Subject: [PATCH 8/9] Revert "[MLIR][LLVM][Fixes] Add Continuous Loop Peeling
 transform to SCF"

This reverts commit 1cfe41741ba0eb92e24861ba8bfd239059554540.
 .../SCF/Transforms/LoopSpecialization.cpp     | 30 +++++++++++--------
 1 file changed, 17 insertions(+), 13 deletions(-)

diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp
index bc5a6f58c1644f..cb567408f66234 100644
--- a/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp
@@ -154,15 +154,10 @@ static scf::IfOp convertSingleIterFor(RewriterBase &b, scf::ForOp &forOp) {
                               forOp.getLowerBound(), forOp.getUpperBound());
   auto ifOp = b.create<scf::IfOp>(loc, forOp->getResultTypes(), cond, true);
   // then branch
-  ifOp->moveBefore(forOp);
-  // Replace block arguments with lower bound (replacement for IV) and
-  // iter_args.
-  SmallVector<Value> bbArgReplacements;
-  bbArgReplacements.push_back(forOp.getLowerBound());
-  llvm::append_range(bbArgReplacements, forOp.getInitArgs());
-  b.inlineBlockBefore(forOp.getBody(), ifOp.thenBlock(),
-                      ifOp.thenBlock()->begin(), bbArgReplacements);
+  b.setInsertionPointToStart(ifOp.thenBlock());
+  for (Operation &op : forOp.getBody()->getOperations()) {
+    b.clone(op, mapping);
+  }
   // else branch
   if (!forOp->getResultTypes().empty()) {
@@ -225,11 +220,20 @@ static LogicalResult continuousPeelForLoop(RewriterBase &b, ForOp forOp,
     // they are then replaced by the current step size.
     // TODO: Alternative method - update affine map to reflect the loop step
     // Example: min(ub - iv, 8) -> min(ub - iv, 4)
-    currentLoop.walk([&](Operation *affineOp) {
-      if (!isa<AffineMinOp, AffineMaxOp>(affineOp))
-        return WalkResult::advance();
+    currentLoop.walk([&](affine::AffineMinOp affineOp) {
+      b.setInsertionPoint(affineOp);
+      auto clonedOp = cast<affine::AffineMinOp>(b.clone(*affineOp));
+      LogicalResult result = scf::rewritePeeledMinMaxOp(
+          b, clonedOp, currentLoop.getInductionVar(), initialUb, initialStep,
+          /*insideLoop=*/true);
+      if (result.succeeded())
+        b.replaceOp(affineOp, currentLoop.getStep());
+      else
+        b.eraseOp(clonedOp); // to avoid infinite walk
+    });
+    currentLoop.walk([&](affine::AffineMaxOp affineOp) {
-      auto clonedOp = b.clone(*affineOp);
+      auto clonedOp = cast<affine::AffineMaxOp>(b.clone(*affineOp));
       LogicalResult result = scf::rewritePeeledMinMaxOp(
           b, clonedOp, currentLoop.getInductionVar(), initialUb, initialStep,

>From f857cfbfb4ea277b7f7c5a33b6eb22dcae005855 Mon Sep 17 00:00:00 2001
From: Muneeb Khan <muneeb.khan at huawei.com>
Date: Thu, 14 Dec 2023 19:02:26 +0800
Subject: [PATCH 9/9] [MLIR][LLVM][Fixes] Add Continuous Loop Peeling transform
 to SCF

Clean up the rewriting of the affine op in the partial iterations.
This required changing the return type of rewritePeeledMinMaxOp
from LogicalResult to FailureOr<AffineApplyOp> to pass on the
newly created affine op.

This commit should be squashed into the original.
 .../SCF/Utils/AffineCanonicalizationUtils.h   |  7 +++--
 .../SCF/Transforms/LoopSpecialization.cpp     | 31 ++++++-------------
 .../SCF/Utils/AffineCanonicalizationUtils.cpp | 10 +++---
 3 files changed, 19 insertions(+), 29 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SCF/Utils/AffineCanonicalizationUtils.h b/mlir/include/mlir/Dialect/SCF/Utils/AffineCanonicalizationUtils.h
index 5022cbf898ec17..8ef1d0a5ffe803 100644
--- a/mlir/include/mlir/Dialect/SCF/Utils/AffineCanonicalizationUtils.h
+++ b/mlir/include/mlir/Dialect/SCF/Utils/AffineCanonicalizationUtils.h
@@ -14,6 +14,7 @@
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Support/LLVM.h"
 #include "mlir/Support/LogicalResult.h"
@@ -75,9 +76,9 @@ LogicalResult canonicalizeMinMaxOpInLoop(RewriterBase &rewriter, Operation *op,
 /// ```
 /// min/max operations inside the partial iteration are rewritten in a similar
 /// way.
-LogicalResult rewritePeeledMinMaxOp(RewriterBase &rewriter, Operation *op,
-                                    Value iv, Value ub, Value step,
-                                    bool insideLoop);
+rewritePeeledMinMaxOp(RewriterBase &rewriter, Operation *op, Value iv, Value ub,
+                      Value step, bool insideLoop);
 } // namespace scf
 } // namespace mlir
diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp
index cb567408f66234..a2baf243a01245 100644
--- a/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp
@@ -220,27 +220,16 @@ static LogicalResult continuousPeelForLoop(RewriterBase &b, ForOp forOp,
     // they are then replaced by the current step size.
     // TODO: Alternative method - update affine map to reflect the loop step
     // Example: min(ub - iv, 8) -> min(ub - iv, 4)
-    currentLoop.walk([&](affine::AffineMinOp affineOp) {
-      b.setInsertionPoint(affineOp);
-      auto clonedOp = cast<affine::AffineMinOp>(b.clone(*affineOp));
-      LogicalResult result = scf::rewritePeeledMinMaxOp(
-          b, clonedOp, currentLoop.getInductionVar(), initialUb, initialStep,
-          /*insideLoop=*/true);
-      if (result.succeeded())
-        b.replaceOp(affineOp, currentLoop.getStep());
-      else
-        b.eraseOp(clonedOp); // to avoid infinite walk
-    });
-    currentLoop.walk([&](affine::AffineMaxOp affineOp) {
-      b.setInsertionPoint(affineOp);
-      auto clonedOp = cast<affine::AffineMaxOp>(b.clone(*affineOp));
-      LogicalResult result = scf::rewritePeeledMinMaxOp(
-          b, clonedOp, currentLoop.getInductionVar(), initialUb, initialStep,
-          /*insideLoop=*/true);
-      if (result.succeeded())
-        b.replaceOp(affineOp, currentLoop.getStep());
-      else
-        b.eraseOp(clonedOp); // to avoid infinite walk
+    currentLoop.walk([&](Operation *affineOp) {
+      if (isa<AffineMinOp, AffineMaxOp>(affineOp)) {
+        FailureOr<AffineApplyOp> result = scf::rewritePeeledMinMaxOp(
+            b, affineOp, currentLoop.getInductionVar(), initialUb, initialStep,
+            /*insideLoop=*/true);
+        // correct the step of the newly created affine op
+        if (!failed(result))
+          b.replaceOp(result.value(), currentLoop.getStep());
+      }
+      return WalkResult::advance();
     // Prepare for the next iteration
diff --git a/mlir/lib/Dialect/SCF/Utils/AffineCanonicalizationUtils.cpp b/mlir/lib/Dialect/SCF/Utils/AffineCanonicalizationUtils.cpp
index 89dfb61d48af94..7c748ab5c9b8f1 100644
--- a/mlir/lib/Dialect/SCF/Utils/AffineCanonicalizationUtils.cpp
+++ b/mlir/lib/Dialect/SCF/Utils/AffineCanonicalizationUtils.cpp
@@ -14,7 +14,6 @@
 #include "mlir/Dialect/Affine/Analysis/AffineStructures.h"
 #include "mlir/Dialect/Affine/Analysis/Utils.h"
-#include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Affine/IR/AffineValueMap.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Dialect/SCF/Utils/AffineCanonicalizationUtils.h"
@@ -189,16 +188,17 @@ LogicalResult scf::canonicalizeMinMaxOpInLoop(RewriterBase &rewriter,
 /// * Inside the peeled loop: min(step, ub - iv) == step
 /// * Inside the partial iteration: min(step, ub - iv) == ub - iv
-/// Returns `success` if the given operation was replaced by a new operation;
+/// Returns the new Affine op if the operation was replaced by a new operation;
 /// `failure` otherwise.
 /// Note: `ub` is the previous upper bound of the loop (before peeling).
 /// `insideLoop` must be true for min/max ops inside the loop and false for
 /// affine.min ops inside the partial iteration. For an explanation of the other
 /// parameters, see comment of `canonicalizeMinMaxOpInLoop`.
-LogicalResult scf::rewritePeeledMinMaxOp(RewriterBase &rewriter, Operation *op,
-                                         Value iv, Value ub, Value step,
-                                         bool insideLoop) {
+FailureOr<AffineApplyOp> scf::rewritePeeledMinMaxOp(RewriterBase &rewriter,
+                                                    Operation *op, Value iv,
+                                                    Value ub, Value step,
+                                                    bool insideLoop) {
   FlatAffineValueConstraints constraints;
   constraints.appendSymbolVar({ub, step});

More information about the lldb-commits mailing list