[Mlir-commits] [mlir] [MLIR][LLVM] Add Continuous Loop Peeling transform to SCF (PR #71555)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Nov 10 03:16:46 PST 2023


https://github.com/muneebkhan85 updated https://github.com/llvm/llvm-project/pull/71555

>From 7bb2f9793b2a2cccbaa401f6e2ac850b587f2b59 Mon Sep 17 00:00:00 2001
From: Muneeb Khan <muneeb.khan at huawei.com>
Date: Tue, 7 Nov 2023 23:52:17 +0800
Subject: [PATCH 1/3] [MLIR][LLVM] Add Continuous Loop Peeling transform to SCF

This patch adds continuous loop peeling to scf loop transforms
in the MLIR backend. This transforms the target loop into a
chain of loops, with step sizes that are powers of two and
decrease exponetially across subsequent loops. Originally
authored by Litu Zhou litu.zhou at huawei.com.
---
 .../SCF/TransformOps/SCFTransformOps.td       |  36 +++++
 .../SCF/TransformOps/SCFTransformOps.cpp      | 147 ++++++++++++++++++
 .../Dialect/SCF/loop-continuous-peel.mlir     |  98 ++++++++++++
 3 files changed, 281 insertions(+)
 create mode 100644 mlir/test/Dialect/SCF/loop-continuous-peel.mlir

diff --git a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
index 14df7e23a430fb1..e3d79a7f0ae40f3 100644
--- a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
+++ b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
@@ -147,6 +147,42 @@ def LoopPeelOp : Op<Transform_Dialect, "loop.peel",
   }];
 }
 
+def LoopContinuousPeelOp : Op<Transform_Dialect, "loop.loop_continuous_peel",
+    [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
+     TransformOpInterface, TransformEachOpTrait]> {
+  let description = [{
+    Transforms the loop into a chain of loops, with step sizes that are
+    powers of two and decrease exponetially across subsequent loops.
+    The transform is similar to loop.peel in the effect that it creates a loop
+    with a step (that is power of 2) to divide the range evenly, with the
+    difference that the remaining iterations are spread across similar loops
+    with exponentially decreasing step sizes, with the last loop with step size
+    of 2^0 = 1.
+
+    #### Return modes
+
+    This operation consumes the `target` handles and produces the
+    continuously-peeled loop.
+  }];
+
+  let arguments =
+      (ins TransformHandleTypeInterface:$target,
+           DefaultValuedAttr<BoolAttr, "false">:$single_iter_opt);
+  // TODO: Return both the peeled loop and the remainder loop.
+  let results = (outs TransformHandleTypeInterface:$transformed);
+
+  let assemblyFormat =
+    "$target attr-dict `:` functional-type(operands, results)";
+
+  let extraClassDeclaration = [{
+    ::mlir::DiagnosedSilenceableFailure applyToOne(
+        ::mlir::transform::TransformRewriter &rewriter,
+        ::mlir::Operation *target,
+        ::mlir::transform::ApplyToEachResultList &results,
+        ::mlir::transform::TransformState &state);
+  }];
+}
+
 def LoopPipelineOp : Op<Transform_Dialect, "loop.pipeline",
     [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
      TransformOpInterface, TransformEachOpTrait]> {
diff --git a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
index 62370604142cd5b..dcba6a8b406b21f 100644
--- a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
+++ b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
@@ -206,6 +206,153 @@ transform::LoopPeelOp::applyToOne(transform::TransformRewriter &rewriter,
   return DiagnosedSilenceableFailure::success();
 }
 
+//===---------------------------------------------------------------------===//
+// LoopContinuousPeelOp
+//===---------------------------------------------------------------------===//
+
+static LogicalResult splitLoopHelper(RewriterBase &b, scf::ForOp &forOp,
+                                     scf::ForOp &partialIteration,
+                                     Value &splitBound) {
+  RewriterBase::InsertionGuard guard(b);
+  auto lbInt = getConstantIntValue(forOp.getLowerBound());
+  auto ubInt = getConstantIntValue(forOp.getUpperBound());
+  auto stepInt = getConstantIntValue(forOp.getStep());
+
+  // No specialization necessary if step already divides upper bound evenly.
+  if (lbInt && ubInt && stepInt && (*ubInt - *lbInt) % *stepInt == 0)
+    return failure();
+  // No specialization necessary if step size is 1.
+  if (stepInt == static_cast<int64_t>(1))
+    return failure();
+
+  // Create ForOp for partial iteration.
+  b.setInsertionPointAfter(forOp);
+  partialIteration = cast<scf::ForOp>(b.clone(*forOp.getOperation()));
+  partialIteration.getLowerBoundMutable().assign(splitBound);
+  forOp.replaceAllUsesWith(partialIteration->getResults());
+  partialIteration.getInitArgsMutable().assign(forOp->getResults());
+
+  // Set new upper loop bound.
+  b.updateRootInPlace(
+      forOp, [&]() { forOp.getUpperBoundMutable().assign(splitBound); });
+
+  return success();
+}
+
+static scf::IfOp convertSingleIterFor(RewriterBase &b, scf::ForOp &forOp) {
+  Location loc = forOp->getLoc();
+  IRMapping mapping;
+  mapping.map(forOp.getInductionVar(), forOp.getLowerBound());
+  for (auto [arg, operand] :
+       llvm::zip(forOp.getRegionIterArgs(), forOp.getInitsMutable())) {
+    mapping.map(arg, operand.get());
+  }
+  b.setInsertionPoint(forOp);
+  auto cond =
+      b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
+                              forOp.getLowerBound(), forOp.getUpperBound());
+  auto ifOp = b.create<scf::IfOp>(loc, forOp->getResultTypes(), cond, true);
+  // then branch
+  b.setInsertionPointToStart(ifOp.thenBlock());
+  for (Operation &op : forOp.getBody()->getOperations()) {
+    b.clone(op, mapping);
+  }
+  // else branch
+  b.setInsertionPointToStart(ifOp.elseBlock());
+  if (!forOp->getResultTypes().empty()) {
+    b.create<scf::YieldOp>(loc, forOp.getInits());
+  }
+  b.replaceOp(forOp, ifOp->getResults());
+  return ifOp;
+}
+
+DiagnosedSilenceableFailure transform::LoopContinuousPeelOp::applyToOne(
+    transform::TransformRewriter &rewriter, Operation *target,
+    transform::ApplyToEachResultList &results,
+    transform::TransformState &state) {
+  scf::ForOp loop, currentLoop, partialLoop;
+  loop = dyn_cast<scf::ForOp>(target);
+  auto lbInt = getConstantIntValue(loop.getLowerBound());
+  auto stepInt = getConstantIntValue(loop.getStep());
+  if (!stepInt.has_value() || *stepInt <= 0)
+    return DiagnosedSilenceableFailure::
+        definiteFailure(); // step size must be a known positive constant
+  Value initialUb = loop.getUpperBound();
+  Value initialStep = loop.getStep();
+  uint64_t loopStep = *stepInt;
+  currentLoop = loop;
+  AffineExpr sym0, sym1, sym2;
+  bindSymbols(rewriter.getContext(), sym0, sym1, sym2);
+  AffineMap defaultSplitMap =
+      AffineMap::get(0, 3, {sym1 - ((sym1 - sym0) % sym2)});
+  AffineMap powerSplitMap = AffineMap::get(0, 3, {sym1 - (sym1 % sym2)});
+  bool usePowerSplit = (lbInt.has_value()) &&
+                       (*lbInt % *stepInt == static_cast<int64_t>(0)) &&
+                       (loopStep == llvm::bit_floor(loopStep));
+  AffineMap splitMap = usePowerSplit ? powerSplitMap : defaultSplitMap;
+  SmallVector<scf::ForOp> loops;
+  while (loopStep) {
+    rewriter.setInsertionPoint(currentLoop);
+    auto constStepOp =
+        rewriter.create<arith::ConstantIndexOp>(currentLoop.getLoc(), loopStep);
+    currentLoop.getStepMutable().assign(constStepOp);
+    rewriter.setInsertionPoint(currentLoop);
+    Value splitBound = rewriter.createOrFold<affine::AffineApplyOp>(
+        currentLoop.getLoc(), splitMap,
+        ValueRange{currentLoop.getLowerBound(), currentLoop.getUpperBound(),
+                   currentLoop.getStep()});
+    LogicalResult status =
+        splitLoopHelper(rewriter, currentLoop, partialLoop, splitBound);
+
+    // Canonicalize min/max affine operations
+    // It uses scf::rewritePeeledMinMaxOp to identify operations to be replaced,
+    // they are then replaced by the current step size.
+    // TODO: Alternative method - update affine map to reflect the loop step
+    // Example: min(ub - iv, 8) -> min(ub - iv, 4)
+    currentLoop.walk([&](affine::AffineMinOp affineOp) {
+      rewriter.setInsertionPoint(affineOp);
+      auto clonedOp = cast<affine::AffineMinOp>(rewriter.clone(*affineOp));
+      LogicalResult result = scf::rewritePeeledMinMaxOp(
+          rewriter, clonedOp, currentLoop.getInductionVar(), initialUb,
+          initialStep,
+          /*insideLoop=*/true);
+      if (result.succeeded())
+        rewriter.replaceOp(affineOp, currentLoop.getStep());
+      else
+        rewriter.eraseOp(clonedOp); // to avoid infinite walk
+    });
+    currentLoop.walk([&](affine::AffineMaxOp affineOp) {
+      rewriter.setInsertionPoint(affineOp);
+      auto clonedOp = cast<affine::AffineMaxOp>(rewriter.clone(*affineOp));
+      LogicalResult result = scf::rewritePeeledMinMaxOp(
+          rewriter, clonedOp, currentLoop.getInductionVar(), initialUb,
+          initialStep,
+          /*insideLoop=*/true);
+      if (result.succeeded())
+        rewriter.replaceOp(affineOp, currentLoop.getStep());
+      else
+        rewriter.eraseOp(clonedOp); // to avoid infinite walk
+    });
+
+    // Prepare for the next iteration
+    loops.push_back(currentLoop);
+    if (failed(status))
+      break;
+    currentLoop = partialLoop;
+    uint64_t maxPower = llvm::bit_floor(loopStep);
+    loopStep = maxPower == loopStep ? maxPower >> 1 : maxPower;
+  }
+  assert(loops.size() > 0 && "There should be at least one loop available");
+  if (getSingleIterOpt()) {
+    for (size_t i = 1; i < loops.size(); ++i) {
+      convertSingleIterFor(rewriter, loops[i]);
+    }
+  }
+
+  results.push_back(loops.front());
+  return DiagnosedSilenceableFailure::success();
+}
+
 //===----------------------------------------------------------------------===//
 // LoopPipelineOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SCF/loop-continuous-peel.mlir b/mlir/test/Dialect/SCF/loop-continuous-peel.mlir
