[mlir] [clang] [llvm] [flang] [MLIR][LLVM] Add Continuous Loop Peeling transform to SCF (PR #77328)

via llvm-commits llvm-commits at lists.llvm.org
Tue Jan 9 02:50:44 PST 2024


https://github.com/muneebkhan85 updated https://github.com/llvm/llvm-project/pull/77328

>From 91418f78aad0c994a49e9305516bd6bbb3d69d65 Mon Sep 17 00:00:00 2001
From: Muneeb Khan <muneeb.khan at huawei.com>
Date: Tue, 7 Nov 2023 23:52:17 +0800
Subject: [PATCH 01/13] [MLIR][LLVM] Add Continuous Loop Peeling transform to
 SCF

This patch adds continuous loop peeling to scf loop transforms
in the MLIR backend. This transforms the target loop into a
chain of loops, with step sizes that are powers of two and
decrease exponetially across subsequent loops. Originally
authored by Litu Zhou litu.zhou at huawei.com.
---
 .../SCF/TransformOps/SCFTransformOps.td       |  36 +++++
 .../SCF/TransformOps/SCFTransformOps.cpp      | 147 ++++++++++++++++++
 .../Dialect/SCF/loop-continuous-peel.mlir     |  98 ++++++++++++
 3 files changed, 281 insertions(+)
 create mode 100644 mlir/test/Dialect/SCF/loop-continuous-peel.mlir

diff --git a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
index b5ac22a2a758dd..4ef2e1b30ec202 100644
--- a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
+++ b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
@@ -156,6 +156,42 @@ def LoopPeelOp : Op<Transform_Dialect, "loop.peel",
   }];
 }
 
+def LoopContinuousPeelOp : Op<Transform_Dialect, "loop.loop_continuous_peel",
+    [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
+     TransformOpInterface, TransformEachOpTrait]> {
+  let description = [{
+    Transforms the loop into a chain of loops, with step sizes that are
+    powers of two and decrease exponetially across subsequent loops.
+    The transform is similar to loop.peel in the effect that it creates a loop
+    with a step (that is power of 2) to divide the range evenly, with the
+    difference that the remaining iterations are spread across similar loops
+    with exponentially decreasing step sizes, with the last loop with step size
+    of 2^0 = 1.
+
+    #### Return modes
+
+    This operation consumes the `target` handles and produces the
+    continuously-peeled loop.
+  }];
+
+  let arguments =
+      (ins TransformHandleTypeInterface:$target,
+           DefaultValuedAttr<BoolAttr, "false">:$single_iter_opt);
+  // TODO: Return both the peeled loop and the remainder loop.
+  let results = (outs TransformHandleTypeInterface:$transformed);
+
+  let assemblyFormat =
+    "$target attr-dict `:` functional-type(operands, results)";
+
+  let extraClassDeclaration = [{
+    ::mlir::DiagnosedSilenceableFailure applyToOne(
+        ::mlir::transform::TransformRewriter &rewriter,
+        ::mlir::Operation *target,
+        ::mlir::transform::ApplyToEachResultList &results,
+        ::mlir::transform::TransformState &state);
+  }];
+}
+
 def LoopPipelineOp : Op<Transform_Dialect, "loop.pipeline",
     [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
      TransformOpInterface, TransformEachOpTrait]> {
diff --git a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
index bc2fe5772af9d6..10806d9e6ba022 100644
--- a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
+++ b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
@@ -217,6 +217,153 @@ transform::LoopPeelOp::applyToOne(transform::TransformRewriter &rewriter,
   return DiagnosedSilenceableFailure::success();
 }
 
+//===---------------------------------------------------------------------===//
+// LoopContinuousPeelOp
+//===---------------------------------------------------------------------===//
+
+static LogicalResult splitLoopHelper(RewriterBase &b, scf::ForOp &forOp,
+                                     scf::ForOp &partialIteration,
+                                     Value &splitBound) {
+  RewriterBase::InsertionGuard guard(b);
+  auto lbInt = getConstantIntValue(forOp.getLowerBound());
+  auto ubInt = getConstantIntValue(forOp.getUpperBound());
+  auto stepInt = getConstantIntValue(forOp.getStep());
+
+  // No specialization necessary if step already divides upper bound evenly.
+  if (lbInt && ubInt && stepInt && (*ubInt - *lbInt) % *stepInt == 0)
+    return failure();
+  // No specialization necessary if step size is 1.
+  if (stepInt == static_cast<int64_t>(1))
+    return failure();
+
+  // Create ForOp for partial iteration.
+  b.setInsertionPointAfter(forOp);
+  partialIteration = cast<scf::ForOp>(b.clone(*forOp.getOperation()));
+  partialIteration.getLowerBoundMutable().assign(splitBound);
+  forOp.replaceAllUsesWith(partialIteration->getResults());
+  partialIteration.getInitArgsMutable().assign(forOp->getResults());
+
+  // Set new upper loop bound.
+  b.updateRootInPlace(
+      forOp, [&]() { forOp.getUpperBoundMutable().assign(splitBound); });
+
+  return success();
+}
+
+static scf::IfOp convertSingleIterFor(RewriterBase &b, scf::ForOp &forOp) {
+  Location loc = forOp->getLoc();
+  IRMapping mapping;
+  mapping.map(forOp.getInductionVar(), forOp.getLowerBound());
+  for (auto [arg, operand] :
+       llvm::zip(forOp.getRegionIterArgs(), forOp.getInitsMutable())) {
+    mapping.map(arg, operand.get());
+  }
+  b.setInsertionPoint(forOp);
+  auto cond =
+      b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
+                              forOp.getLowerBound(), forOp.getUpperBound());
+  auto ifOp = b.create<scf::IfOp>(loc, forOp->getResultTypes(), cond, true);
+  // then branch
+  b.setInsertionPointToStart(ifOp.thenBlock());
+  for (Operation &op : forOp.getBody()->getOperations()) {
+    b.clone(op, mapping);
+  }
+  // else branch
+  b.setInsertionPointToStart(ifOp.elseBlock());
+  if (!forOp->getResultTypes().empty()) {
+    b.create<scf::YieldOp>(loc, forOp.getInits());
+  }
+  b.replaceOp(forOp, ifOp->getResults());
+  return ifOp;
+}
+
+DiagnosedSilenceableFailure transform::LoopContinuousPeelOp::applyToOne(
+    transform::TransformRewriter &rewriter, Operation *target,
+    transform::ApplyToEachResultList &results,
+    transform::TransformState &state) {
+  scf::ForOp loop, currentLoop, partialLoop;
+  loop = dyn_cast<scf::ForOp>(target);
+  auto lbInt = getConstantIntValue(loop.getLowerBound());
+  auto stepInt = getConstantIntValue(loop.getStep());
+  if (!stepInt.has_value() || *stepInt <= 0)
+    return DiagnosedSilenceableFailure::
+        definiteFailure(); // step size must be a known positive constant
+  Value initialUb = loop.getUpperBound();
+  Value initialStep = loop.getStep();
+  uint64_t loopStep = *stepInt;
+  currentLoop = loop;
+  AffineExpr sym0, sym1, sym2;
+  bindSymbols(rewriter.getContext(), sym0, sym1, sym2);
+  AffineMap defaultSplitMap =
+      AffineMap::get(0, 3, {sym1 - ((sym1 - sym0) % sym2)});
+  AffineMap powerSplitMap = AffineMap::get(0, 3, {sym1 - (sym1 % sym2)});
+  bool usePowerSplit = (lbInt.has_value()) &&
+                       (*lbInt % *stepInt == static_cast<int64_t>(0)) &&
+                       (loopStep == llvm::bit_floor(loopStep));
+  AffineMap splitMap = usePowerSplit ? powerSplitMap : defaultSplitMap;
+  SmallVector<scf::ForOp> loops;
+  while (loopStep) {
+    rewriter.setInsertionPoint(currentLoop);
+    auto constStepOp =
+        rewriter.create<arith::ConstantIndexOp>(currentLoop.getLoc(), loopStep);
+    currentLoop.getStepMutable().assign(constStepOp);
+    rewriter.setInsertionPoint(currentLoop);
+    Value splitBound = rewriter.createOrFold<affine::AffineApplyOp>(
+        currentLoop.getLoc(), splitMap,
+        ValueRange{currentLoop.getLowerBound(), currentLoop.getUpperBound(),
+                   currentLoop.getStep()});
+    LogicalResult status =
+        splitLoopHelper(rewriter, currentLoop, partialLoop, splitBound);
+
+    // Canonicalize min/max affine operations
+    // It uses scf::rewritePeeledMinMaxOp to identify operations to be replaced,
+    // they are then replaced by the current step size.
+    // TODO: Alternative method - update affine map to reflect the loop step
+    // Example: min(ub - iv, 8) -> min(ub - iv, 4)
+    currentLoop.walk([&](affine::AffineMinOp affineOp) {
+      rewriter.setInsertionPoint(affineOp);
+      auto clonedOp = cast<affine::AffineMinOp>(rewriter.clone(*affineOp));
+      LogicalResult result = scf::rewritePeeledMinMaxOp(
+          rewriter, clonedOp, currentLoop.getInductionVar(), initialUb,
+          initialStep,
+          /*insideLoop=*/true);
+      if (result.succeeded())
+        rewriter.replaceOp(affineOp, currentLoop.getStep());
+      else
+        rewriter.eraseOp(clonedOp); // to avoid infinite walk
+    });
+    currentLoop.walk([&](affine::AffineMaxOp affineOp) {
+      rewriter.setInsertionPoint(affineOp);
+      auto clonedOp = cast<affine::AffineMaxOp>(rewriter.clone(*affineOp));
+      LogicalResult result = scf::rewritePeeledMinMaxOp(
+          rewriter, clonedOp, currentLoop.getInductionVar(), initialUb,
+          initialStep,
+          /*insideLoop=*/true);
+      if (result.succeeded())
+        rewriter.replaceOp(affineOp, currentLoop.getStep());
+      else
+        rewriter.eraseOp(clonedOp); // to avoid infinite walk
+    });
+
+    // Prepare for the next iteration
+    loops.push_back(currentLoop);
+    if (failed(status))
+      break;
+    currentLoop = partialLoop;
+    uint64_t maxPower = llvm::bit_floor(loopStep);
+    loopStep = maxPower == loopStep ? maxPower >> 1 : maxPower;
+  }
+  assert(loops.size() > 0 && "There should be at least one loop available");
+  if (getSingleIterOpt()) {
+    for (size_t i = 1; i < loops.size(); ++i) {
+      convertSingleIterFor(rewriter, loops[i]);
+    }
+  }
+
+  results.push_back(loops.front());
+  return DiagnosedSilenceableFailure::success();
+}
+
 //===----------------------------------------------------------------------===//
 // LoopPipelineOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SCF/loop-continuous-peel.mlir b/mlir/test/Dialect/SCF/loop-continuous-peel.mlir
new file mode 100644
index 00000000000000..752e1b1efed92a
--- /dev/null
+++ b/mlir/test/Dialect/SCF/loop-continuous-peel.mlir
@@ -0,0 +1,98 @@
+// RUN: mlir-opt %s --transform-interpreter -split-input-file | FileCheck %s
+
+#map = affine_map<(d0) -> ()>
+#map1 = affine_map<(d0) -> (d0)>
+module {
+  func.func @foo(%arg0: f32, %arg1: tensor<?xf32>) -> tensor<?xf32> {
+    %0 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel"]} ins(%arg0 : f32) outs(%arg1 : tensor<?xf32>) {
+    ^bb0(%in: f32, %out: f32):
+      %3 = arith.mulf %in, %out : f32
+      linalg.yield %3 : f32
+    } -> tensor<?xf32>
+    return %0 : tensor<?xf32>
+  }
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %1, %loops = transform.structured.tile_using_for %0[8] : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    %2 = transform.cast %loops : !transform.any_op to !transform.op<"scf.for">
+    %3 = transform.loop.loop_continuous_peel %2 {single_iter_opt = true} : (!transform.op<"scf.for">) -> (!transform.any_op)
+    transform.yield
+  }
+}
+
+// CHECK: #[[MAP:.*]] = affine_map<()[s0, s1, s2] -> (s1 - s1 mod s2)>
+// CHECK: #[[MAP1:.*]] = affine_map<() -> (8)>
+// CHECK: #[[MAP2:.*]] = affine_map<(d0) -> (d0 - 1)>
+// CHECK: #[[MAP3:.*]] = affine_map<(d0) -> ()>
+// CHECK: #[[MAP4:.*]] = affine_map<(d0) -> (d0)>
+
+// CHECK: func.func @foo(%[[S:.*]]: f32, %[[INVEC1:.*]]: tensor<?xf32>) -> tensor<?xf32> {
+// CHECK:       %[[C0:.*]] = arith.constant 0 : index
+// CHECK:       %[[DIM:.*]] = tensor.dim %[[INVEC1]], %[[C0]] : tensor<?xf32>
+// CHECK:       %[[C0:.*]] = arith.constant 0 : index
+// CHECK:       %{{.*}} = arith.constant 8 : index
+// CHECK:       %[[C8:.*]] = arith.constant 8 : index
+// CHECK:       %[[IDX0:.*]] = affine.apply #[[MAP]]()[%[[C0]], %[[DIM]], %[[C8]]]
+// CHECK:       %[[INS1:.*]] = scf.for %[[IDX:.*]] = %[[C0]] to %[[IDX0]] step %[[C8]] iter_args(%[[AINVEC1:.*]] = %[[INVEC1]]) -> (tensor<?xf32>) {
+// CHECK:         %{{.*}} = affine.apply #[[MAP2]](%[[C8]])
+// CHECK:         %[[XS8:.*]] = tensor.extract_slice %[[AINVEC1]][%[[IDX]]] [%[[C8]]] [1] : tensor<?xf32> to tensor<?xf32>
+// CHECK:         %[[MUL:.*]] = linalg.generic {indexing_maps = [#[[MAP3]], #[[MAP4]]], iterator_types = ["parallel"]} ins(%{{.*}} : f32) outs(%[[XS8]] : tensor<?xf32>) {
+// CHECK:         ^bb0(%{{.*}}: f32, %{{.*}}: f32):
+// CHECK:           %{{.*}} = arith.mulf %{{.*}}, %{{.*}} : f32
+// CHECK:           linalg.yield %{{.*}} : f32
+// CHECK:         } -> tensor<?xf32>
+// CHECK:         %[[INS:.*]] = tensor.insert_slice %[[MUL]] into %[[AINVEC1]][%[[IDX]]] [%[[C8]]] [1] : tensor<?xf32> into tensor<?xf32>
+// CHECK:         scf.yield %[[INS]] : tensor<?xf32>
+// CHECK:       }
+// CHECK:       %[[C4:.*]] = arith.constant 4 : index
+// CHECK:       %[[IDX2:.*]] = affine.apply #[[MAP]]()[%[[IDX0]], %[[DIM]], %[[C4]]]
+// CHECK:       %[[CMP3:.*]] = arith.cmpi slt, %[[IDX0]], %[[IDX2]] : index
+// CHECK:       %[[INS2:.*]] = scf.if %[[CMP3]] -> (tensor<?xf32>) {
+// CHECK:          %{{.*}} = affine.apply #[[MAP2]](%[[C4]])
+// CHECK:         %[[XS4:.*]] = tensor.extract_slice %[[INS1]][%[[IDX0]]] [%[[C4]]] [1] : tensor<?xf32> to tensor<?xf32>
+// CHECK:         %[[MUL:.*]] = linalg.generic {indexing_maps = [#[[MAP3]], #[[MAP4]]], iterator_types = ["parallel"]} ins(%[[S]] : f32) outs(%[[XS4]] : tensor<?xf32>) {
+// CHECK:         ^bb0(%{{.*}}: f32, %{{.*}}: f32):
+// CHECK:           %{{.*}} = arith.mulf  %{{.*}},  %{{.*}} : f32
+// CHECK:           linalg.yield  %{{.*}} : f32
+// CHECK:         } -> tensor<?xf32>
+// CHECK:         %[[INS:.*]] = tensor.insert_slice %[[MUL]] into %[[INS1]][%[[IDX0]]] [%[[C4]]] [1] : tensor<?xf32> into tensor<?xf32>
+// CHECK:         scf.yield %[[INS]] : tensor<?xf32>
+// CHECK:       } else {
+// CHECK:         scf.yield %[[INS1]] : tensor<?xf32>
+// CHECK:       }
+// CHECK:       %[[C2:.*]] = arith.constant 2 : index
+// CHECK:       %[[IDX3:.*]] = affine.apply #[[MAP]]()[%[[IDX2]], %[[DIM]], %[[C2]]]
+// CHECK:       %[[CMP4:.*]] = arith.cmpi slt, %[[IDX2]], %[[IDX3]] : index
+// CHECK:       %[[INS3:.*]] = scf.if %[[CMP4]] -> (tensor<?xf32>) {
+// CHECK:         %{{.*}} = affine.apply #[[MAP2]](%[[C2]])
+// CHECK:         %[[XS2:.*]] = tensor.extract_slice %[[INS2]][%[[IDX2]]] [%[[C2]]] [1] : tensor<?xf32> to tensor<?xf32>
+// CHECK:         %[[MUL:.*]] = linalg.generic {indexing_maps = [#[[MAP3]], #[[MAP4]]], iterator_types = ["parallel"]} ins(%[[S]] : f32) outs(%[[XS2]] : tensor<?xf32>) {
+// CHECK:         ^bb0(%{{.*}}: f32, %{{.*}}: f32):
+// CHECK:           %{{.*}} = arith.mulf %{{.*}}, %{{.*}} : f32
+// CHECK:           linalg.yield %{{.*}} : f32
+// CHECK:         } -> tensor<?xf32>
+// CHECK:         %[[INS:.*]] = tensor.insert_slice %[[MUL]] into %[[INS2]][%[[IDX2]]] [%[[C2]]] [1] : tensor<?xf32> into tensor<?xf32>
+// CHECK:         scf.yield %[[INS]] : tensor<?xf32>
+// CHECK:       } else {
+// CHECK:         scf.yield %[[INS2]] : tensor<?xf32>
+// CHECK:       }
+// CHECK:       %[[C1:.*]] = arith.constant 1 : index
+// CHECK:       %{{.*}} = affine.apply #[[MAP]]()[%[[IDX3]], %[[DIM]], %[[C1]]]
+// CHECK:       %[[CMP5:.*]] = arith.cmpi slt, %[[IDX3]], %[[DIM]] : index
+// CHECK:       %[[INS4:.*]] = scf.if %[[CMP5]] -> (tensor<?xf32>) {
+// CHECK:         %{{.*}} = affine.apply #[[MAP2]](%[[C1]])
+// CHECK:         %[[XS1:.*]] = tensor.extract_slice %[[INS3]][%[[IDX3]]] [%[[C1]]] [1] : tensor<?xf32> to tensor<?xf32>
+// CHECK:         %[[MUL:.*]] = linalg.generic {indexing_maps = [#[[MAP3]], #[[MAP4]]], iterator_types = ["parallel"]} ins(%[[S]] : f32) outs(%[[XS1]] : tensor<?xf32>) {
+// CHECK:         ^bb0(%{{.*}}: f32, %{{.*}}: f32):
+// CHECK:           %{{.*}} = arith.mulf %{{.*}}, %{{.*}} : f32
+// CHECK:           linalg.yield %{{.*}} : f32
+// CHECK:         } -> tensor<?xf32>
+// CHECK:         %[[INS:.*]] = tensor.insert_slice %[[MUL]] into %[[INS3]][%[[IDX3]]] [%[[C1]]] [1] : tensor<?xf32> into tensor<?xf32>
+// CHECK:         scf.yield %[[INS]] : tensor<?xf32>
+// CHECK:       } else {
+// CHECK:         scf.yield %[[INS3]] : tensor<?xf32>
+// CHECK:       }
+// CHECK:       return %[[INS4]] : tensor<?xf32>

>From cec91978c105a95edb1ba42d784685a652f98da0 Mon Sep 17 00:00:00 2001
From: Muneeb Khan <muneeb.khan at huawei.com>
Date: Tue, 7 Nov 2023 23:52:17 +0800
Subject: [PATCH 02/13] [MLIR][LLVM][Fixes] Add Continuous Loop Peeling
 transform to SCF

1. The transformation has been moved to
   mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp

2. Test in mlir/test/Dialect/SCF/loop-continuous-peel.mlir
   simplified using scf.for

3. Added -scf-for-loop-continuous-peeling pass for applying
   this transform independently on scf.for loops.

This commit should be squashed into the original.
---
 .../SCF/TransformOps/SCFTransformOps.td       |   5 +-
 .../mlir/Dialect/SCF/Transforms/Passes.h      |   6 +
 .../mlir/Dialect/SCF/Transforms/Passes.td     |  11 +
 .../mlir/Dialect/SCF/Transforms/Transforms.h  |  41 ++++
 .../SCF/TransformOps/SCFTransformOps.cpp      | 145 ++----------
 .../SCF/Transforms/LoopSpecialization.cpp     | 221 ++++++++++++++++++
 .../Dialect/SCF/loop-continuous-peel.mlir     | 134 ++++-------
 7 files changed, 336 insertions(+), 227 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
index 4ef2e1b30ec202..b39cd1490d86ad 100644
--- a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
+++ b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
@@ -176,9 +176,10 @@ def LoopContinuousPeelOp : Op<Transform_Dialect, "loop.loop_continuous_peel",
 
   let arguments =
       (ins TransformHandleTypeInterface:$target,
-           DefaultValuedAttr<BoolAttr, "false">:$single_iter_opt);
+           DefaultValuedAttr<BoolAttr, "false">:$convert_single_iter_loops_to_if);
   // TODO: Return both the peeled loop and the remainder loop.
-  let results = (outs TransformHandleTypeInterface:$transformed);
+  let results = (outs TransformHandleTypeInterface:$peeled_loop,
+                      TransformHandleTypeInterface:$remainder_loop);
 
   let assemblyFormat =
     "$target attr-dict `:` functional-type(operands, results)";
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/Passes.h b/mlir/include/mlir/Dialect/SCF/Transforms/Passes.h
index 90b315e83a8cfd..9ff1d6a07f17c3 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/Passes.h
@@ -31,6 +31,12 @@ std::unique_ptr<Pass> createForLoopSpecializationPass();
 /// better vectorization.
 std::unique_ptr<Pass> createForLoopPeelingPass();
 
+/// Creates a pass that transforms a for loop into a chain of loops
+/// where the step size is always a power of 2 but decreases exponentially
+/// across the loops. Helps with dividing the iteration space across all
+/// resulting peeled loops evenly.
+std::unique_ptr<Pass> createForLoopContinuousPeelingPass();
+
 /// Creates a pass that canonicalizes affine.min and affine.max operations
 /// inside of scf.for loops with known lower and upper bounds.
 std::unique_ptr<Pass> createSCFForLoopCanonicalizationPass();
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td b/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td
index 350611ad86873d..6f8462329c9d47 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td
@@ -43,6 +43,17 @@ def SCFForLoopPeeling : Pass<"scf-for-loop-peeling"> {
   let dependentDialects = ["affine::AffineDialect"];
 }
 