new file mode 100644
index 000000000000000..752e1b1efed92ac
--- /dev/null
+++ b/mlir/test/Dialect/SCF/loop-continuous-peel.mlir
@@ -0,0 +1,98 @@
+// RUN: mlir-opt %s --transform-interpreter -split-input-file | FileCheck %s
+
+#map = affine_map<(d0) -> ()>
+#map1 = affine_map<(d0) -> (d0)>
+module {
+  func.func @foo(%arg0: f32, %arg1: tensor<?xf32>) -> tensor<?xf32> {
+    %0 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel"]} ins(%arg0 : f32) outs(%arg1 : tensor<?xf32>) {
+    ^bb0(%in: f32, %out: f32):
+      %3 = arith.mulf %in, %out : f32
+      linalg.yield %3 : f32
+    } -> tensor<?xf32>
+    return %0 : tensor<?xf32>
+  }
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %1, %loops = transform.structured.tile_using_for %0[8] : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    %2 = transform.cast %loops : !transform.any_op to !transform.op<"scf.for">
+    %3 = transform.loop.loop_continuous_peel %2 {single_iter_opt = true} : (!transform.op<"scf.for">) -> (!transform.any_op)
+    transform.yield
+  }
+}
+
+// CHECK: #[[MAP:.*]] = affine_map<()[s0, s1, s2] -> (s1 - s1 mod s2)>
+// CHECK: #[[MAP1:.*]] = affine_map<() -> (8)>
+// CHECK: #[[MAP2:.*]] = affine_map<(d0) -> (d0 - 1)>
+// CHECK: #[[MAP3:.*]] = affine_map<(d0) -> ()>
+// CHECK: #[[MAP4:.*]] = affine_map<(d0) -> (d0)>
+
+// CHECK: func.func @foo(%[[S:.*]]: f32, %[[INVEC1:.*]]: tensor<?xf32>) -> tensor<?xf32> {
+// CHECK:       %[[C0:.*]] = arith.constant 0 : index
+// CHECK:       %[[DIM:.*]] = tensor.dim %[[INVEC1]], %[[C0]] : tensor<?xf32>
+// CHECK:       %[[C0:.*]] = arith.constant 0 : index
+// CHECK:       %{{.*}} = arith.constant 8 : index
+// CHECK:       %[[C8:.*]] = arith.constant 8 : index
+// CHECK:       %[[IDX0:.*]] = affine.apply #[[MAP]]()[%[[C0]], %[[DIM]], %[[C8]]]
+// CHECK:       %[[INS1:.*]] = scf.for %[[IDX:.*]] = %[[C0]] to %[[IDX0]] step %[[C8]] iter_args(%[[AINVEC1:.*]] = %[[INVEC1]]) -> (tensor<?xf32>) {
+// CHECK:         %{{.*}} = affine.apply #[[MAP2]](%[[C8]])
+// CHECK:         %[[XS8:.*]] = tensor.extract_slice %[[AINVEC1]][%[[IDX]]] [%[[C8]]] [1] : tensor<?xf32> to tensor<?xf32>
+// CHECK:         %[[MUL:.*]] = linalg.generic {indexing_maps = [#[[MAP3]], #[[MAP4]]], iterator_types = ["parallel"]} ins(%{{.*}} : f32) outs(%[[XS8]] : tensor<?xf32>) {
+// CHECK:         ^bb0(%{{.*}}: f32, %{{.*}}: f32):
+// CHECK:           %{{.*}} = arith.mulf %{{.*}}, %{{.*}} : f32
+// CHECK:           linalg.yield %{{.*}} : f32
+// CHECK:         } -> tensor<?xf32>
+// CHECK:         %[[INS:.*]] = tensor.insert_slice %[[MUL]] into %[[AINVEC1]][%[[IDX]]] [%[[C8]]] [1] : tensor<?xf32> into tensor<?xf32>
+// CHECK:         scf.yield %[[INS]] : tensor<?xf32>
+// CHECK:       }
+// CHECK:       %[[C4:.*]] = arith.constant 4 : index
+// CHECK:       %[[IDX2:.*]] = affine.apply #[[MAP]]()[%[[IDX0]], %[[DIM]], %[[C4]]]
+// CHECK:       %[[CMP3:.*]] = arith.cmpi slt, %[[IDX0]], %[[IDX2]] : index
+// CHECK:       %[[INS2:.*]] = scf.if %[[CMP3]] -> (tensor<?xf32>) {
+// CHECK:          %{{.*}} = affine.apply #[[MAP2]](%[[C4]])
+// CHECK:         %[[XS4:.*]] = tensor.extract_slice %[[INS1]][%[[IDX0]]] [%[[C4]]] [1] : tensor<?xf32> to tensor<?xf32>
+// CHECK:         %[[MUL:.*]] = linalg.generic {indexing_maps = [#[[MAP3]], #[[MAP4]]], iterator_types = ["parallel"]} ins(%[[S]] : f32) outs(%[[XS4]] : tensor<?xf32>) {
+// CHECK:         ^bb0(%{{.*}}: f32, %{{.*}}: f32):
+// CHECK:           %{{.*}} = arith.mulf  %{{.*}},  %{{.*}} : f32
+// CHECK:           linalg.yield  %{{.*}} : f32
+// CHECK:         } -> tensor<?xf32>
+// CHECK:         %[[INS:.*]] = tensor.insert_slice %[[MUL]] into %[[INS1]][%[[IDX0]]] [%[[C4]]] [1] : tensor<?xf32> into tensor<?xf32>
+// CHECK:         scf.yield %[[INS]] : tensor<?xf32>
+// CHECK:       } else {
+// CHECK:         scf.yield %[[INS1]] : tensor<?xf32>
+// CHECK:       }
+// CHECK:       %[[C2:.*]] = arith.constant 2 : index
+// CHECK:       %[[IDX3:.*]] = affine.apply #[[MAP]]()[%[[IDX2]], %[[DIM]], %[[C2]]]
+// CHECK:       %[[CMP4:.*]] = arith.cmpi slt, %[[IDX2]], %[[IDX3]] : index
+// CHECK:       %[[INS3:.*]] = scf.if %[[CMP4]] -> (tensor<?xf32>) {
+// CHECK:         %{{.*}} = affine.apply #[[MAP2]](%[[C2]])
+// CHECK:         %[[XS2:.*]] = tensor.extract_slice %[[INS2]][%[[IDX2]]] [%[[C2]]] [1] : tensor<?xf32> to tensor<?xf32>
+// CHECK:         %[[MUL:.*]] = linalg.generic {indexing_maps = [#[[MAP3]], #[[MAP4]]], iterator_types = ["parallel"]} ins(%[[S]] : f32) outs(%[[XS2]] : tensor<?xf32>) {
+// CHECK:         ^bb0(%{{.*}}: f32, %{{.*}}: f32):
+// CHECK:           %{{.*}} = arith.mulf %{{.*}}, %{{.*}} : f32
+// CHECK:           linalg.yield %{{.*}} : f32
+// CHECK:         } -> tensor<?xf32>
+// CHECK:         %[[INS:.*]] = tensor.insert_slice %[[MUL]] into %[[INS2]][%[[IDX2]]] [%[[C2]]] [1] : tensor<?xf32> into tensor<?xf32>
+// CHECK:         scf.yield %[[INS]] : tensor<?xf32>
+// CHECK:       } else {
+// CHECK:         scf.yield %[[INS2]] : tensor<?xf32>
+// CHECK:       }
+// CHECK:       %[[C1:.*]] = arith.constant 1 : index
+// CHECK:       %{{.*}} = affine.apply #[[MAP]]()[%[[IDX3]], %[[DIM]], %[[C1]]]
+// CHECK:       %[[CMP5:.*]] = arith.cmpi slt, %[[IDX3]], %[[DIM]] : index
+// CHECK:       %[[INS4:.*]] = scf.if %[[CMP5]] -> (tensor<?xf32>) {
+// CHECK:         %{{.*}} = affine.apply #[[MAP2]](%[[C1]])
+// CHECK:         %[[XS1:.*]] = tensor.extract_slice %[[INS3]][%[[IDX3]]] [%[[C1]]] [1] : tensor<?xf32> to tensor<?xf32>
+// CHECK:         %[[MUL:.*]] = linalg.generic {indexing_maps = [#[[MAP3]], #[[MAP4]]], iterator_types = ["parallel"]} ins(%[[S]] : f32) outs(%[[XS1]] : tensor<?xf32>) {
+// CHECK:         ^bb0(%{{.*}}: f32, %{{.*}}: f32):
+// CHECK:           %{{.*}} = arith.mulf %{{.*}}, %{{.*}} : f32
+// CHECK:           linalg.yield %{{.*}} : f32
+// CHECK:         } -> tensor<?xf32>
+// CHECK:         %[[INS:.*]] = tensor.insert_slice %[[MUL]] into %[[INS3]][%[[IDX3]]] [%[[C1]]] [1] : tensor<?xf32> into tensor<?xf32>
+// CHECK:         scf.yield %[[INS]] : tensor<?xf32>
+// CHECK:       } else {
+// CHECK:         scf.yield %[[INS3]] : tensor<?xf32>
+// CHECK:       }
+// CHECK:       return %[[INS4]] : tensor<?xf32>