+def SCFForLoopContinuousPeeling : Pass<"scf-for-loop-continuous-peeling"> {
+  let summary = "Convert a loop into a chain of loops with exponentially decreasing steps that are power of 2.";
+  let constructor = "mlir::createForLoopContinuousPeelingPass()";
+  let options = [
+    Option<"convertSingleIterLoopsToIf", "convert-single-iter-loops-to-if", "bool",
+           /*default=*/"false",
+           "Convert single iteration loops to if. ">
+  ];
+  let dependentDialects = ["affine::AffineDialect"];
+}
+
 def SCFForLoopSpecialization : Pass<"scf-for-loop-specialization"> {
   let summary = "Specialize `for` loops for vectorization";
   let constructor = "mlir::createForLoopSpecializationPass()";
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h b/mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h
index e91f9e4469ab72..d2284aeacb175d 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h
@@ -89,6 +89,47 @@ LogicalResult peelForLoopAndSimplifyBounds(RewriterBase &rewriter, ForOp forOp,
 LogicalResult peelForLoopFirstIteration(RewriterBase &rewriter, ForOp forOp,
                                         scf::ForOp &partialIteration);
 
+/// Rewrite a for loop with bounds/step that potentially do not divide the
+/// iteration space evenly into a chain of for loops where the step is a
+/// power of 2 and decreases exponentially across subsequent loops.
+///
+/// E.g., assuming a lower bound of 0, the following loop
+/// ```
+/// scf.for %iv = %c0 to %ub step %c8 {
+///   (loop body)
+/// }
+/// ```
+/// is rewritten into the following pseudo IR:
+/// ```
+/// %newUb = %ub - (%ub mod %c8)
+/// scf.for %iv = %c0 to %newUb step %c8 {
+///   (loop body)
+/// }
+/// %newUb2 = %ub - (%ub mod %c4)
+/// scf.for %iv2 = %newUb to %newUb2 {
+///   (loop body)
+/// }
+/// %newUb3 = %ub - (%ub mod %c2)
+/// scf.for %iv2 = %newUb2 to %newUb3 {
+///   (loop body)
+/// }
+/// scf.for %iv2 = %newUb3 to %ub {
+///   (loop body)
+/// }
+/// ```
+///
+/// Similar to loop peeling, this function simplifies the affine.min and
+/// affine.max ops in the body of each resulting for loop for better
+/// canonicalization opportunities.
+///
+/// The return value indicates if the loop was rewritten. The loop
+/// is not rewritten if the step size is 1 or dynamic.
+
+LogicalResult
+continuousPeelForLoopAndSimplifyBounds(RewriterBase &rewriter, ForOp forOp,
+                                       scf::ForOp &partialIteration,
+                                       bool convertSingleIterLoopsToIf);
+
 /// Tile a parallel loop of the form
 ///   scf.parallel (%i0, %i1) = (%arg0, %arg1) to (%arg2, %arg3)
 ///                                             step (%arg4, %arg5)
diff --git a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
index 10806d9e6ba022..f9355fb03e1bc2 100644
--- a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
+++ b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
@@ -221,146 +221,27 @@ transform::LoopPeelOp::applyToOne(transform::TransformRewriter &rewriter,
 // LoopContinuousPeelOp
 //===---------------------------------------------------------------------===//
 
-static LogicalResult splitLoopHelper(RewriterBase &b, scf::ForOp &forOp,
-                                     scf::ForOp &partialIteration,
-                                     Value &splitBound) {
-  RewriterBase::InsertionGuard guard(b);
-  auto lbInt = getConstantIntValue(forOp.getLowerBound());
-  auto ubInt = getConstantIntValue(forOp.getUpperBound());
-  auto stepInt = getConstantIntValue(forOp.getStep());
-
-  // No specialization necessary if step already divides upper bound evenly.
-  if (lbInt && ubInt && stepInt && (*ubInt - *lbInt) % *stepInt == 0)
-    return failure();
-  // No specialization necessary if step size is 1.
-  if (stepInt == static_cast<int64_t>(1))
-    return failure();
-
-  // Create ForOp for partial iteration.
-  b.setInsertionPointAfter(forOp);
-  partialIteration = cast<scf::ForOp>(b.clone(*forOp.getOperation()));
-  partialIteration.getLowerBoundMutable().assign(splitBound);
-  forOp.replaceAllUsesWith(partialIteration->getResults());
-  partialIteration.getInitArgsMutable().assign(forOp->getResults());
-
-  // Set new upper loop bound.
-  b.updateRootInPlace(
-      forOp, [&]() { forOp.getUpperBoundMutable().assign(splitBound); });
-
-  return success();
-}
-
-static scf::IfOp convertSingleIterFor(RewriterBase &b, scf::ForOp &forOp) {
-  Location loc = forOp->getLoc();
-  IRMapping mapping;
-  mapping.map(forOp.getInductionVar(), forOp.getLowerBound());
-  for (auto [arg, operand] :
-       llvm::zip(forOp.getRegionIterArgs(), forOp.getInitsMutable())) {
-    mapping.map(arg, operand.get());
-  }
-  b.setInsertionPoint(forOp);
-  auto cond =
-      b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
-                              forOp.getLowerBound(), forOp.getUpperBound());
-  auto ifOp = b.create<scf::IfOp>(loc, forOp->getResultTypes(), cond, true);
-  // then branch
-  b.setInsertionPointToStart(ifOp.thenBlock());
-  for (Operation &op : forOp.getBody()->getOperations()) {
-    b.clone(op, mapping);
-  }
-  // else branch
-  b.setInsertionPointToStart(ifOp.elseBlock());
-  if (!forOp->getResultTypes().empty()) {
-    b.create<scf::YieldOp>(loc, forOp.getInits());
-  }
-  b.replaceOp(forOp, ifOp->getResults());
-  return ifOp;
-}
-
 DiagnosedSilenceableFailure transform::LoopContinuousPeelOp::applyToOne(
     transform::TransformRewriter &rewriter, Operation *target,
     transform::ApplyToEachResultList &results,
     transform::TransformState &state) {
-  scf::ForOp loop, currentLoop, partialLoop;
+  scf::ForOp loop, result;
   loop = dyn_cast<scf::ForOp>(target);
-  auto lbInt = getConstantIntValue(loop.getLowerBound());
-  auto stepInt = getConstantIntValue(loop.getStep());
-  if (!stepInt.has_value() || *stepInt <= 0)
-    return DiagnosedSilenceableFailure::
-        definiteFailure(); // step size must be a known positive constant
-  Value initialUb = loop.getUpperBound();
-  Value initialStep = loop.getStep();
-  uint64_t loopStep = *stepInt;
-  currentLoop = loop;
-  AffineExpr sym0, sym1, sym2;
-  bindSymbols(rewriter.getContext(), sym0, sym1, sym2);
-  AffineMap defaultSplitMap =
-      AffineMap::get(0, 3, {sym1 - ((sym1 - sym0) % sym2)});
-  AffineMap powerSplitMap = AffineMap::get(0, 3, {sym1 - (sym1 % sym2)});
-  bool usePowerSplit = (lbInt.has_value()) &&
-                       (*lbInt % *stepInt == static_cast<int64_t>(0)) &&
-                       (loopStep == llvm::bit_floor(loopStep));
-  AffineMap splitMap = usePowerSplit ? powerSplitMap : defaultSplitMap;
-  SmallVector<scf::ForOp> loops;
-  while (loopStep) {
-    rewriter.setInsertionPoint(currentLoop);
-    auto constStepOp =
-        rewriter.create<arith::ConstantIndexOp>(currentLoop.getLoc(), loopStep);
-    currentLoop.getStepMutable().assign(constStepOp);
-    rewriter.setInsertionPoint(currentLoop);
-    Value splitBound = rewriter.createOrFold<affine::AffineApplyOp>(
-        currentLoop.getLoc(), splitMap,
-        ValueRange{currentLoop.getLowerBound(), currentLoop.getUpperBound(),
-                   currentLoop.getStep()});
-    LogicalResult status =
-        splitLoopHelper(rewriter, currentLoop, partialLoop, splitBound);
-
-    // Canonicalize min/max affine operations
-    // It uses scf::rewritePeeledMinMaxOp to identify operations to be replaced,
-    // they are then replaced by the current step size.
-    // TODO: Alternative method - update affine map to reflect the loop step
-    // Example: min(ub - iv, 8) -> min(ub - iv, 4)
-    currentLoop.walk([&](affine::AffineMinOp affineOp) {
-      rewriter.setInsertionPoint(affineOp);
-      auto clonedOp = cast<affine::AffineMinOp>(rewriter.clone(*affineOp));
-      LogicalResult result = scf::rewritePeeledMinMaxOp(
-          rewriter, clonedOp, currentLoop.getInductionVar(), initialUb,
-          initialStep,
-          /*insideLoop=*/true);
-      if (result.succeeded())
-        rewriter.replaceOp(affineOp, currentLoop.getStep());
-      else
-        rewriter.eraseOp(clonedOp); // to avoid infinite walk
-    });
-    currentLoop.walk([&](affine::AffineMaxOp affineOp) {
-      rewriter.setInsertionPoint(affineOp);
-      auto clonedOp = cast<affine::AffineMaxOp>(rewriter.clone(*affineOp));
-      LogicalResult result = scf::rewritePeeledMinMaxOp(
-          rewriter, clonedOp, currentLoop.getInductionVar(), initialUb,
-          initialStep,
-          /*insideLoop=*/true);
-      if (result.succeeded())
-        rewriter.replaceOp(affineOp, currentLoop.getStep());
-      else
-        rewriter.eraseOp(clonedOp); // to avoid infinite walk
-    });
+  bool convertSingleIterLoopsToIf = false;
 
-    // Prepare for the next iteration
-    loops.push_back(currentLoop);
-    if (failed(status))
-      break;
-    currentLoop = partialLoop;
-    uint64_t maxPower = llvm::bit_floor(loopStep);
-    loopStep = maxPower == loopStep ? maxPower >> 1 : maxPower;
-  }
-  assert(loops.size() > 0 && "There should be at least one loop available");
-  if (getSingleIterOpt()) {
-    for (size_t i = 1; i < loops.size(); ++i) {
-      convertSingleIterFor(rewriter, loops[i]);
-    }
+  if (getConvertSingleIterLoopsToIf())
+    convertSingleIterLoopsToIf = true;
+
+  LogicalResult status = scf::continuousPeelForLoopAndSimplifyBounds(
+      rewriter, loop, result, convertSingleIterLoopsToIf);
+  if (failed(status)) {
+    DiagnosedSilenceableFailure diag =
+        emitSilenceableError() << "failed to perform continuous peeling";
+    return diag;
   }
 
-  results.push_back(loops.front());
+  results.push_back(loop);
+  results.push_back(result);
   return DiagnosedSilenceableFailure::success();
 }
 
diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp
index 9fda4861d40a3b..81d9c79de84a40 100644
--- a/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp
@@ -28,6 +28,7 @@
 
 namespace mlir {
 #define GEN_PASS_DEF_SCFFORLOOPPEELING
+#define GEN_PASS_DEF_SCFFORLOOPCONTINUOUSPEELING
 #define GEN_PASS_DEF_SCFFORLOOPSPECIALIZATION
 #define GEN_PASS_DEF_SCFPARALLELLOOPSPECIALIZATION
 #include "mlir/Dialect/SCF/Transforms/Passes.h.inc"
@@ -105,6 +106,165 @@ static void specializeForLoopForUnrolling(ForOp op) {
   op.erase();
 }
 
+static LogicalResult splitLoopHelper(RewriterBase &b, scf::ForOp &forOp,
+                                     scf::ForOp &partialIteration,
+                                     Value &splitBound) {
+  RewriterBase::InsertionGuard guard(b);
+  auto lbInt = getConstantIntValue(forOp.getLowerBound());
+  auto ubInt = getConstantIntValue(forOp.getUpperBound());
+  auto stepInt = getConstantIntValue(forOp.getStep());
+
+  // No specialization necessary if step already divides upper bound evenly.
+  if (lbInt && ubInt && stepInt && (*ubInt - *lbInt) % *stepInt == 0)
+    return failure();
+  // No specialization necessary if step size is 1.
+  if (stepInt == static_cast<int64_t>(1))
+    return failure();
+
+  // Create ForOp for partial iteration.
+  b.setInsertionPointAfter(forOp);
+  partialIteration = cast<scf::ForOp>(b.clone(*forOp.getOperation()));
+  partialIteration.getLowerBoundMutable().assign(splitBound);
+  forOp.replaceAllUsesWith(partialIteration->getResults());
+  partialIteration.getInitArgsMutable().assign(forOp->getResults());
+
+  // Set new upper loop bound.
+  b.updateRootInPlace(
+      forOp, [&]() { forOp.getUpperBoundMutable().assign(splitBound); });
+
+  return success();
+}
+
+static scf::IfOp convertSingleIterFor(RewriterBase &b, scf::ForOp &forOp) {
+  Location loc = forOp->getLoc();
+  IRMapping mapping;
+  mapping.map(forOp.getInductionVar(), forOp.getLowerBound());
+  for (auto [arg, operand] :
+       llvm::zip(forOp.getRegionIterArgs(), forOp.getInitsMutable())) {
+    mapping.map(arg, operand.get());
+  }
+  b.setInsertionPoint(forOp);
+  auto cond =
+      b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
+                              forOp.getLowerBound(), forOp.getUpperBound());
+  auto ifOp = b.create<scf::IfOp>(loc, forOp->getResultTypes(), cond, true);
+  // then branch
+  b.setInsertionPointToStart(ifOp.thenBlock());
+  for (Operation &op : forOp.getBody()->getOperations()) {
+    b.clone(op, mapping);
+  }
+  // else branch
+  b.setInsertionPointToStart(ifOp.elseBlock());
+  if (!forOp->getResultTypes().empty()) {
+    b.create<scf::YieldOp>(loc, forOp.getInits());
+  }
+  b.replaceOp(forOp, ifOp->getResults());
+  return ifOp;
+}
+
+/// Rewrite a for loop with bounds/step that potentially do not divide the
+/// iteration space evenly into a chain of for loops where the step is a
+/// power of 2 and decreases exponentially across subsequent loops. Helps
+/// divide the iteration space across all resulting peeled loops evenly.
+///
+/// Optionally, convert all single iteration for loops to if-else
+/// blocks when convert_single_iter_loops_to_if attribute is set to true or
+/// alternatively with the convert-single-iter-loops-to-if option for the
+/// scf-for-loop-continuous-peeling pass.
+static LogicalResult continuousPeelForLoop(RewriterBase &b, ForOp forOp,
+                                           ForOp &partialIteration,
+                                           bool convertSingleIterLoopsToIf) {
+
+  scf::ForOp currentLoop;
+  auto lbInt = getConstantIntValue(forOp.getLowerBound());
+  auto stepInt = getConstantIntValue(forOp.getStep());
+  if (!stepInt.has_value() || *stepInt <= 0)
+    return failure(); // step size must be a known positive constant
+  Value initialUb = forOp.getUpperBound();
+  Value initialStep = forOp.getStep();
+  uint64_t loopStep = *stepInt;
+  currentLoop = forOp;
+  AffineExpr sym0, sym1, sym2;
+  bindSymbols(b.getContext(), sym0, sym1, sym2);
+  AffineMap defaultSplitMap =
+      AffineMap::get(0, 3, {sym1 - ((sym1 - sym0) % sym2)});
+  AffineMap powerSplitMap = AffineMap::get(0, 3, {sym1 - (sym1 % sym2)});
+  bool usePowerSplit = (lbInt.has_value()) &&
+                       (*lbInt % *stepInt == static_cast<int64_t>(0)) &&
+                       (loopStep == llvm::bit_floor(loopStep));
+  AffineMap splitMap = usePowerSplit ? powerSplitMap : defaultSplitMap;
+  SmallVector<scf::ForOp> loops;
+  while (loopStep) {
+    b.setInsertionPoint(currentLoop);
+    auto constStepOp =
+        b.create<arith::ConstantIndexOp>(currentLoop.getLoc(), loopStep);
+    currentLoop.getStepMutable().assign(constStepOp);
+    b.setInsertionPoint(currentLoop);
+    Value splitBound = b.createOrFold<affine::AffineApplyOp>(
+        currentLoop.getLoc(), splitMap,
+        ValueRange{currentLoop.getLowerBound(), currentLoop.getUpperBound(),
+                   currentLoop.getStep()});
+    LogicalResult status =
+        splitLoopHelper(b, currentLoop, partialIteration, splitBound);
+
+    // Canonicalize min/max affine operations
+    // It uses scf::rewritePeeledMinMaxOp to identify operations to be replaced,
+    // they are then replaced by the current step size.
+    // TODO: Alternative method - update affine map to reflect the loop step
+    // Example: min(ub - iv, 8) -> min(ub - iv, 4)
+    currentLoop.walk([&](affine::AffineMinOp affineOp) {
+      b.setInsertionPoint(affineOp);
+      auto clonedOp = cast<affine::AffineMinOp>(b.clone(*affineOp));
+      LogicalResult result = scf::rewritePeeledMinMaxOp(
+          b, clonedOp, currentLoop.getInductionVar(), initialUb, initialStep,
+          /*insideLoop=*/true);
+      if (result.succeeded())
+        b.replaceOp(affineOp, currentLoop.getStep());
+      else
+        b.eraseOp(clonedOp); // to avoid infinite walk
+    });
+    currentLoop.walk([&](affine::AffineMaxOp affineOp) {
+      b.setInsertionPoint(affineOp);
+      auto clonedOp = cast<affine::AffineMaxOp>(b.clone(*affineOp));
+      LogicalResult result = scf::rewritePeeledMinMaxOp(
+          b, clonedOp, currentLoop.getInductionVar(), initialUb, initialStep,
+          /*insideLoop=*/true);
+      if (result.succeeded())
+        b.replaceOp(affineOp, currentLoop.getStep());
+      else
+        b.eraseOp(clonedOp); // to avoid infinite walk
+    });
+
+    // Prepare for the next iteration
+    loops.push_back(currentLoop);
+    if (failed(status))
+      break;
+    currentLoop = partialIteration;
+    uint64_t maxPower = llvm::bit_floor(loopStep);
+    loopStep = maxPower == loopStep ? maxPower >> 1 : maxPower;
+  }
+
+  assert(loops.size() > 0 && "There should be at least one loop available");
+  if (convertSingleIterLoopsToIf) {
+    for (size_t i = 1; i < loops.size(); ++i) {
+      convertSingleIterFor(b, loops[i]);
+    }
+  }
+
+  return success();
+}
+
+LogicalResult mlir::scf::continuousPeelForLoopAndSimplifyBounds(
+    RewriterBase &rewriter, ForOp forOp, ForOp &partialIteration,
+    bool convertSingleIterLoopsToIf) {
+
+  if (failed(continuousPeelForLoop(rewriter, forOp, partialIteration,
+                                   convertSingleIterLoopsToIf)))
+    return failure();
+
+  return success();
+}
+
 /// Rewrite a for loop with bounds/step that potentially do not divide evenly
 /// into a for loop where the step divides the iteration space evenly, followed
 /// by an scf.if for the last (partial) iteration (if any).
@@ -308,6 +468,45 @@ struct ForLoopPeelingPattern : public OpRewritePattern<ForOp> {
 };
 } // namespace
 
+namespace {
+struct ForLoopContinuousPeelingPattern : public OpRewritePattern<ForOp> {
+  ForLoopContinuousPeelingPattern(MLIRContext *ctx,
+                                  bool convertSingleIterLoopsToIf)
+      : OpRewritePattern<ForOp>(ctx),
+        convertSingleIterLoopsToIf(convertSingleIterLoopsToIf) {}
+
+  LogicalResult matchAndRewrite(ForOp forOp,
+                                PatternRewriter &rewriter) const override {
+    // Do not peel already peeled loops.
+    if (forOp->hasAttr(kPeeledLoopLabel))
+      return failure();
+
+    // Apply continuous loop peeling.
+    scf::ForOp partialIteration;
+    if (failed(continuousPeelForLoopAndSimplifyBounds(
+            rewriter, forOp, partialIteration, convertSingleIterLoopsToIf)))
+      return failure();
+
+    rewriter.updateRootInPlace(partialIteration, [&]() {
+      partialIteration->setAttr(kPeeledLoopLabel, rewriter.getUnitAttr());
+      partialIteration->setAttr(kPartialIterationLabel, rewriter.getUnitAttr());
+    });
+    rewriter.updateRootInPlace(forOp, [&]() {
+      forOp->setAttr(kPeeledLoopLabel, rewriter.getUnitAttr());
+    });
+
+    return success();
+  }
+
+  /// If set to true, loops inside partial iterations of another peeled loop
+  /// are not peeled. This reduces the size of the generated code. Partial
+  /// iterations are not usually performance critical.
+  /// Note: Takes into account the entire chain of parent operations, not just
+  /// the direct parent.
+  bool convertSingleIterLoopsToIf;
+};
+} // namespace
+
 namespace {
 struct ParallelLoopSpecialization
     : public impl::SCFParallelLoopSpecializationBase<
@@ -340,6 +539,24 @@ struct ForLoopPeeling : public impl::SCFForLoopPeelingBase<ForLoopPeeling> {
     });
   }
 };