>From 98ba0bd1177af08f286a0eaf16a85828f5410c7d Mon Sep 17 00:00:00 2001
From: Muneeb Khan <muneeb.khan at huawei.com>
Date: Tue, 7 Nov 2023 23:52:17 +0800
Subject: [PATCH 2/3] [MLIR][LLVM][Fixes] Add Continuous Loop Peeling transform
 to SCF

1. The transformation has been moved to
   mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp

2. Test in mlir/test/Dialect/SCF/loop-continuous-peel.mlir
   simplified using scf.for

3. Added -scf-for-loop-continuous-peeling pass for applying
   this transform independently on scf.for loops.

This commit should be squashed into the original.
---
 .../SCF/TransformOps/SCFTransformOps.td       |   5 +-
 .../mlir/Dialect/SCF/Transforms/Passes.h      |   6 +
 .../mlir/Dialect/SCF/Transforms/Passes.td     |  11 +
 .../mlir/Dialect/SCF/Transforms/Transforms.h  |  41 ++++
 .../SCF/TransformOps/SCFTransformOps.cpp      | 145 ++----------
 .../SCF/Transforms/LoopSpecialization.cpp     | 221 ++++++++++++++++++
 .../Dialect/SCF/loop-continuous-peel.mlir     | 134 ++++-------
 7 files changed, 336 insertions(+), 227 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
index e3d79a7f0ae40f3..10602271a9aa52c 100644
--- a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
+++ b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
@@ -167,9 +167,10 @@ def LoopContinuousPeelOp : Op<Transform_Dialect, "loop.loop_continuous_peel",
 
   let arguments =
       (ins TransformHandleTypeInterface:$target,
-           DefaultValuedAttr<BoolAttr, "false">:$single_iter_opt);
+           DefaultValuedAttr<BoolAttr, "false">:$convert_single_iter_loops_to_if);
   // TODO: Return both the peeled loop and the remainder loop.
-  let results = (outs TransformHandleTypeInterface:$transformed);
+  let results = (outs TransformHandleTypeInterface:$peeled_loop,
+                      TransformHandleTypeInterface:$remainder_loop);
 
   let assemblyFormat =
     "$target attr-dict `:` functional-type(operands, results)";
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/Passes.h b/mlir/include/mlir/Dialect/SCF/Transforms/Passes.h
index 90b315e83a8cfdb..9ff1d6a07f17c34 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/Passes.h
@@ -31,6 +31,12 @@ std::unique_ptr<Pass> createForLoopSpecializationPass();
 /// better vectorization.
 std::unique_ptr<Pass> createForLoopPeelingPass();
 
+/// Creates a pass that transforms a for loop into a chain of loops
+/// where the step size is always a power of 2 but decreases exponentially
+/// across the loops. Helps with dividing the iteration space across all
+/// resulting peeled loops evenly.
+std::unique_ptr<Pass> createForLoopContinuousPeelingPass();
+
 /// Creates a pass that canonicalizes affine.min and affine.max operations
 /// inside of scf.for loops with known lower and upper bounds.
 std::unique_ptr<Pass> createSCFForLoopCanonicalizationPass();
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td b/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td
index bbc673f44977ac9..daafb78a9134ccf 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td
@@ -40,6 +40,17 @@ def SCFForLoopPeeling : Pass<"scf-for-loop-peeling"> {
   let dependentDialects = ["affine::AffineDialect"];
 }
 
+def SCFForLoopContinuousPeeling : Pass<"scf-for-loop-continuous-peeling"> {
+  let summary = "Convert a loop into a chain of loops with exponentially decreasing steps that are power of 2.";
+  let constructor = "mlir::createForLoopContinuousPeelingPass()";
+  let options = [
+    Option<"convertSingleIterLoopsToIf", "convert-single-iter-loops-to-if", "bool",
+           /*default=*/"false",
+           "Convert single iteration loops to if. ">
+  ];
+  let dependentDialects = ["affine::AffineDialect"];
+}
+
 def SCFForLoopSpecialization : Pass<"scf-for-loop-specialization"> {
   let summary = "Specialize `for` loops for vectorization";
   let constructor = "mlir::createForLoopSpecializationPass()";
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h b/mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h
index 347beb9e4c64f8c..8a69b2076777312 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h
@@ -81,6 +81,47 @@ void naivelyFuseParallelOps(Region &region);
 LogicalResult peelForLoopAndSimplifyBounds(RewriterBase &rewriter, ForOp forOp,
                                            scf::ForOp &partialIteration);
 
+/// Rewrite a for loop with bounds/step that potentially do not divide the
+/// iteration space evenly into a chain of for loops where the step is a
+/// power of 2 and decreases exponentially across subsequent loops.
+///
+/// E.g., assuming a lower bound of 0, the following loop
+/// ```
+/// scf.for %iv = %c0 to %ub step %c8 {
+///   (loop body)
+/// }
+/// ```
+/// is rewritten into the following pseudo IR:
+/// ```
+/// %newUb = %ub - (%ub mod %c8)
+/// scf.for %iv = %c0 to %newUb step %c8 {
+///   (loop body)
+/// }
+/// %newUb2 = %ub - (%ub mod %c4)
+/// scf.for %iv2 = %newUb to %newUb2 {
+///   (loop body)
+/// }
+/// %newUb3 = %ub - (%ub mod %c2)
+/// scf.for %iv2 = %newUb2 to %newUb3 {
+///   (loop body)
+/// }
+/// scf.for %iv2 = %newUb3 to %ub {
+///   (loop body)
+/// }
+/// ```
+///
+/// Similar to loop peeling, this function simplifies the affine.min and
+/// affine.max ops in the body of each resulting for loop for better
+/// canonicalization opportunities.
+///
+/// The return value indicates if the loop was rewritten. The loop
+/// is not rewritten if the step size is 1 or dynamic.
+
+LogicalResult
+continuousPeelForLoopAndSimplifyBounds(RewriterBase &rewriter, ForOp forOp,
+                                       scf::ForOp &partialIteration,
+                                       bool convertSingleIterLoopsToIf);
+
 /// Tile a parallel loop of the form
 ///   scf.parallel (%i0, %i1) = (%arg0, %arg1) to (%arg2, %arg3)
 ///                                             step (%arg4, %arg5)
diff --git a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
index dcba6a8b406b21f..45b91f0d9b1fcf4 100644
--- a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
+++ b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
@@ -210,146 +210,27 @@ transform::LoopPeelOp::applyToOne(transform::TransformRewriter &rewriter,
 // LoopContinuousPeelOp
 //===---------------------------------------------------------------------===//
 
-static LogicalResult splitLoopHelper(RewriterBase &b, scf::ForOp &forOp,
-                                     scf::ForOp &partialIteration,
-                                     Value &splitBound) {
-  RewriterBase::InsertionGuard guard(b);
-  auto lbInt = getConstantIntValue(forOp.getLowerBound());
-  auto ubInt = getConstantIntValue(forOp.getUpperBound());
-  auto stepInt = getConstantIntValue(forOp.getStep());
-
-  // No specialization necessary if step already divides upper bound evenly.
-  if (lbInt && ubInt && stepInt && (*ubInt - *lbInt) % *stepInt == 0)
-    return failure();
-  // No specialization necessary if step size is 1.
-  if (stepInt == static_cast<int64_t>(1))
-    return failure();
-
-  // Create ForOp for partial iteration.
-  b.setInsertionPointAfter(forOp);
-  partialIteration = cast<scf::ForOp>(b.clone(*forOp.getOperation()));
-  partialIteration.getLowerBoundMutable().assign(splitBound);
-  forOp.replaceAllUsesWith(partialIteration->getResults());
-  partialIteration.getInitArgsMutable().assign(forOp->getResults());
-
-  // Set new upper loop bound.
-  b.updateRootInPlace(
-      forOp, [&]() { forOp.getUpperBoundMutable().assign(splitBound); });
-
-  return success();
-}
-
-static scf::IfOp convertSingleIterFor(RewriterBase &b, scf::ForOp &forOp) {
-  Location loc = forOp->getLoc();
-  IRMapping mapping;
-  mapping.map(forOp.getInductionVar(), forOp.getLowerBound());
-  for (auto [arg, operand] :
-       llvm::zip(forOp.getRegionIterArgs(), forOp.getInitsMutable())) {
-    mapping.map(arg, operand.get());
-  }
-  b.setInsertionPoint(forOp);
-  auto cond =
-      b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
-                              forOp.getLowerBound(), forOp.getUpperBound());
-  auto ifOp = b.create<scf::IfOp>(loc, forOp->getResultTypes(), cond, true);
-  // then branch
-  b.setInsertionPointToStart(ifOp.thenBlock());
-  for (Operation &op : forOp.getBody()->getOperations()) {
-    b.clone(op, mapping);
-  }
-  // else branch
-  b.setInsertionPointToStart(ifOp.elseBlock());
-  if (!forOp->getResultTypes().empty()) {
-    b.create<scf::YieldOp>(loc, forOp.getInits());
-  }
-  b.replaceOp(forOp, ifOp->getResults());
-  return ifOp;
-}
-
 DiagnosedSilenceableFailure transform::LoopContinuousPeelOp::applyToOne(
     transform::TransformRewriter &rewriter, Operation *target,
     transform::ApplyToEachResultList &results,
     transform::TransformState &state) {
-  scf::ForOp loop, currentLoop, partialLoop;
+  scf::ForOp loop, result;
   loop = dyn_cast<scf::ForOp>(target);
-  auto lbInt = getConstantIntValue(loop.getLowerBound());
-  auto stepInt = getConstantIntValue(loop.getStep());
-  if (!stepInt.has_value() || *stepInt <= 0)
-    return DiagnosedSilenceableFailure::
-        definiteFailure(); // step size must be a known positive constant
-  Value initialUb = loop.getUpperBound();
-  Value initialStep = loop.getStep();
-  uint64_t loopStep = *stepInt;
-  currentLoop = loop;
-  AffineExpr sym0, sym1, sym2;
-  bindSymbols(rewriter.getContext(), sym0, sym1, sym2);
-  AffineMap defaultSplitMap =
-      AffineMap::get(0, 3, {sym1 - ((sym1 - sym0) % sym2)});
-  AffineMap powerSplitMap = AffineMap::get(0, 3, {sym1 - (sym1 % sym2)});
-  bool usePowerSplit = (lbInt.has_value()) &&
-                       (*lbInt % *stepInt == static_cast<int64_t>(0)) &&
-                       (loopStep == llvm::bit_floor(loopStep));
-  AffineMap splitMap = usePowerSplit ? powerSplitMap : defaultSplitMap;
-  SmallVector<scf::ForOp> loops;
-  while (loopStep) {
-    rewriter.setInsertionPoint(currentLoop);
-    auto constStepOp =
-        rewriter.create<arith::ConstantIndexOp>(currentLoop.getLoc(), loopStep);
-    currentLoop.getStepMutable().assign(constStepOp);
-    rewriter.setInsertionPoint(currentLoop);
-    Value splitBound = rewriter.createOrFold<affine::AffineApplyOp>(
-        currentLoop.getLoc(), splitMap,
-        ValueRange{currentLoop.getLowerBound(), currentLoop.getUpperBound(),
-                   currentLoop.getStep()});
-    LogicalResult status =
-        splitLoopHelper(rewriter, currentLoop, partialLoop, splitBound);
-
-    // Canonicalize min/max affine operations
-    // It uses scf::rewritePeeledMinMaxOp to identify operations to be replaced,
-    // they are then replaced by the current step size.
-    // TODO: Alternative method - update affine map to reflect the loop step
-    // Example: min(ub - iv, 8) -> min(ub - iv, 4)
-    currentLoop.walk([&](affine::AffineMinOp affineOp) {
-      rewriter.setInsertionPoint(affineOp);
-      auto clonedOp = cast<affine::AffineMinOp>(rewriter.clone(*affineOp));
-      LogicalResult result = scf::rewritePeeledMinMaxOp(
-          rewriter, clonedOp, currentLoop.getInductionVar(), initialUb,
-          initialStep,
-          /*insideLoop=*/true);
-      if (result.succeeded())
-        rewriter.replaceOp(affineOp, currentLoop.getStep());
-      else
-        rewriter.eraseOp(clonedOp); // to avoid infinite walk
-    });
-    currentLoop.walk([&](affine::AffineMaxOp affineOp) {
-      rewriter.setInsertionPoint(affineOp);
-      auto clonedOp = cast<affine::AffineMaxOp>(rewriter.clone(*affineOp));
-      LogicalResult result = scf::rewritePeeledMinMaxOp(
-          rewriter, clonedOp, currentLoop.getInductionVar(), initialUb,
-          initialStep,
-          /*insideLoop=*/true);
-      if (result.succeeded())
-        rewriter.replaceOp(affineOp, currentLoop.getStep());
-      else
-        rewriter.eraseOp(clonedOp); // to avoid infinite walk
-    });
+  bool convertSingleIterLoopsToIf = false;
 
-    // Prepare for the next iteration
-    loops.push_back(currentLoop);
-    if (failed(status))
-      break;
-    currentLoop = partialLoop;
-    uint64_t maxPower = llvm::bit_floor(loopStep);
-    loopStep = maxPower == loopStep ? maxPower >> 1 : maxPower;
-  }
-  assert(loops.size() > 0 && "There should be at least one loop available");
-  if (getSingleIterOpt()) {
-    for (size_t i = 1; i < loops.size(); ++i) {
-      convertSingleIterFor(rewriter, loops[i]);
-    }
+  if (getConvertSingleIterLoopsToIf())
+    convertSingleIterLoopsToIf = true;
+
+  LogicalResult status = scf::continuousPeelForLoopAndSimplifyBounds(
+      rewriter, loop, result, convertSingleIterLoopsToIf);
+  if (failed(status)) {
+    DiagnosedSilenceableFailure diag =
+        emitSilenceableError() << "failed to perform continuous peeling";
+    return diag;
   }
 
-  results.push_back(loops.front());
+  results.push_back(loop);
+  results.push_back(result);
   return DiagnosedSilenceableFailure::success();
 }
 
diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp
index f208e5245977d83..e2bc0e410878d47 100644
--- a/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp
@@ -28,6 +28,7 @@
 
 namespace mlir {
 #define GEN_PASS_DEF_SCFFORLOOPPEELING
+#define GEN_PASS_DEF_SCFFORLOOPCONTINUOUSPEELING
 #define GEN_PASS_DEF_SCFFORLOOPSPECIALIZATION
 #define GEN_PASS_DEF_SCFPARALLELLOOPSPECIALIZATION
 #include "mlir/Dialect/SCF/Transforms/Passes.h.inc"
@@ -105,6 +106,165 @@ static void specializeForLoopForUnrolling(ForOp op) {
   op.erase();
 }
 
+static LogicalResult splitLoopHelper(RewriterBase &b, scf::ForOp &forOp,
+                                     scf::ForOp &partialIteration,
+                                     Value &splitBound) {
+  RewriterBase::InsertionGuard guard(b);
+  auto lbInt = getConstantIntValue(forOp.getLowerBound());
+  auto ubInt = getConstantIntValue(forOp.getUpperBound());
+  auto stepInt = getConstantIntValue(forOp.getStep());
+
+  // No specialization necessary if step already divides upper bound evenly.
+  if (lbInt && ubInt && stepInt && (*ubInt - *lbInt) % *stepInt == 0)
+    return failure();
+  // No specialization necessary if step size is 1.
+  if (stepInt == static_cast<int64_t>(1))
+    return failure();
+
+  // Create ForOp for partial iteration.
+  b.setInsertionPointAfter(forOp);
+  partialIteration = cast<scf::ForOp>(b.clone(*forOp.getOperation()));
+  partialIteration.getLowerBoundMutable().assign(splitBound);
+  forOp.replaceAllUsesWith(partialIteration->getResults());
+  partialIteration.getInitArgsMutable().assign(forOp->getResults());
+
+  // Set new upper loop bound.
+  b.updateRootInPlace(
+      forOp, [&]() { forOp.getUpperBoundMutable().assign(splitBound); });
+
+  return success();
+}
+
+static scf::IfOp convertSingleIterFor(RewriterBase &b, scf::ForOp &forOp) {
+  Location loc = forOp->getLoc();
+  IRMapping mapping;
+  mapping.map(forOp.getInductionVar(), forOp.getLowerBound());
+  for (auto [arg, operand] :
+       llvm::zip(forOp.getRegionIterArgs(), forOp.getInitsMutable())) {
+    mapping.map(arg, operand.get());
+  }
+  b.setInsertionPoint(forOp);
+  auto cond =
+      b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
+                              forOp.getLowerBound(), forOp.getUpperBound());
+  auto ifOp = b.create<scf::IfOp>(loc, forOp->getResultTypes(), cond, true);
+  // then branch
+  b.setInsertionPointToStart(ifOp.thenBlock());
+  for (Operation &op : forOp.getBody()->getOperations()) {
+    b.clone(op, mapping);
+  }
+  // else branch
+  b.setInsertionPointToStart(ifOp.elseBlock());
+  if (!forOp->getResultTypes().empty()) {
+    b.create<scf::YieldOp>(loc, forOp.getInits());
+  }
+  b.replaceOp(forOp, ifOp->getResults());
+  return ifOp;
+}
+
+/// Rewrite a for loop with bounds/step that potentially do not divide the
+/// iteration space evenly into a chain of for loops where the step is a
+/// power of 2 and decreases exponentially across subsequent loops. Helps
+/// divide the iteration space across all resulting peeled loops evenly.
+///
+/// Optionally, convert all single iteration for loops to if-else
+/// blocks when convert_single_iter_loops_to_if attribute is set to true or
+/// alternatively with the convert-single-iter-loops-to-if option for the
+/// scf-for-loop-continuous-peeling pass.
+static LogicalResult continuousPeelForLoop(RewriterBase &b, ForOp forOp,
+                                           ForOp &partialIteration,
+                                           bool convertSingleIterLoopsToIf) {
+
+  scf::ForOp currentLoop;
+  auto lbInt = getConstantIntValue(forOp.getLowerBound());
+  auto stepInt = getConstantIntValue(forOp.getStep());
+  if (!stepInt.has_value() || *stepInt <= 0)
+    return failure(); // step size must be a known positive constant
+  Value initialUb = forOp.getUpperBound();
+  Value initialStep = forOp.getStep();
+  uint64_t loopStep = *stepInt;
+  currentLoop = forOp;
+  AffineExpr sym0, sym1, sym2;
+  bindSymbols(b.getContext(), sym0, sym1, sym2);
+  AffineMap defaultSplitMap =
+      AffineMap::get(0, 3, {sym1 - ((sym1 - sym0) % sym2)});
+  AffineMap powerSplitMap = AffineMap::get(0, 3, {sym1 - (sym1 % sym2)});
+  bool usePowerSplit = (lbInt.has_value()) &&
+                       (*lbInt % *stepInt == static_cast<int64_t>(0)) &&
+                       (loopStep == llvm::bit_floor(loopStep));
+  AffineMap splitMap = usePowerSplit ? powerSplitMap : defaultSplitMap;
+  SmallVector<scf::ForOp> loops;
+  while (loopStep) {
+    b.setInsertionPoint(currentLoop);
+    auto constStepOp =
+        b.create<arith::ConstantIndexOp>(currentLoop.getLoc(), loopStep);
+    currentLoop.getStepMutable().assign(constStepOp);
+    b.setInsertionPoint(currentLoop);
+    Value splitBound = b.createOrFold<affine::AffineApplyOp>(
+        currentLoop.getLoc(), splitMap,
+        ValueRange{currentLoop.getLowerBound(), currentLoop.getUpperBound(),
+                   currentLoop.getStep()});
+    LogicalResult status =
+        splitLoopHelper(b, currentLoop, partialIteration, splitBound);
+
+    // Canonicalize min/max affine operations
+    // It uses scf::rewritePeeledMinMaxOp to identify operations to be replaced,
+    // they are then replaced by the current step size.
+    // TODO: Alternative method - update affine map to reflect the loop step
+    // Example: min(ub - iv, 8) -> min(ub - iv, 4)
+    currentLoop.walk([&](affine::AffineMinOp affineOp) {
+      b.setInsertionPoint(affineOp);
+      auto clonedOp = cast<affine::AffineMinOp>(b.clone(*affineOp));
+      LogicalResult result = scf::rewritePeeledMinMaxOp(
+          b, clonedOp, currentLoop.getInductionVar(), initialUb, initialStep,
+          /*insideLoop=*/true);
+      if (result.succeeded())
+        b.replaceOp(affineOp, currentLoop.getStep());
+      else
+        b.eraseOp(clonedOp); // to avoid infinite walk
+    });
+    currentLoop.walk([&](affine::AffineMaxOp affineOp) {
+      b.setInsertionPoint(affineOp);
+      auto clonedOp = cast<affine::AffineMaxOp>(b.clone(*affineOp));
+      LogicalResult result = scf::rewritePeeledMinMaxOp(
+          b, clonedOp, currentLoop.getInductionVar(), initialUb, initialStep,
+          /*insideLoop=*/true);
+      if (result.succeeded())
+        b.replaceOp(affineOp, currentLoop.getStep());
+      else
+        b.eraseOp(clonedOp); // to avoid infinite walk
+    });
+
+    // Prepare for the next iteration
+    loops.push_back(currentLoop);
+    if (failed(status))
+      break;
+    currentLoop = partialIteration;
+    uint64_t maxPower = llvm::bit_floor(loopStep);
+    loopStep = maxPower == loopStep ? maxPower >> 1 : maxPower;
+  }
+
+  assert(loops.size() > 0 && "There should be at least one loop available");
+  if (convertSingleIterLoopsToIf) {
+    for (size_t i = 1; i < loops.size(); ++i) {
+      convertSingleIterFor(b, loops[i]);
+    }
+  }
+
+  return success();
+}
+
+LogicalResult mlir::scf::continuousPeelForLoopAndSimplifyBounds(
+    RewriterBase &rewriter, ForOp forOp, ForOp &partialIteration,
+    bool convertSingleIterLoopsToIf) {
+
+  if (failed(continuousPeelForLoop(rewriter, forOp, partialIteration,
+                                   convertSingleIterLoopsToIf)))
+    return failure();
+
+  return success();
+}
+
 /// Rewrite a for loop with bounds/step that potentially do not divide evenly
 /// into a for loop where the step divides the iteration space evenly, followed
 /// by an scf.if for the last (partial) iteration (if any).
@@ -241,6 +401,45 @@ struct ForLoopPeelingPattern : public OpRewritePattern<ForOp> {
 };
 } // namespace
 
+namespace {
+struct ForLoopContinuousPeelingPattern : public OpRewritePattern<ForOp> {
+  ForLoopContinuousPeelingPattern(MLIRContext *ctx,
+                                  bool convertSingleIterLoopsToIf)
+      : OpRewritePattern<ForOp>(ctx),
+        convertSingleIterLoopsToIf(convertSingleIterLoopsToIf) {}
+
+  LogicalResult matchAndRewrite(ForOp forOp,
+                                PatternRewriter &rewriter) const override {
+    // Do not peel already peeled loops.
+    if (forOp->hasAttr(kPeeledLoopLabel))
+      return failure();
+
+    // Apply continuous loop peeling.
+    scf::ForOp partialIteration;
+    if (failed(continuousPeelForLoopAndSimplifyBounds(
+            rewriter, forOp, partialIteration, convertSingleIterLoopsToIf)))
+      return failure();
+
+    rewriter.updateRootInPlace(partialIteration, [&]() {
+      partialIteration->setAttr(kPeeledLoopLabel, rewriter.getUnitAttr());
+      partialIteration->setAttr(kPartialIterationLabel, rewriter.getUnitAttr());
+    });
+    rewriter.updateRootInPlace(forOp, [&]() {
+      forOp->setAttr(kPeeledLoopLabel, rewriter.getUnitAttr());
+    });
+
+    return success();
+  }
+
+  /// If set to true, loops inside partial iterations of another peeled loop
+  /// are not peeled. This reduces the size of the generated code. Partial
+  /// iterations are not usually performance critical.
+  /// Note: Takes into account the entire chain of parent operations, not just
+  /// the direct parent.
+  bool convertSingleIterLoopsToIf;
+};
+} // namespace
+
 namespace {
 struct ParallelLoopSpecialization
     : public impl::SCFParallelLoopSpecializationBase<
@@ -273,6 +472,24 @@ struct ForLoopPeeling : public impl::SCFForLoopPeelingBase<ForLoopPeeling> {
     });
   }
 };