+
+struct ForLoopContinuousPeeling
+    : public impl::SCFForLoopContinuousPeelingBase<ForLoopContinuousPeeling> {
+  void runOnOperation() override {
+    auto *parentOp = getOperation();
+    MLIRContext *ctx = parentOp->getContext();
+    RewritePatternSet patterns(ctx);
+    patterns.add<ForLoopContinuousPeelingPattern>(ctx,
+                                                  convertSingleIterLoopsToIf);
+    (void)applyPatternsAndFoldGreedily(parentOp, std::move(patterns));
+
+    // Drop the markers.
+    parentOp->walk([](Operation *op) {
+      op->removeAttr(kPeeledLoopLabel);
+      op->removeAttr(kPartialIterationLabel);
+    });
+  }
+};
 } // namespace
 
 std::unique_ptr<Pass> mlir::createParallelLoopSpecializationPass() {
@@ -353,3 +570,7 @@ std::unique_ptr<Pass> mlir::createForLoopSpecializationPass() {
 std::unique_ptr<Pass> mlir::createForLoopPeelingPass() {
   return std::make_unique<ForLoopPeeling>();
 }
+
+std::unique_ptr<Pass> mlir::createForLoopContinuousPeelingPass() {
+  return std::make_unique<ForLoopContinuousPeeling>();
+}
diff --git a/mlir/test/Dialect/SCF/loop-continuous-peel.mlir b/mlir/test/Dialect/SCF/loop-continuous-peel.mlir
index 752e1b1efed92a..e051e6a43be70e 100644
--- a/mlir/test/Dialect/SCF/loop-continuous-peel.mlir
+++ b/mlir/test/Dialect/SCF/loop-continuous-peel.mlir
@@ -1,98 +1,46 @@
-// RUN: mlir-opt %s --transform-interpreter -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -scf-for-loop-continuous-peeling=convert-single-iter-loops-to-if=true -split-input-file | FileCheck %s
 
-#map = affine_map<(d0) -> ()>
-#map1 = affine_map<(d0) -> (d0)>
-module {
-  func.func @foo(%arg0: f32, %arg1: tensor<?xf32>) -> tensor<?xf32> {
-    %0 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel"]} ins(%arg0 : f32) outs(%arg1 : tensor<?xf32>) {
-    ^bb0(%in: f32, %out: f32):
-      %3 = arith.mulf %in, %out : f32
-      linalg.yield %3 : f32
-    } -> tensor<?xf32>
-    return %0 : tensor<?xf32>
-  }
-}
-
-module attributes {transform.with_named_sequence} {
-  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
-    %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
-    %1, %loops = transform.structured.tile_using_for %0[8] : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
-    %2 = transform.cast %loops : !transform.any_op to !transform.op<"scf.for">
-    %3 = transform.loop.loop_continuous_peel %2 {single_iter_opt = true} : (!transform.op<"scf.for">) -> (!transform.any_op)
-    transform.yield
+#map = affine_map<(d0, d1)[s0] -> (s0, d0 - d1)>
+func.func @foo(%ub: index) -> index {
+  %c0 = arith.constant 0 : index
+  %step = arith.constant 8 : index
+  %0 = scf.for %iv = %c0 to %ub step %step iter_args(%arg = %c0) -> (index) {
+    %1 = affine.min #map(%ub, %iv)[%step]
+    %2 = index.add %1, %arg
+    scf.yield %2 : index
   }
+  return %0 : index
 }
 
 // CHECK: #[[MAP:.*]] = affine_map<()[s0, s1, s2] -> (s1 - s1 mod s2)>
-// CHECK: #[[MAP1:.*]] = affine_map<() -> (8)>
-// CHECK: #[[MAP2:.*]] = affine_map<(d0) -> (d0 - 1)>
-// CHECK: #[[MAP3:.*]] = affine_map<(d0) -> ()>
-// CHECK: #[[MAP4:.*]] = affine_map<(d0) -> (d0)>
-
-// CHECK: func.func @foo(%[[S:.*]]: f32, %[[INVEC1:.*]]: tensor<?xf32>) -> tensor<?xf32> {
-// CHECK:       %[[C0:.*]] = arith.constant 0 : index
-// CHECK:       %[[DIM:.*]] = tensor.dim %[[INVEC1]], %[[C0]] : tensor<?xf32>
-// CHECK:       %[[C0:.*]] = arith.constant 0 : index
-// CHECK:       %{{.*}} = arith.constant 8 : index
-// CHECK:       %[[C8:.*]] = arith.constant 8 : index
-// CHECK:       %[[IDX0:.*]] = affine.apply #[[MAP]]()[%[[C0]], %[[DIM]], %[[C8]]]
-// CHECK:       %[[INS1:.*]] = scf.for %[[IDX:.*]] = %[[C0]] to %[[IDX0]] step %[[C8]] iter_args(%[[AINVEC1:.*]] = %[[INVEC1]]) -> (tensor<?xf32>) {
-// CHECK:         %{{.*}} = affine.apply #[[MAP2]](%[[C8]])
-// CHECK:         %[[XS8:.*]] = tensor.extract_slice %[[AINVEC1]][%[[IDX]]] [%[[C8]]] [1] : tensor<?xf32> to tensor<?xf32>
-// CHECK:         %[[MUL:.*]] = linalg.generic {indexing_maps = [#[[MAP3]], #[[MAP4]]], iterator_types = ["parallel"]} ins(%{{.*}} : f32) outs(%[[XS8]] : tensor<?xf32>) {
-// CHECK:         ^bb0(%{{.*}}: f32, %{{.*}}: f32):
-// CHECK:           %{{.*}} = arith.mulf %{{.*}}, %{{.*}} : f32
-// CHECK:           linalg.yield %{{.*}} : f32
-// CHECK:         } -> tensor<?xf32>
-// CHECK:         %[[INS:.*]] = tensor.insert_slice %[[MUL]] into %[[AINVEC1]][%[[IDX]]] [%[[C8]]] [1] : tensor<?xf32> into tensor<?xf32>
-// CHECK:         scf.yield %[[INS]] : tensor<?xf32>
-// CHECK:       }
-// CHECK:       %[[C4:.*]] = arith.constant 4 : index
-// CHECK:       %[[IDX2:.*]] = affine.apply #[[MAP]]()[%[[IDX0]], %[[DIM]], %[[C4]]]
-// CHECK:       %[[CMP3:.*]] = arith.cmpi slt, %[[IDX0]], %[[IDX2]] : index
-// CHECK:       %[[INS2:.*]] = scf.if %[[CMP3]] -> (tensor<?xf32>) {
-// CHECK:          %{{.*}} = affine.apply #[[MAP2]](%[[C4]])
-// CHECK:         %[[XS4:.*]] = tensor.extract_slice %[[INS1]][%[[IDX0]]] [%[[C4]]] [1] : tensor<?xf32> to tensor<?xf32>
-// CHECK:         %[[MUL:.*]] = linalg.generic {indexing_maps = [#[[MAP3]], #[[MAP4]]], iterator_types = ["parallel"]} ins(%[[S]] : f32) outs(%[[XS4]] : tensor<?xf32>) {
-// CHECK:         ^bb0(%{{.*}}: f32, %{{.*}}: f32):
-// CHECK:           %{{.*}} = arith.mulf  %{{.*}},  %{{.*}} : f32
-// CHECK:           linalg.yield  %{{.*}} : f32
-// CHECK:         } -> tensor<?xf32>
-// CHECK:         %[[INS:.*]] = tensor.insert_slice %[[MUL]] into %[[INS1]][%[[IDX0]]] [%[[C4]]] [1] : tensor<?xf32> into tensor<?xf32>
-// CHECK:         scf.yield %[[INS]] : tensor<?xf32>
-// CHECK:       } else {
-// CHECK:         scf.yield %[[INS1]] : tensor<?xf32>
-// CHECK:       }
-// CHECK:       %[[C2:.*]] = arith.constant 2 : index
-// CHECK:       %[[IDX3:.*]] = affine.apply #[[MAP]]()[%[[IDX2]], %[[DIM]], %[[C2]]]
-// CHECK:       %[[CMP4:.*]] = arith.cmpi slt, %[[IDX2]], %[[IDX3]] : index
-// CHECK:       %[[INS3:.*]] = scf.if %[[CMP4]] -> (tensor<?xf32>) {
-// CHECK:         %{{.*}} = affine.apply #[[MAP2]](%[[C2]])
-// CHECK:         %[[XS2:.*]] = tensor.extract_slice %[[INS2]][%[[IDX2]]] [%[[C2]]] [1] : tensor<?xf32> to tensor<?xf32>
-// CHECK:         %[[MUL:.*]] = linalg.generic {indexing_maps = [#[[MAP3]], #[[MAP4]]], iterator_types = ["parallel"]} ins(%[[S]] : f32) outs(%[[XS2]] : tensor<?xf32>) {
-// CHECK:         ^bb0(%{{.*}}: f32, %{{.*}}: f32):
-// CHECK:           %{{.*}} = arith.mulf %{{.*}}, %{{.*}} : f32
-// CHECK:           linalg.yield %{{.*}} : f32
-// CHECK:         } -> tensor<?xf32>
-// CHECK:         %[[INS:.*]] = tensor.insert_slice %[[MUL]] into %[[INS2]][%[[IDX2]]] [%[[C2]]] [1] : tensor<?xf32> into tensor<?xf32>
-// CHECK:         scf.yield %[[INS]] : tensor<?xf32>
-// CHECK:       } else {
-// CHECK:         scf.yield %[[INS2]] : tensor<?xf32>
-// CHECK:       }
-// CHECK:       %[[C1:.*]] = arith.constant 1 : index
-// CHECK:       %{{.*}} = affine.apply #[[MAP]]()[%[[IDX3]], %[[DIM]], %[[C1]]]
-// CHECK:       %[[CMP5:.*]] = arith.cmpi slt, %[[IDX3]], %[[DIM]] : index
-// CHECK:       %[[INS4:.*]] = scf.if %[[CMP5]] -> (tensor<?xf32>) {
-// CHECK:         %{{.*}} = affine.apply #[[MAP2]](%[[C1]])
-// CHECK:         %[[XS1:.*]] = tensor.extract_slice %[[INS3]][%[[IDX3]]] [%[[C1]]] [1] : tensor<?xf32> to tensor<?xf32>
-// CHECK:         %[[MUL:.*]] = linalg.generic {indexing_maps = [#[[MAP3]], #[[MAP4]]], iterator_types = ["parallel"]} ins(%[[S]] : f32) outs(%[[XS1]] : tensor<?xf32>) {
-// CHECK:         ^bb0(%{{.*}}: f32, %{{.*}}: f32):
-// CHECK:           %{{.*}} = arith.mulf %{{.*}}, %{{.*}} : f32
-// CHECK:           linalg.yield %{{.*}} : f32
-// CHECK:         } -> tensor<?xf32>
-// CHECK:         %[[INS:.*]] = tensor.insert_slice %[[MUL]] into %[[INS3]][%[[IDX3]]] [%[[C1]]] [1] : tensor<?xf32> into tensor<?xf32>
-// CHECK:         scf.yield %[[INS]] : tensor<?xf32>
-// CHECK:       } else {
-// CHECK:         scf.yield %[[INS3]] : tensor<?xf32>
-// CHECK:       }
-// CHECK:       return %[[INS4]] : tensor<?xf32>
+// CHECK: func.func @foo(%[[UB:.*]]: index) -> index {
+// CHECK: %[[STEP8:.*]] = arith.constant 8 : index
+// CHECK: %[[STEP4:.*]] = arith.constant 4 : index
+// CHECK: %[[STEP2:.*]] = arith.constant 2 : index
+// CHECK: %[[STEP1:.*]] = arith.constant 1 : index
+// CHECK: %[[LB:.*]] = arith.constant 0 : index
+// CHECK: %[[I0:.*]] = affine.apply #[[MAP]]()[%[[LB]], %[[UB]], %[[STEP8]]]
+// CHECK: %[[I1:.*]] = scf.for %{{.*}} = %[[LB]] to %[[I0]] step %[[STEP8]] iter_args(%[[ALB:.*]] = %[[LB]]) -> (index) {
+// CHECK: %[[SUM:.*]] = index.add %[[ALB]], %[[STEP8]]
+// CHECK: scf.yield %[[SUM]] : index
+// CHECK: %[[I2:.*]] = affine.apply #[[MAP]]()[%[[I0]], %[[UB]], %[[STEP4]]]
+// CHECK: %[[I3:.*]] = arith.cmpi slt, %[[I0]], %[[I2]] : index
+// CHECK: %[[I4:.*]] = scf.if %[[I3]] -> (index) {
+// CHECK: %[[SUM:.*]] = index.add %[[I1]], %[[STEP4]]
+// CHECK: scf.yield %[[SUM]] : index
+// CHECK: } else {
+// CHECK: scf.yield %[[I1]] : index
+// CHECK: %[[I5:.*]] = affine.apply #[[MAP]]()[%[[I2]], %[[UB]], %[[STEP2]]]
+// CHECK: %[[I6:.*]] = arith.cmpi slt, %[[I2]], %[[I5]] : index
+// CHECK: %[[I7:.*]] = scf.if %[[I6]] -> (index) {
+// CHECK: %[[SUM:.*]] = index.add %[[I4]], %[[STEP2]]
+// CHECK: scf.yield %[[SUM]] : index
+// CHECK: } else {
+// CHECK: scf.yield %[[I4]] : index
+// CHECK: %[[I8:.*]] = arith.cmpi slt, %[[I5]], %[[UB]] : index
+// CHECK: %[[I9:.*]] = scf.if %[[I8]] -> (index) {
+// CHECK: %[[SUM:.*]] = index.add %[[I7]], %[[STEP1]]
+// CHECK: scf.yield %[[SUM]] : index
+// CHECK: } else {
+// CHECK: scf.yield %[[I7]] : index
+// CHECK: return %[[I9]] : index

>From 942771c5b752843cd8057d3249403fa922cd8fe0 Mon Sep 17 00:00:00 2001
From: Muneeb Khan <muneeb.khan at huawei.com>
Date: Tue, 7 Nov 2023 23:52:17 +0800
Subject: [PATCH 03/13] [MLIR][LLVM][Fixes] Add Continuous Loop Peeling
 transform to SCF

Add case for step size 1

This commit should be squashed into the original.
---
 mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp | 7 +++++--
 1 file changed, 5 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp
index 81d9c79de84a40..17b691b1282108 100644
--- a/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp
@@ -178,8 +178,11 @@ static LogicalResult continuousPeelForLoop(RewriterBase &b, ForOp forOp,
   scf::ForOp currentLoop;
   auto lbInt = getConstantIntValue(forOp.getLowerBound());
   auto stepInt = getConstantIntValue(forOp.getStep());
-  if (!stepInt.has_value() || *stepInt <= 0)
-    return failure(); // step size must be a known positive constant
+
+  // Step size must be a known positive constant greater than 1.
+  if (stepInt && stepInt <= static_cast<int64_t>(1))
+    return failure();
+
   Value initialUb = forOp.getUpperBound();
   Value initialStep = forOp.getStep();
   uint64_t loopStep = *stepInt;

>From 4f197b94882e3d7b5e05b7172b7ae2f5e2bcc8c1 Mon Sep 17 00:00:00 2001
From: Muneeb Khan <muneeb.khan at huawei.com>
Date: Fri, 10 Nov 2023 20:42:20 +0800
Subject: [PATCH 04/13] [MLIR][LLVM][Fixes] Add Continuous Loop Peeling
 transform to SCF

Removed redundant TODO comment.

This commit should be squashed into the original.
---
 mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
index b39cd1490d86ad..72b8b359d8ec98 100644
--- a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
+++ b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
@@ -177,7 +177,7 @@ def LoopContinuousPeelOp : Op<Transform_Dialect, "loop.loop_continuous_peel",
   let arguments =
       (ins TransformHandleTypeInterface:$target,
            DefaultValuedAttr<BoolAttr, "false">:$convert_single_iter_loops_to_if);
-  // TODO: Return both the peeled loop and the remainder loop.
+
   let results = (outs TransformHandleTypeInterface:$peeled_loop,
                       TransformHandleTypeInterface:$remainder_loop);
 

>From 1e51c4799376e1c7e8e0ed4a7545c852ba11d5c8 Mon Sep 17 00:00:00 2001
From: Muneeb Khan <muneeb.khan at huawei.com>
Date: Fri, 10 Nov 2023 23:38:29 +0800
Subject: [PATCH 05/13] [MLIR][LLVM][Fixes] Add Continuous Loop Peeling
 transform to SCF

Renamed test to conform to other similar tests.

This commit should be squashed into the original.
---
 ...op-continuous-peel.mlir => for-loop-continuous-peeling.mlir} | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)
 rename mlir/test/Dialect/SCF/{loop-continuous-peel.mlir => for-loop-continuous-peeling.mlir} (98%)

diff --git a/mlir/test/Dialect/SCF/loop-continuous-peel.mlir b/mlir/test/Dialect/SCF/for-loop-continuous-peeling.mlir
similarity index 98%
rename from mlir/test/Dialect/SCF/loop-continuous-peel.mlir
rename to mlir/test/Dialect/SCF/for-loop-continuous-peeling.mlir
index e051e6a43be70e..37a16c6dd094b7 100644
--- a/mlir/test/Dialect/SCF/loop-continuous-peel.mlir
+++ b/mlir/test/Dialect/SCF/for-loop-continuous-peeling.mlir
@@ -43,4 +43,4 @@ func.func @foo(%ub: index) -> index {
 // CHECK: scf.yield %[[SUM]] : index
 // CHECK: } else {
 // CHECK: scf.yield %[[I7]] : index
-// CHECK: return %[[I9]] : index
+// CHECK: return %[[I9]] : index
\ No newline at end of file

>From 162653ec5a529bfdd3e51f5ce6e17c3bfe3de9eb Mon Sep 17 00:00:00 2001
From: Muneeb Khan <muneeb.khan at huawei.com>
Date: Tue, 14 Nov 2023 23:44:35 +0800
Subject: [PATCH 06/13] [MLIR][LLVM][Fixes] Add Continuous Loop Peeling
 transform to SCF

1. Added comment for splitLoopHelper
2. Added comment for convertSingleIterFor.
3. Some minor changes as per  review suggestion.

This commit should be squashed into the original.
---
 mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp | 10 ++++++++--
 1 file changed, 8 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp
index 17b691b1282108..e33e8331ef4c95 100644
--- a/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp
@@ -106,7 +106,11 @@ static void specializeForLoopForUnrolling(ForOp op) {
   op.erase();
 }
 
-static LogicalResult splitLoopHelper(RewriterBase &b, scf::ForOp &forOp,
+/// Create a new for loop for the remaining iterations (partiaIteration)
+/// after a for loop has been peeled. This is followed by correcting the
+/// loop bounds for both loops given the index (splitBound) where the
+/// iteration space is to be split up.
+static LogicalResult splitLoopHelper(RewriterBase &b, scf::ForOp forOp,
                                      scf::ForOp &partialIteration,
                                      Value &splitBound) {
   RewriterBase::InsertionGuard guard(b);
@@ -135,6 +139,7 @@ static LogicalResult splitLoopHelper(RewriterBase &b, scf::ForOp &forOp,
   return success();
 }
 
+/// Convert single-iteration for loop to if-else block.
 static scf::IfOp convertSingleIterFor(RewriterBase &b, scf::ForOp &forOp) {
   Location loc = forOp->getLoc();
   IRMapping mapping;
@@ -180,7 +185,7 @@ static LogicalResult continuousPeelForLoop(RewriterBase &b, ForOp forOp,
   auto stepInt = getConstantIntValue(forOp.getStep());
 
   // Step size must be a known positive constant greater than 1.
-  if (stepInt && stepInt <= static_cast<int64_t>(1))
+  if (!stepInt || stepInt <= static_cast<int64_t>(1))
     return failure();
 
   Value initialUb = forOp.getUpperBound();
@@ -249,6 +254,7 @@ static LogicalResult continuousPeelForLoop(RewriterBase &b, ForOp forOp,
 
   assert(loops.size() > 0 && "There should be at least one loop available");
   if (convertSingleIterLoopsToIf) {
+    // Loops following the first loop will always be single-iteration.
     for (size_t i = 1; i < loops.size(); ++i) {
       convertSingleIterFor(b, loops[i]);
     }

>From 42bad66fe0981f78e1e833b261eee3d56d0f2fae Mon Sep 17 00:00:00 2001
From: Muneeb Khan <muneeb.khan at huawei.com>
Date: Tue, 12 Dec 2023 23:38:03 +0800
Subject: [PATCH 07/13] [MLIR][LLVM][Fixes] Add Continuous Loop Peeling
 transform to SCF

Use inlining function instead of manually cloning operations from
replaced for loop.

This commit should be squashed into the original.
---
 .../SCF/Transforms/LoopSpecialization.cpp     | 30 ++++++++-----------
 1 file changed, 13 insertions(+), 17 deletions(-)

diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp
index e33e8331ef4c95..3e9ea6ad9e19af 100644
--- a/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp
@@ -154,10 +154,15 @@ static scf::IfOp convertSingleIterFor(RewriterBase &b, scf::ForOp &forOp) {
                               forOp.getLowerBound(), forOp.getUpperBound());
   auto ifOp = b.create<scf::IfOp>(loc, forOp->getResultTypes(), cond, true);
   // then branch
-  b.setInsertionPointToStart(ifOp.thenBlock());
-  for (Operation &op : forOp.getBody()->getOperations()) {
-    b.clone(op, mapping);
-  }
+  ifOp->moveBefore(forOp);
+  // Replace block arguments with lower bound (replacement for IV) and
+  // iter_args.
+  SmallVector<Value> bbArgReplacements;
+  bbArgReplacements.push_back(forOp.getLowerBound());
+  llvm::append_range(bbArgReplacements, forOp.getInitArgs());
+
+  b.inlineBlockBefore(forOp.getBody(), ifOp.thenBlock(),
+                      ifOp.thenBlock()->begin(), bbArgReplacements);
   // else branch
   b.setInsertionPointToStart(ifOp.elseBlock());
   if (!forOp->getResultTypes().empty()) {
@@ -220,20 +225,11 @@ static LogicalResult continuousPeelForLoop(RewriterBase &b, ForOp forOp,
     // they are then replaced by the current step size.
     // TODO: Alternative method - update affine map to reflect the loop step
     // Example: min(ub - iv, 8) -> min(ub - iv, 4)
-    currentLoop.walk([&](affine::AffineMinOp affineOp) {
-      b.setInsertionPoint(affineOp);
-      auto clonedOp = cast<affine::AffineMinOp>(b.clone(*affineOp));
-      LogicalResult result = scf::rewritePeeledMinMaxOp(
-          b, clonedOp, currentLoop.getInductionVar(), initialUb, initialStep,
-          /*insideLoop=*/true);
-      if (result.succeeded())
-        b.replaceOp(affineOp, currentLoop.getStep());
-      else
-        b.eraseOp(clonedOp); // to avoid infinite walk
-    });
-    currentLoop.walk([&](affine::AffineMaxOp affineOp) {
+    currentLoop.walk([&](Operation *affineOp) {
+      if (!isa<AffineMinOp, AffineMaxOp>(affineOp))
+        return WalkResult::advance();
       b.setInsertionPoint(affineOp);
-      auto clonedOp = cast<affine::AffineMaxOp>(b.clone(*affineOp));
+      auto clonedOp = b.clone(*affineOp);
       LogicalResult result = scf::rewritePeeledMinMaxOp(
           b, clonedOp, currentLoop.getInductionVar(), initialUb, initialStep,
           /*insideLoop=*/true);

>From e4ffd004ac533c8c3816139205ce2a86fa7e9d09 Mon Sep 17 00:00:00 2001
From: Muneeb Khan <muneeb.khan at huawei.com>
Date: Thu, 14 Dec 2023 17:40:39 +0800
Subject: [PATCH 08/13] Revert "[MLIR][LLVM][Fixes] Add Continuous Loop Peeling
 transform to SCF"

This reverts commit 1cfe41741ba0eb92e24861ba8bfd239059554540.
---
 .../SCF/Transforms/LoopSpecialization.cpp     | 30 +++++++++++--------
 1 file changed, 17 insertions(+), 13 deletions(-)

diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp
index 3e9ea6ad9e19af..e33e8331ef4c95 100644
--- a/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp
@@ -154,15 +154,10 @@ static scf::IfOp convertSingleIterFor(RewriterBase &b, scf::ForOp &forOp) {
                               forOp.getLowerBound(), forOp.getUpperBound());
   auto ifOp = b.create<scf::IfOp>(loc, forOp->getResultTypes(), cond, true);
   // then branch
-  ifOp->moveBefore(forOp);
-  // Replace block arguments with lower bound (replacement for IV) and
-  // iter_args.
-  SmallVector<Value> bbArgReplacements;
-  bbArgReplacements.push_back(forOp.getLowerBound());
-  llvm::append_range(bbArgReplacements, forOp.getInitArgs());
-
-  b.inlineBlockBefore(forOp.getBody(), ifOp.thenBlock(),
-                      ifOp.thenBlock()->begin(), bbArgReplacements);
+  b.setInsertionPointToStart(ifOp.thenBlock());
+  for (Operation &op : forOp.getBody()->getOperations()) {
+    b.clone(op, mapping);
+  }
   // else branch
   b.setInsertionPointToStart(ifOp.elseBlock());
   if (!forOp->getResultTypes().empty()) {
@@ -225,11 +220,20 @@ static LogicalResult continuousPeelForLoop(RewriterBase &b, ForOp forOp,
     // they are then replaced by the current step size.
     // TODO: Alternative method - update affine map to reflect the loop step
     // Example: min(ub - iv, 8) -> min(ub - iv, 4)
-    currentLoop.walk([&](Operation *affineOp) {
-      if (!isa<AffineMinOp, AffineMaxOp>(affineOp))
-        return WalkResult::advance();
+    currentLoop.walk([&](affine::AffineMinOp affineOp) {
+      b.setInsertionPoint(affineOp);
+      auto clonedOp = cast<affine::AffineMinOp>(b.clone(*affineOp));
+      LogicalResult result = scf::rewritePeeledMinMaxOp(
+          b, clonedOp, currentLoop.getInductionVar(), initialUb, initialStep,
+          /*insideLoop=*/true);
+      if (result.succeeded())
+        b.replaceOp(affineOp, currentLoop.getStep());
+      else
+        b.eraseOp(clonedOp); // to avoid infinite walk
+    });
+    currentLoop.walk([&](affine::AffineMaxOp affineOp) {
       b.setInsertionPoint(affineOp);
-      auto clonedOp = b.clone(*affineOp);
+      auto clonedOp = cast<affine::AffineMaxOp>(b.clone(*affineOp));
       LogicalResult result = scf::rewritePeeledMinMaxOp(
           b, clonedOp, currentLoop.getInductionVar(), initialUb, initialStep,
           /*insideLoop=*/true);

>From e4a1cc79b11c6345b7fd14ef546f28810abe4092 Mon Sep 17 00:00:00 2001
From: Muneeb Khan <muneeb.khan at huawei.com>
Date: Thu, 14 Dec 2023 19:02:26 +0800
Subject: [PATCH 09/13] [MLIR][LLVM][Fixes] Add Continuous Loop Peeling
 transform to SCF

Clean up the rewriting of the affine op in the partial iterations.
This required changing the return type of rewritePeeledMinMaxOp
from LogicalResult to FailureOr<AffineApplyOp> to pass on the
newly created affine op.

This commit should be squashed into the original.
---
 .../SCF/Utils/AffineCanonicalizationUtils.h   |  7 +++--
 .../SCF/Transforms/LoopSpecialization.cpp     | 31 ++++++-------------
 .../SCF/Utils/AffineCanonicalizationUtils.cpp | 10 +++---
 3 files changed, 19 insertions(+), 29 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SCF/Utils/AffineCanonicalizationUtils.h b/mlir/include/mlir/Dialect/SCF/Utils/AffineCanonicalizationUtils.h
index 5022cbf898ec17..8ef1d0a5ffe803 100644
--- a/mlir/include/mlir/Dialect/SCF/Utils/AffineCanonicalizationUtils.h
+++ b/mlir/include/mlir/Dialect/SCF/Utils/AffineCanonicalizationUtils.h
@@ -14,6 +14,7 @@
 #ifndef MLIR_DIALECT_SCF_UTILS_AFFINECANONICALIZATIONUTILS_H_
 #define MLIR_DIALECT_SCF_UTILS_AFFINECANONICALIZATIONUTILS_H_
 
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Support/LLVM.h"
 #include "mlir/Support/LogicalResult.h"
 
@@ -75,9 +76,9 @@ LogicalResult canonicalizeMinMaxOpInLoop(RewriterBase &rewriter, Operation *op,
 /// ```
 /// min/max operations inside the partial iteration are rewritten in a similar
 /// way.
-LogicalResult rewritePeeledMinMaxOp(RewriterBase &rewriter, Operation *op,
-                                    Value iv, Value ub, Value step,
-                                    bool insideLoop);
+FailureOr<mlir::affine::AffineApplyOp>
+rewritePeeledMinMaxOp(RewriterBase &rewriter, Operation *op, Value iv, Value ub,
+                      Value step, bool insideLoop);
 
 } // namespace scf
 } // namespace mlir
diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp
index e33e8331ef4c95..297aeebcd5adcc 100644
--- a/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp
@@ -220,27 +220,16 @@ static LogicalResult continuousPeelForLoop(RewriterBase &b, ForOp forOp,
     // they are then replaced by the current step size.
     // TODO: Alternative method - update affine map to reflect the loop step
     // Example: min(ub - iv, 8) -> min(ub - iv, 4)
-    currentLoop.walk([&](affine::AffineMinOp affineOp) {
-      b.setInsertionPoint(affineOp);
-      auto clonedOp = cast<affine::AffineMinOp>(b.clone(*affineOp));
-      LogicalResult result = scf::rewritePeeledMinMaxOp(
-          b, clonedOp, currentLoop.getInductionVar(), initialUb, initialStep,
-          /*insideLoop=*/true);
-      if (result.succeeded())
-        b.replaceOp(affineOp, currentLoop.getStep());
-      else
-        b.eraseOp(clonedOp); // to avoid infinite walk
-    });
-    currentLoop.walk([&](affine::AffineMaxOp affineOp) {
-      b.setInsertionPoint(affineOp);
-      auto clonedOp = cast<affine::AffineMaxOp>(b.clone(*affineOp));
-      LogicalResult result = scf::rewritePeeledMinMaxOp(
-          b, clonedOp, currentLoop.getInductionVar(), initialUb, initialStep,
-          /*insideLoop=*/true);
-      if (result.succeeded())
-        b.replaceOp(affineOp, currentLoop.getStep());
-      else
-        b.eraseOp(clonedOp); // to avoid infinite walk
+    currentLoop.walk([&](Operation *affineOp) {
+      if (isa<AffineMinOp, AffineMaxOp>(affineOp)) {
+        FailureOr<AffineApplyOp> result = scf::rewritePeeledMinMaxOp(
+            b, affineOp, currentLoop.getInductionVar(), initialUb, initialStep,
+            /*insideLoop=*/true);
+        // correct the step of the newly created affine op
+        if (!failed(result))
+          b.replaceOp(result.value(), currentLoop.getStep());
+      }
+      return WalkResult::advance();
     });
 
     // Prepare for the next iteration
diff --git a/mlir/lib/Dialect/SCF/Utils/AffineCanonicalizationUtils.cpp b/mlir/lib/Dialect/SCF/Utils/AffineCanonicalizationUtils.cpp
index 89dfb61d48af94..7c748ab5c9b8f1 100644
--- a/mlir/lib/Dialect/SCF/Utils/AffineCanonicalizationUtils.cpp
+++ b/mlir/lib/Dialect/SCF/Utils/AffineCanonicalizationUtils.cpp
@@ -14,7 +14,6 @@
 
 #include "mlir/Dialect/Affine/Analysis/AffineStructures.h"
 #include "mlir/Dialect/Affine/Analysis/Utils.h"
-#include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Affine/IR/AffineValueMap.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Dialect/SCF/Utils/AffineCanonicalizationUtils.h"
@@ -189,16 +188,17 @@ LogicalResult scf::canonicalizeMinMaxOpInLoop(RewriterBase &rewriter,
 /// * Inside the peeled loop: min(step, ub - iv) == step
 /// * Inside the partial iteration: min(step, ub - iv) == ub - iv
 ///
-/// Returns `success` if the given operation was replaced by a new operation;
+/// Returns the new Affine op if the operation was replaced by a new operation;
 /// `failure` otherwise.
 ///
 /// Note: `ub` is the previous upper bound of the loop (before peeling).
 /// `insideLoop` must be true for min/max ops inside the loop and false for
 /// affine.min ops inside the partial iteration. For an explanation of the other
 /// parameters, see comment of `canonicalizeMinMaxOpInLoop`.
-LogicalResult scf::rewritePeeledMinMaxOp(RewriterBase &rewriter, Operation *op,
-                                         Value iv, Value ub, Value step,
-                                         bool insideLoop) {
+FailureOr<AffineApplyOp> scf::rewritePeeledMinMaxOp(RewriterBase &rewriter,
+                                                    Operation *op, Value iv,
+                                                    Value ub, Value step,
+                                                    bool insideLoop) {
   FlatAffineValueConstraints constraints;
   constraints.appendDimVar({iv});
   constraints.appendSymbolVar({ub, step});

>From e9e40ffdb76ffacb18e2611f2d5c24993afb89cd Mon Sep 17 00:00:00 2001
From: Muneeb Khan <muneeb.khan at huawei.com>
Date: Thu, 14 Dec 2023 19:26:51 +0800
Subject: [PATCH 10/13] [MLIR][LLVM][Fixes] Add Continuous Loop Peeling
 transform to SCF

Use block inlining for moving operations inside a for op to the
newly created target if op when converting single iteration for
ops to if ops.

This commit should be squashed into the original.
---
 mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp | 10 ++++++----
 1 file changed, 6 insertions(+), 4 deletions(-)

diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp
index 297aeebcd5adcc..90748efd40dd36 100644
--- a/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp
@@ -154,10 +154,12 @@ static scf::IfOp convertSingleIterFor(RewriterBase &b, scf::ForOp &forOp) {
                               forOp.getLowerBound(), forOp.getUpperBound());
   auto ifOp = b.create<scf::IfOp>(loc, forOp->getResultTypes(), cond, true);
   // then branch
-  b.setInsertionPointToStart(ifOp.thenBlock());
-  for (Operation &op : forOp.getBody()->getOperations()) {
-    b.clone(op, mapping);
-  }
+  SmallVector<Value> bbArgReplacements;
+  bbArgReplacements.push_back(forOp.getLowerBound());
+  llvm::append_range(bbArgReplacements, forOp.getInitArgs());
+
+  b.inlineBlockBefore(forOp.getBody(), ifOp.thenBlock(),
+                      ifOp.thenBlock()->begin(), bbArgReplacements);
   // else branch
   b.setInsertionPointToStart(ifOp.elseBlock());
   if (!forOp->getResultTypes().empty()) {

>From a850f6470a0710b425cb971b526afd0a3925f1fe Mon Sep 17 00:00:00 2001
From: Muneeb Khan <muneeb.khan at huawei.com>
Date: Mon, 8 Jan 2024 20:02:40 +0800
Subject: [PATCH 11/13] [MLIR][LLVM][Fixes] Add Continuous Loop Peeling
 transform to SCF

Some minor fixes as per reviewer comments.

This commit should be squashed into the original.
---
 .../mlir/Dialect/SCF/TransformOps/SCFTransformOps.td        | 2 +-
 mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h       | 6 +++---
 mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp       | 6 +-----
 mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp      | 6 ++++--
 4 files changed, 9 insertions(+), 11 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
index 72b8b359d8ec98..0999ef5ca75804 100644
--- a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
+++ b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
@@ -161,7 +161,7 @@ def LoopContinuousPeelOp : Op<Transform_Dialect, "loop.loop_continuous_peel",
      TransformOpInterface, TransformEachOpTrait]> {
   let description = [{
     Transforms the loop into a chain of loops, with step sizes that are
-    powers of two and decrease exponetially across subsequent loops.
+    powers of two and decrease exponentially across subsequent loops.
     The transform is similar to loop.peel in the effect that it creates a loop
     with a step (that is power of 2) to divide the range evenly, with the
     difference that the remaining iterations are spread across similar loops
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h b/mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h
index d2284aeacb175d..35d0fd87dc27a7 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h
@@ -106,14 +106,14 @@ LogicalResult peelForLoopFirstIteration(RewriterBase &rewriter, ForOp forOp,
 ///   (loop body)
 /// }
 /// %newUb2 = %ub - (%ub mod %c4)
-/// scf.for %iv2 = %newUb to %newUb2 {
+/// scf.for %iv2 = %newUb to %newUb2 step %c4 {
 ///   (loop body)
 /// }
 /// %newUb3 = %ub - (%ub mod %c2)
-/// scf.for %iv2 = %newUb2 to %newUb3 {
+/// scf.for %iv2 = %newUb2 to %newUb3 step %c2 {
 ///   (loop body)
 /// }
-/// scf.for %iv2 = %newUb3 to %ub {
+/// scf.for %iv2 = %newUb3 to %ub step %c1 {
 ///   (loop body)
 /// }
 /// ```
diff --git a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
index f9355fb03e1bc2..c1f752a3daf500 100644
--- a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
+++ b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
@@ -227,13 +227,9 @@ DiagnosedSilenceableFailure transform::LoopContinuousPeelOp::applyToOne(
     transform::TransformState &state) {
   scf::ForOp loop, result;
   loop = dyn_cast<scf::ForOp>(target);
-  bool convertSingleIterLoopsToIf = false;
-
-  if (getConvertSingleIterLoopsToIf())
-    convertSingleIterLoopsToIf = true;
 
   LogicalResult status = scf::continuousPeelForLoopAndSimplifyBounds(
-      rewriter, loop, result, convertSingleIterLoopsToIf);
+      rewriter, loop, result, getConvertSingleIterLoopsToIf());
   if (failed(status)) {
     DiagnosedSilenceableFailure diag =
         emitSilenceableError() << "failed to perform continuous peeling";
diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp
index 90748efd40dd36..50bbe06f68bff7 100644
--- a/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp
@@ -145,7 +145,7 @@ static scf::IfOp convertSingleIterFor(RewriterBase &b, scf::ForOp &forOp) {
   IRMapping mapping;
   mapping.map(forOp.getInductionVar(), forOp.getLowerBound());
   for (auto [arg, operand] :
-       llvm::zip(forOp.getRegionIterArgs(), forOp.getInitsMutable())) {
+       llvm::zip_equal(forOp.getRegionIterArgs(), forOp.getInitsMutable())) {
     mapping.map(arg, operand.get());
   }
   b.setInsertionPoint(forOp);
@@ -208,7 +208,9 @@ static LogicalResult continuousPeelForLoop(RewriterBase &b, ForOp forOp,
     b.setInsertionPoint(currentLoop);
     auto constStepOp =
         b.create<arith::ConstantIndexOp>(currentLoop.getLoc(), loopStep);
-    currentLoop.getStepMutable().assign(constStepOp);
+    b.updateRootInPlace(currentLoop, [&]() {
+      currentLoop.getStepMutable().assign(constStepOp);
+    });
     b.setInsertionPoint(currentLoop);
     Value splitBound = b.createOrFold<affine::AffineApplyOp>(
         currentLoop.getLoc(), splitMap,

>From e2f8f42355218863595620551da4e3b6af422369 Mon Sep 17 00:00:00 2001
From: Muneeb Khan <muneeb.khan at huawei.com>
Date: Mon, 8 Jan 2024 22:55:59 +0800
Subject: [PATCH 12/13] [MLIR][LLVM][Fixes] Add Continuous Loop Peeling
 transform to SCF

Some minor reordering in the check.

This commit should be squashed into the original.
---
 mlir/test/Dialect/SCF/for-loop-continuous-peeling.mlir | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/mlir/test/Dialect/SCF/for-loop-continuous-peeling.mlir b/mlir/test/Dialect/SCF/for-loop-continuous-peeling.mlir
index 37a16c6dd094b7..a78b3c914a0146 100644
--- a/mlir/test/Dialect/SCF/for-loop-continuous-peeling.mlir
+++ b/mlir/test/Dialect/SCF/for-loop-continuous-peeling.mlir
@@ -14,11 +14,11 @@ func.func @foo(%ub: index) -> index {
 
 // CHECK: #[[MAP:.*]] = affine_map<()[s0, s1, s2] -> (s1 - s1 mod s2)>
 // CHECK: func.func @foo(%[[UB:.*]]: index) -> index {
-// CHECK: %[[STEP8:.*]] = arith.constant 8 : index
-// CHECK: %[[STEP4:.*]] = arith.constant 4 : index
-// CHECK: %[[STEP2:.*]] = arith.constant 2 : index
 // CHECK: %[[STEP1:.*]] = arith.constant 1 : index
+// CHECK: %[[STEP2:.*]] = arith.constant 2 : index
+// CHECK: %[[STEP4:.*]] = arith.constant 4 : index
 // CHECK: %[[LB:.*]] = arith.constant 0 : index
+// CHECK: %[[STEP8:.*]] = arith.constant 8 : index
 // CHECK: %[[I0:.*]] = affine.apply #[[MAP]]()[%[[LB]], %[[UB]], %[[STEP8]]]
 // CHECK: %[[I1:.*]] = scf.for %{{.*}} = %[[LB]] to %[[I0]] step %[[STEP8]] iter_args(%[[ALB:.*]] = %[[LB]]) -> (index) {
 // CHECK: %[[SUM:.*]] = index.add %[[ALB]], %[[STEP8]]

>From d10f11f08c230eb7de4cb11bb67b0e763954fb28 Mon Sep 17 00:00:00 2001
From: Muneeb Khan <muneeb.khan at huawei.com>
Date: Tue, 9 Jan 2024 18:46:17 +0800
Subject: [PATCH 13/13] [MLIR][LLVM][Fixes] Add Continuous Loop Peeling
 transform to SCF

Changes as per reviewer comments. Additional test for the
case when lower loop bound is not divisible by step.

This commit should be squashed into the original.
---
 .../SCF/Transforms/LoopSpecialization.cpp     |  28 ++---
 .../SCF/Utils/AffineCanonicalizationUtils.cpp |   4 +-
 .../SCF/for-loop-continuous-peeling.mlir      | 100 +++++++++++++-----
 3 files changed, 89 insertions(+), 43 deletions(-)

diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp
index 50bbe06f68bff7..b9592588896616 100644
--- a/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp
@@ -106,13 +106,14 @@ static void specializeForLoopForUnrolling(ForOp op) {
   op.erase();
 }
 
-/// Create a new for loop for the remaining iterations (partiaIteration)
+/// Create a new for loop for the remaining iterations (partialIteration)
 /// after a for loop has been peeled. This is followed by correcting the
 /// loop bounds for both loops given the index (splitBound) where the
-/// iteration space is to be split up.
+/// iteration space is to be split up. Returns failure if the loop can not
+/// be split and no new partialIteration is created.
 static LogicalResult splitLoopHelper(RewriterBase &b, scf::ForOp forOp,
                                      scf::ForOp &partialIteration,
-                                     Value &splitBound) {
+                                     Value splitBound) {
   RewriterBase::InsertionGuard guard(b);
   auto lbInt = getConstantIntValue(forOp.getLowerBound());
   auto ubInt = getConstantIntValue(forOp.getUpperBound());
@@ -142,17 +143,13 @@ static LogicalResult splitLoopHelper(RewriterBase &b, scf::ForOp forOp,
 /// Convert single-iteration for loop to if-else block.
 static scf::IfOp convertSingleIterFor(RewriterBase &b, scf::ForOp &forOp) {
   Location loc = forOp->getLoc();
-  IRMapping mapping;
-  mapping.map(forOp.getInductionVar(), forOp.getLowerBound());
-  for (auto [arg, operand] :
-       llvm::zip_equal(forOp.getRegionIterArgs(), forOp.getInitsMutable())) {
-    mapping.map(arg, operand.get());
-  }
+
   b.setInsertionPoint(forOp);
   auto cond =
       b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
                               forOp.getLowerBound(), forOp.getUpperBound());
-  auto ifOp = b.create<scf::IfOp>(loc, forOp->getResultTypes(), cond, true);
+  auto ifOp = b.create<scf::IfOp>(loc, forOp->getResultTypes(), cond,
+                                  /*withElseRegion=*/true);
   // then branch
   SmallVector<Value> bbArgReplacements;
   bbArgReplacements.push_back(forOp.getLowerBound());
@@ -198,6 +195,9 @@ static LogicalResult continuousPeelForLoop(RewriterBase &b, ForOp forOp,
   bindSymbols(b.getContext(), sym0, sym1, sym2);
   AffineMap defaultSplitMap =
       AffineMap::get(0, 3, {sym1 - ((sym1 - sym0) % sym2)});
+  // Simplify defaultSplitMap if lower-bound is exactly divisible by step.
+  // In that case sym0 % sym2 (lb % step) is ignored as it equals 0 and
+  // we get a simpler map expressed as powerSplitMap here.
   AffineMap powerSplitMap = AffineMap::get(0, 3, {sym1 - (sym1 % sym2)});
   bool usePowerSplit = (lbInt.has_value()) &&
                        (*lbInt % *stepInt == static_cast<int64_t>(0)) &&
@@ -222,15 +222,15 @@ static LogicalResult continuousPeelForLoop(RewriterBase &b, ForOp forOp,
     // Canonicalize min/max affine operations
     // It uses scf::rewritePeeledMinMaxOp to identify operations to be replaced,
     // they are then replaced by the current step size.
-    // TODO: Alternative method - update affine map to reflect the loop step
-    // Example: min(ub - iv, 8) -> min(ub - iv, 4)
+    // TODO: Alternative method - update affine map to reflect the corrected
+    // loop step - Example: min(ub - iv, 8) -> min(ub - iv, 4)
     currentLoop.walk([&](Operation *affineOp) {
       if (isa<AffineMinOp, AffineMaxOp>(affineOp)) {
         FailureOr<AffineApplyOp> result = scf::rewritePeeledMinMaxOp(
             b, affineOp, currentLoop.getInductionVar(), initialUb, initialStep,
             /*insideLoop=*/true);
         // correct the step of the newly created affine op
-        if (!failed(result))
+        if (succeeded(result))
           b.replaceOp(result.value(), currentLoop.getStep());
       }
       return WalkResult::advance();
@@ -242,7 +242,7 @@ static LogicalResult continuousPeelForLoop(RewriterBase &b, ForOp forOp,
       break;
     currentLoop = partialIteration;
     uint64_t maxPower = llvm::bit_floor(loopStep);
-    loopStep = maxPower == loopStep ? maxPower >> 1 : maxPower;
+    loopStep = maxPower == loopStep ? maxPower / 2 : maxPower;
   }
 
   assert(loops.size() > 0 && "There should be at least one loop available");
diff --git a/mlir/lib/Dialect/SCF/Utils/AffineCanonicalizationUtils.cpp b/mlir/lib/Dialect/SCF/Utils/AffineCanonicalizationUtils.cpp
index 7c748ab5c9b8f1..83db8a552c82a2 100644
--- a/mlir/lib/Dialect/SCF/Utils/AffineCanonicalizationUtils.cpp
+++ b/mlir/lib/Dialect/SCF/Utils/AffineCanonicalizationUtils.cpp
@@ -188,8 +188,8 @@ LogicalResult scf::canonicalizeMinMaxOpInLoop(RewriterBase &rewriter,
 /// * Inside the peeled loop: min(step, ub - iv) == step
 /// * Inside the partial iteration: min(step, ub - iv) == ub - iv
 ///
-/// Returns the new Affine op if the operation was replaced by a new operation;
-/// `failure` otherwise.
+/// Returns the new AffineApplyOp if the operation was replaced by a new
+/// operation; `failure` otherwise.
 ///
 /// Note: `ub` is the previous upper bound of the loop (before peeling).
 /// `insideLoop` must be true for min/max ops inside the loop and false for
diff --git a/mlir/test/Dialect/SCF/for-loop-continuous-peeling.mlir b/mlir/test/Dialect/SCF/for-loop-continuous-peeling.mlir
index a78b3c914a0146..f611f4659633ab 100644
--- a/mlir/test/Dialect/SCF/for-loop-continuous-peeling.mlir
+++ b/mlir/test/Dialect/SCF/for-loop-continuous-peeling.mlir
@@ -12,35 +12,81 @@ func.func @foo(%ub: index) -> index {
   return %0 : index
 }
 
-// CHECK: #[[MAP:.*]] = affine_map<()[s0, s1, s2] -> (s1 - s1 mod s2)>
+// CHECK:   #[[MAP:.*]] = affine_map<()[s0, s1, s2] -> (s1 - s1 mod s2)>
 // CHECK: func.func @foo(%[[UB:.*]]: index) -> index {
 // CHECK: %[[STEP1:.*]] = arith.constant 1 : index
 // CHECK: %[[STEP2:.*]] = arith.constant 2 : index
 // CHECK: %[[STEP4:.*]] = arith.constant 4 : index
-// CHECK: %[[LB:.*]] = arith.constant 0 : index
+// CHECK:    %[[LB:.*]] = arith.constant 0 : index
 // CHECK: %[[STEP8:.*]] = arith.constant 8 : index
-// CHECK: %[[I0:.*]] = affine.apply #[[MAP]]()[%[[LB]], %[[UB]], %[[STEP8]]]
-// CHECK: %[[I1:.*]] = scf.for %{{.*}} = %[[LB]] to %[[I0]] step %[[STEP8]] iter_args(%[[ALB:.*]] = %[[LB]]) -> (index) {
-// CHECK: %[[SUM:.*]] = index.add %[[ALB]], %[[STEP8]]
-// CHECK: scf.yield %[[SUM]] : index
-// CHECK: %[[I2:.*]] = affine.apply #[[MAP]]()[%[[I0]], %[[UB]], %[[STEP4]]]
-// CHECK: %[[I3:.*]] = arith.cmpi slt, %[[I0]], %[[I2]] : index
-// CHECK: %[[I4:.*]] = scf.if %[[I3]] -> (index) {
-// CHECK: %[[SUM:.*]] = index.add %[[I1]], %[[STEP4]]
-// CHECK: scf.yield %[[SUM]] : index
-// CHECK: } else {
-// CHECK: scf.yield %[[I1]] : index
-// CHECK: %[[I5:.*]] = affine.apply #[[MAP]]()[%[[I2]], %[[UB]], %[[STEP2]]]
-// CHECK: %[[I6:.*]] = arith.cmpi slt, %[[I2]], %[[I5]] : index
-// CHECK: %[[I7:.*]] = scf.if %[[I6]] -> (index) {
-// CHECK: %[[SUM:.*]] = index.add %[[I4]], %[[STEP2]]
-// CHECK: scf.yield %[[SUM]] : index
-// CHECK: } else {
-// CHECK: scf.yield %[[I4]] : index
-// CHECK: %[[I8:.*]] = arith.cmpi slt, %[[I5]], %[[UB]] : index
-// CHECK: %[[I9:.*]] = scf.if %[[I8]] -> (index) {
-// CHECK: %[[SUM:.*]] = index.add %[[I7]], %[[STEP1]]
-// CHECK: scf.yield %[[SUM]] : index
-// CHECK: } else {
-// CHECK: scf.yield %[[I7]] : index
-// CHECK: return %[[I9]] : index
\ No newline at end of file
+// CHECK:    %[[I0:.*]] = affine.apply #[[MAP]]()[%[[LB]], %[[UB]], %[[STEP8]]]
+// CHECK:    %[[I1:.*]] = scf.for %{{.*}} = %[[LB]] to %[[I0]] step %[[STEP8]] iter_args(%[[ALB:.*]] = %[[LB]]) -> (index) {
+// CHECK:   %[[SUM:.*]] =   index.add %[[ALB]], %[[STEP8]]
+// CHECK:                   scf.yield %[[SUM]] : index
+// CHECK:    %[[I2:.*]] = affine.apply #[[MAP]]()[%[[I0]], %[[UB]], %[[STEP4]]]
+// CHECK:    %[[I3:.*]] = arith.cmpi slt, %[[I0]], %[[I2]] : index
+// CHECK:    %[[I4:.*]] = scf.if %[[I3]] -> (index) {
+// CHECK:   %[[SUM:.*]] =   index.add %[[I1]], %[[STEP4]]
+// CHECK:                   scf.yield %[[SUM]] : index
+// CHECK:                 } else {
+// CHECK:                   scf.yield %[[I1]] : index
+// CHECK:    %[[I5:.*]] = affine.apply #[[MAP]]()[%[[I2]], %[[UB]], %[[STEP2]]]
+// CHECK:    %[[I6:.*]] = arith.cmpi slt, %[[I2]], %[[I5]] : index
+// CHECK:    %[[I7:.*]] = scf.if %[[I6]] -> (index) {
+// CHECK:   %[[SUM:.*]] =   index.add %[[I4]], %[[STEP2]]
+// CHECK:                   scf.yield %[[SUM]] : index
+// CHECK:                 } else {
+// CHECK:                   scf.yield %[[I4]] : index
+// CHECK:    %[[I8:.*]] = arith.cmpi slt, %[[I5]], %[[UB]] : index
+// CHECK:    %[[I9:.*]] = scf.if %[[I8]] -> (index) {
+// CHECK:   %[[SUM:.*]] =   index.add %[[I7]], %[[STEP1]]
+// CHECK:                   scf.yield %[[SUM]] : index
+// CHECK:                 } else {
+// CHECK:                   scf.yield %[[I7]] : index
+// CHECK:                 return %[[I9]] : index
+
+// -----
+
+#map1 = affine_map<(d0, d1)[s0] -> (s0, d0 - d1)>
+func.func @foo1(%ub: index) -> index {
+  %c0 = arith.constant 1 : index
+  %step = arith.constant 8 : index
+  %0 = scf.for %iv = %c0 to %ub step %step iter_args(%arg = %c0) -> (index) {
+    %1 = affine.min #map1(%ub, %iv)[%step]
+    %2 = index.add %1, %arg
+    scf.yield %2 : index
+  }
+  return %0 : index
+}
+
+// CHECK:   #[[MAP:.*]] = affine_map<()[s0, s1, s2] -> (s1 - (s1 - s0) mod s2)>
+// CHECK: func.func @foo1(%[[UB:.*]]: index) -> index {
+// CHECK: %[[STEP2:.*]] = arith.constant 2 : index
+// CHECK: %[[STEP4:.*]] = arith.constant 4 : index
+// CHECK: %[[STEP1:.*]] = arith.constant 1 : index
+// CHECK: %[[STEP8:.*]] = arith.constant 8 : index
+// CHECK:    %[[I0:.*]] = affine.apply #[[MAP]]()[%[[STEP1]], %[[UB]], %[[STEP8]]]
+// CHECK:    %[[I1:.*]] = scf.for %{{.*}} = %[[STEP1]] to %[[I0]] step %[[STEP8]] iter_args(%[[ALB:.*]] = %[[STEP1]]) -> (index) {
+// CHECK:   %[[SUM:.*]] =   index.add %[[ALB]], %[[STEP8]]
+// CHECK:                   scf.yield %[[SUM]] : index
+// CHECK:    %[[I2:.*]] = affine.apply #[[MAP]]()[%[[I0]], %[[UB]], %[[STEP4]]]
+// CHECK:    %[[I3:.*]] = arith.cmpi slt, %[[I0]], %[[I2]] : index
+// CHECK:    %[[I4:.*]] = scf.if %[[I3]] -> (index) {
+// CHECK:   %[[SUM:.*]] =   index.add %[[I1]], %[[STEP4]]
+// CHECK:                   scf.yield %[[SUM]] : index
+// CHECK:                 } else {
+// CHECK:                   scf.yield %[[I1]] : index
+// CHECK:    %[[I5:.*]] = affine.apply #[[MAP]]()[%[[I2]], %[[UB]], %[[STEP2]]]
+// CHECK:    %[[I6:.*]] = arith.cmpi slt, %[[I2]], %[[I5]] : index
+// CHECK:    %[[I7:.*]] = scf.if %[[I6]] -> (index) {
+// CHECK:   %[[SUM:.*]] =   index.add %[[I4]], %[[STEP2]]
+// CHECK:                   scf.yield %[[SUM]] : index
+// CHECK:                 } else {
+// CHECK:                   scf.yield %[[I4]] : index
+// CHECK:    %[[I8:.*]] = arith.cmpi slt, %[[I5]], %[[UB]] : index
+// CHECK:    %[[I9:.*]] = scf.if %[[I8]] -> (index) {
+// CHECK:   %[[SUM:.*]] =   index.add %[[I7]], %[[STEP1]]
+// CHECK:                   scf.yield %[[SUM]] : index
+// CHECK:                 } else {
+// CHECK:                   scf.yield %[[I7]] : index
+// CHECK:                 return %[[I9]] : index
\ No newline at end of file



More information about the llvm-commits mailing list