+
+struct ForLoopContinuousPeeling
+    : public impl::SCFForLoopContinuousPeelingBase<ForLoopContinuousPeeling> {
+  void runOnOperation() override {
+    auto *parentOp = getOperation();
+    MLIRContext *ctx = parentOp->getContext();
+    RewritePatternSet patterns(ctx);
+    patterns.add<ForLoopContinuousPeelingPattern>(ctx,
+                                                  convertSingleIterLoopsToIf);
+    (void)applyPatternsAndFoldGreedily(parentOp, std::move(patterns));
+
+    // Drop the markers.
+    parentOp->walk([](Operation *op) {
+      op->removeAttr(kPeeledLoopLabel);
+      op->removeAttr(kPartialIterationLabel);
+    });
+  }
+};
 } // namespace
 
 std::unique_ptr<Pass> mlir::createParallelLoopSpecializationPass() {
@@ -286,3 +503,7 @@ std::unique_ptr<Pass> mlir::createForLoopSpecializationPass() {
 std::unique_ptr<Pass> mlir::createForLoopPeelingPass() {
   return std::make_unique<ForLoopPeeling>();
 }
+
+std::unique_ptr<Pass> mlir::createForLoopContinuousPeelingPass() {
+  return std::make_unique<ForLoopContinuousPeeling>();
+}
diff --git a/mlir/test/Dialect/SCF/loop-continuous-peel.mlir b/mlir/test/Dialect/SCF/loop-continuous-peel.mlir
index 752e1b1efed92ac..e051e6a43be70ea 100644
--- a/mlir/test/Dialect/SCF/loop-continuous-peel.mlir
+++ b/mlir/test/Dialect/SCF/loop-continuous-peel.mlir
@@ -1,98 +1,46 @@
-// RUN: mlir-opt %s --transform-interpreter -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -scf-for-loop-continuous-peeling=convert-single-iter-loops-to-if=true -split-input-file | FileCheck %s
 
-#map = affine_map<(d0) -> ()>
-#map1 = affine_map<(d0) -> (d0)>
-module {
-  func.func @foo(%arg0: f32, %arg1: tensor<?xf32>) -> tensor<?xf32> {
-    %0 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel"]} ins(%arg0 : f32) outs(%arg1 : tensor<?xf32>) {
-    ^bb0(%in: f32, %out: f32):
-      %3 = arith.mulf %in, %out : f32
-      linalg.yield %3 : f32
-    } -> tensor<?xf32>
-    return %0 : tensor<?xf32>
-  }
-}
-
-module attributes {transform.with_named_sequence} {
-  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
-    %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
-    %1, %loops = transform.structured.tile_using_for %0[8] : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
-    %2 = transform.cast %loops : !transform.any_op to !transform.op<"scf.for">
-    %3 = transform.loop.loop_continuous_peel %2 {single_iter_opt = true} : (!transform.op<"scf.for">) -> (!transform.any_op)
-    transform.yield
+#map = affine_map<(d0, d1)[s0] -> (s0, d0 - d1)>
+func.func @foo(%ub: index) -> index {
+  %c0 = arith.constant 0 : index
+  %step = arith.constant 8 : index
+  %0 = scf.for %iv = %c0 to %ub step %step iter_args(%arg = %c0) -> (index) {
+    %1 = affine.min #map(%ub, %iv)[%step]
+    %2 = index.add %1, %arg
+    scf.yield %2 : index
   }
+  return %0 : index
 }
 
 // CHECK: #[[MAP:.*]] = affine_map<()[s0, s1, s2] -> (s1 - s1 mod s2)>
-// CHECK: #[[MAP1:.*]] = affine_map<() -> (8)>
-// CHECK: #[[MAP2:.*]] = affine_map<(d0) -> (d0 - 1)>
-// CHECK: #[[MAP3:.*]] = affine_map<(d0) -> ()>
-// CHECK: #[[MAP4:.*]] = affine_map<(d0) -> (d0)>
-
-// CHECK: func.func @foo(%[[S:.*]]: f32, %[[INVEC1:.*]]: tensor<?xf32>) -> tensor<?xf32> {
-// CHECK:       %[[C0:.*]] = arith.constant 0 : index
-// CHECK:       %[[DIM:.*]] = tensor.dim %[[INVEC1]], %[[C0]] : tensor<?xf32>
-// CHECK:       %[[C0:.*]] = arith.constant 0 : index
-// CHECK:       %{{.*}} = arith.constant 8 : index
-// CHECK:       %[[C8:.*]] = arith.constant 8 : index
-// CHECK:       %[[IDX0:.*]] = affine.apply #[[MAP]]()[%[[C0]], %[[DIM]], %[[C8]]]
-// CHECK:       %[[INS1:.*]] = scf.for %[[IDX:.*]] = %[[C0]] to %[[IDX0]] step %[[C8]] iter_args(%[[AINVEC1:.*]] = %[[INVEC1]]) -> (tensor<?xf32>) {
-// CHECK:         %{{.*}} = affine.apply #[[MAP2]](%[[C8]])
-// CHECK:         %[[XS8:.*]] = tensor.extract_slice %[[AINVEC1]][%[[IDX]]] [%[[C8]]] [1] : tensor<?xf32> to tensor<?xf32>
-// CHECK:         %[[MUL:.*]] = linalg.generic {indexing_maps = [#[[MAP3]], #[[MAP4]]], iterator_types = ["parallel"]} ins(%{{.*}} : f32) outs(%[[XS8]] : tensor<?xf32>) {
-// CHECK:         ^bb0(%{{.*}}: f32, %{{.*}}: f32):
-// CHECK:           %{{.*}} = arith.mulf %{{.*}}, %{{.*}} : f32
-// CHECK:           linalg.yield %{{.*}} : f32
-// CHECK:         } -> tensor<?xf32>
-// CHECK:         %[[INS:.*]] = tensor.insert_slice %[[MUL]] into %[[AINVEC1]][%[[IDX]]] [%[[C8]]] [1] : tensor<?xf32> into tensor<?xf32>
-// CHECK:         scf.yield %[[INS]] : tensor<?xf32>
-// CHECK:       }
-// CHECK:       %[[C4:.*]] = arith.constant 4 : index
-// CHECK:       %[[IDX2:.*]] = affine.apply #[[MAP]]()[%[[IDX0]], %[[DIM]], %[[C4]]]
-// CHECK:       %[[CMP3:.*]] = arith.cmpi slt, %[[IDX0]], %[[IDX2]] : index
-// CHECK:       %[[INS2:.*]] = scf.if %[[CMP3]] -> (tensor<?xf32>) {
-// CHECK:          %{{.*}} = affine.apply #[[MAP2]](%[[C4]])
-// CHECK:         %[[XS4:.*]] = tensor.extract_slice %[[INS1]][%[[IDX0]]] [%[[C4]]] [1] : tensor<?xf32> to tensor<?xf32>
-// CHECK:         %[[MUL:.*]] = linalg.generic {indexing_maps = [#[[MAP3]], #[[MAP4]]], iterator_types = ["parallel"]} ins(%[[S]] : f32) outs(%[[XS4]] : tensor<?xf32>) {
-// CHECK:         ^bb0(%{{.*}}: f32, %{{.*}}: f32):
-// CHECK:           %{{.*}} = arith.mulf  %{{.*}},  %{{.*}} : f32
-// CHECK:           linalg.yield  %{{.*}} : f32
-// CHECK:         } -> tensor<?xf32>
-// CHECK:         %[[INS:.*]] = tensor.insert_slice %[[MUL]] into %[[INS1]][%[[IDX0]]] [%[[C4]]] [1] : tensor<?xf32> into tensor<?xf32>
-// CHECK:         scf.yield %[[INS]] : tensor<?xf32>
-// CHECK:       } else {
-// CHECK:         scf.yield %[[INS1]] : tensor<?xf32>
-// CHECK:       }
-// CHECK:       %[[C2:.*]] = arith.constant 2 : index
-// CHECK:       %[[IDX3:.*]] = affine.apply #[[MAP]]()[%[[IDX2]], %[[DIM]], %[[C2]]]
-// CHECK:       %[[CMP4:.*]] = arith.cmpi slt, %[[IDX2]], %[[IDX3]] : index
-// CHECK:       %[[INS3:.*]] = scf.if %[[CMP4]] -> (tensor<?xf32>) {
-// CHECK:         %{{.*}} = affine.apply #[[MAP2]](%[[C2]])
-// CHECK:         %[[XS2:.*]] = tensor.extract_slice %[[INS2]][%[[IDX2]]] [%[[C2]]] [1] : tensor<?xf32> to tensor<?xf32>
-// CHECK:         %[[MUL:.*]] = linalg.generic {indexing_maps = [#[[MAP3]], #[[MAP4]]], iterator_types = ["parallel"]} ins(%[[S]] : f32) outs(%[[XS2]] : tensor<?xf32>) {
-// CHECK:         ^bb0(%{{.*}}: f32, %{{.*}}: f32):
-// CHECK:           %{{.*}} = arith.mulf %{{.*}}, %{{.*}} : f32
-// CHECK:           linalg.yield %{{.*}} : f32
-// CHECK:         } -> tensor<?xf32>
-// CHECK:         %[[INS:.*]] = tensor.insert_slice %[[MUL]] into %[[INS2]][%[[IDX2]]] [%[[C2]]] [1] : tensor<?xf32> into tensor<?xf32>
-// CHECK:         scf.yield %[[INS]] : tensor<?xf32>
-// CHECK:       } else {
-// CHECK:         scf.yield %[[INS2]] : tensor<?xf32>
-// CHECK:       }
-// CHECK:       %[[C1:.*]] = arith.constant 1 : index
-// CHECK:       %{{.*}} = affine.apply #[[MAP]]()[%[[IDX3]], %[[DIM]], %[[C1]]]
-// CHECK:       %[[CMP5:.*]] = arith.cmpi slt, %[[IDX3]], %[[DIM]] : index
-// CHECK:       %[[INS4:.*]] = scf.if %[[CMP5]] -> (tensor<?xf32>) {
-// CHECK:         %{{.*}} = affine.apply #[[MAP2]](%[[C1]])
-// CHECK:         %[[XS1:.*]] = tensor.extract_slice %[[INS3]][%[[IDX3]]] [%[[C1]]] [1] : tensor<?xf32> to tensor<?xf32>
-// CHECK:         %[[MUL:.*]] = linalg.generic {indexing_maps = [#[[MAP3]], #[[MAP4]]], iterator_types = ["parallel"]} ins(%[[S]] : f32) outs(%[[XS1]] : tensor<?xf32>) {
-// CHECK:         ^bb0(%{{.*}}: f32, %{{.*}}: f32):
-// CHECK:           %{{.*}} = arith.mulf %{{.*}}, %{{.*}} : f32
-// CHECK:           linalg.yield %{{.*}} : f32
-// CHECK:         } -> tensor<?xf32>
-// CHECK:         %[[INS:.*]] = tensor.insert_slice %[[MUL]] into %[[INS3]][%[[IDX3]]] [%[[C1]]] [1] : tensor<?xf32> into tensor<?xf32>
-// CHECK:         scf.yield %[[INS]] : tensor<?xf32>
-// CHECK:       } else {
-// CHECK:         scf.yield %[[INS3]] : tensor<?xf32>
-// CHECK:       }
-// CHECK:       return %[[INS4]] : tensor<?xf32>
+// CHECK: func.func @foo(%[[UB:.*]]: index) -> index {
+// CHECK: %[[STEP8:.*]] = arith.constant 8 : index
+// CHECK: %[[STEP4:.*]] = arith.constant 4 : index
+// CHECK: %[[STEP2:.*]] = arith.constant 2 : index
+// CHECK: %[[STEP1:.*]] = arith.constant 1 : index
+// CHECK: %[[LB:.*]] = arith.constant 0 : index
+// CHECK: %[[I0:.*]] = affine.apply #[[MAP]]()[%[[LB]], %[[UB]], %[[STEP8]]]
+// CHECK: %[[I1:.*]] = scf.for %{{.*}} = %[[LB]] to %[[I0]] step %[[STEP8]] iter_args(%[[ALB:.*]] = %[[LB]]) -> (index) {
+// CHECK: %[[SUM:.*]] = index.add %[[ALB]], %[[STEP8]]
+// CHECK: scf.yield %[[SUM]] : index
+// CHECK: %[[I2:.*]] = affine.apply #[[MAP]]()[%[[I0]], %[[UB]], %[[STEP4]]]
+// CHECK: %[[I3:.*]] = arith.cmpi slt, %[[I0]], %[[I2]] : index
+// CHECK: %[[I4:.*]] = scf.if %[[I3]] -> (index) {
+// CHECK: %[[SUM:.*]] = index.add %[[I1]], %[[STEP4]]
+// CHECK: scf.yield %[[SUM]] : index
+// CHECK: } else {
+// CHECK: scf.yield %[[I1]] : index
+// CHECK: %[[I5:.*]] = affine.apply #[[MAP]]()[%[[I2]], %[[UB]], %[[STEP2]]]
+// CHECK: %[[I6:.*]] = arith.cmpi slt, %[[I2]], %[[I5]] : index
+// CHECK: %[[I7:.*]] = scf.if %[[I6]] -> (index) {
+// CHECK: %[[SUM:.*]] = index.add %[[I4]], %[[STEP2]]
+// CHECK: scf.yield %[[SUM]] : index
+// CHECK: } else {
+// CHECK: scf.yield %[[I4]] : index
+// CHECK: %[[I8:.*]] = arith.cmpi slt, %[[I5]], %[[UB]] : index
+// CHECK: %[[I9:.*]] = scf.if %[[I8]] -> (index) {
+// CHECK: %[[SUM:.*]] = index.add %[[I7]], %[[STEP1]]
+// CHECK: scf.yield %[[SUM]] : index
+// CHECK: } else {
+// CHECK: scf.yield %[[I7]] : index
+// CHECK: return %[[I9]] : index

>From 3a7d8b0c73d06a4b07ab654d69693af03a7f8e16 Mon Sep 17 00:00:00 2001
From: Muneeb Khan <muneeb.khan at huawei.com>
Date: Tue, 7 Nov 2023 23:52:17 +0800
Subject: [PATCH 3/3] [MLIR][LLVM][Fixes] Add Continuous Loop Peeling transform
 to SCF

Add case for step size 1

This commit should be squashed into the original.
---
 mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp | 7 +++++--
 1 file changed, 5 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp
index e2bc0e410878d47..1c32b483d0d1928 100644
--- a/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp
@@ -178,8 +178,11 @@ static LogicalResult continuousPeelForLoop(RewriterBase &b, ForOp forOp,
   scf::ForOp currentLoop;
   auto lbInt = getConstantIntValue(forOp.getLowerBound());
   auto stepInt = getConstantIntValue(forOp.getStep());
-  if (!stepInt.has_value() || *stepInt <= 0)
-    return failure(); // step size must be a known positive constant
+
+  // Step size must be a known positive constant greater than 1.
+  if (stepInt && stepInt <= static_cast<int64_t>(1))
+    return failure();
+
   Value initialUb = forOp.getUpperBound();
   Value initialStep = forOp.getStep();
   uint64_t loopStep = *stepInt;



More information about the Mlir-commits mailing list