[Mlir-commits] [mlir] [mlir][SCF] Modernize `coalesceLoops` method to handle `scf.for` loops with iter_args (PR #87019)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Mar 28 16:17:29 PDT 2024


https://github.com/MaheshRavishankar created https://github.com/llvm/llvm-project/pull/87019

As part of this extension this change also does some general cleanup

1) Make all the methods take `RewriterBase` as arguments instead of
   creating their own builders that tend to crash when used within
   pattern rewrites
2) For the induction variables being `index` types use
   `makeComposedFoldedAffineApply` to constant propagate where
   possible. The non-index types cant use this path, so they continue
   to generate the arith instructions
3) Split `coalesePerfectlyNestedLoops` into two separate methods, one
   for `scf.for` and other for `affine.for`. The templatization didnt
   seem to be buying much there.
4) Also add a canonicalization to `affine.delinearize_index` to drop
   the delinearization when the outer dimensions are `1` and replace
   those with `0`.

Also general clean up of tests.

>From ce943c01a17a1d26ef3f7470f5a49234ed4d90cf Mon Sep 17 00:00:00 2001
From: MaheshRavishankar <mahesh at nod-labs.com>
Date: Tue, 12 Mar 2024 22:43:11 -0700
Subject: [PATCH] [mlir][SCF] Modernize `coalesceLoops` method to handle
 `scf.for` loops with iter_args.

As part of this extension this change also does some general cleanup

1) Make all the methods take `RewriterBase` as arguments instead of
   creating their own builders that tend to crash when used within
   pattern rewrites
2) For the induction variables being `index` types use
   `makeComposedFoldedAffineApply` to constant propagate where
   possible. The non-index types cant use this path, so they continue
   to generate the arith instructions
3) Split `coalesePerfectlyNestedLoops` into two separate methods, one
   for `scf.for` and other for `affine.for`. The templatization didnt
   seem to be buying much there.
4) Also add a canonicalization to `affine.delinearize_index` to drop
   the delinearization when the outer dimensions are `1` and replace
   those with `0`.

Also general clean up of tests.
---
 .../mlir/Dialect/Affine/IR/AffineOps.td       |   1 +
 mlir/include/mlir/Dialect/Affine/LoopUtils.h  |  49 +-
 mlir/include/mlir/Dialect/SCF/Utils/Utils.h   |   7 +-
 mlir/include/mlir/IR/PatternMatch.h           |   9 +
 mlir/lib/Dialect/Affine/IR/AffineOps.cpp      |  36 ++
 .../Affine/Transforms/LoopCoalescing.cpp      |   8 +-
 mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp   |  46 ++
 .../SCF/TransformOps/SCFTransformOps.cpp      |   4 +-
 .../SCF/Transforms/ParallelLoopCollapsing.cpp |   9 +-
 mlir/lib/Dialect/SCF/Utils/Utils.cpp          | 421 +++++++++++++-----
 mlir/test/Dialect/Affine/canonicalize.mlir    |  16 +
 mlir/test/Dialect/Affine/loop-coalescing.mlir | 200 ++++-----
 .../Dialect/SCF/transform-op-coalesce.mlir    | 166 ++++++-
 .../Transforms/parallel-loop-collapsing.mlir  |  30 +-
 .../single-parallel-loop-collapsing.mlir      |  29 +-
 15 files changed, 717 insertions(+), 314 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
index edcfcfd830c443..a0b14614934519 100644
--- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
+++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
@@ -1095,6 +1095,7 @@ def AffineDelinearizeIndexOp : Affine_Op<"delinearize_index",
   ];
 
   let hasVerifier = 1;
+  let hasCanonicalizer = 1;
 }
 
 #endif // AFFINE_OPS
diff --git a/mlir/include/mlir/Dialect/Affine/LoopUtils.h b/mlir/include/mlir/Dialect/Affine/LoopUtils.h
index 723a262f24acc5..d143954b78fc12 100644
--- a/mlir/include/mlir/Dialect/Affine/LoopUtils.h
+++ b/mlir/include/mlir/Dialect/Affine/LoopUtils.h
@@ -299,53 +299,8 @@ LogicalResult
 separateFullTiles(MutableArrayRef<AffineForOp> nest,
                   SmallVectorImpl<AffineForOp> *fullTileNest = nullptr);
 
-/// Walk either an scf.for or an affine.for to find a band to coalesce.
-template <typename LoopOpTy>
-LogicalResult coalescePerfectlyNestedLoops(LoopOpTy op) {
-  LogicalResult result(failure());
-  SmallVector<LoopOpTy> loops;
-  getPerfectlyNestedLoops(loops, op);
-
-  // Look for a band of loops that can be coalesced, i.e. perfectly nested
-  // loops with bounds defined above some loop.
-  // 1. For each loop, find above which parent loop its operands are
-  // defined.
-  SmallVector<unsigned, 4> operandsDefinedAbove(loops.size());
-  for (unsigned i = 0, e = loops.size(); i < e; ++i) {
-    operandsDefinedAbove[i] = i;
-    for (unsigned j = 0; j < i; ++j) {
-      if (areValuesDefinedAbove(loops[i].getOperands(), loops[j].getRegion())) {
-        operandsDefinedAbove[i] = j;
-        break;
-      }
-    }
-  }
-
-  // 2. Identify bands of loops such that the operands of all of them are
-  // defined above the first loop in the band.  Traverse the nest bottom-up
-  // so that modifications don't invalidate the inner loops.
-  for (unsigned end = loops.size(); end > 0; --end) {
-    unsigned start = 0;
-    for (; start < end - 1; ++start) {
-      auto maxPos =
-          *std::max_element(std::next(operandsDefinedAbove.begin(), start),
-                            std::next(operandsDefinedAbove.begin(), end));
-      if (maxPos > start)
-        continue;
-      assert(maxPos == start &&
-             "expected loop bounds to be known at the start of the band");
-      auto band = llvm::MutableArrayRef(loops.data() + start, end - start);
-      if (succeeded(coalesceLoops(band)))
-        result = success();
-      break;
-    }
-    // If a band was found and transformed, keep looking at the loops above
-    // the outermost transformed loop.
-    if (start != end - 1)
-      end = start + 1;
-  }
-  return result;
-}
+/// Walk an affine.for to find a band to coalesce.
+LogicalResult coalescePerfectlyNestedAffineLoops(AffineForOp op);
 
 } // namespace affine
 } // namespace mlir
diff --git a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
index 883d11bcc4df06..bc09cc7f7fa5e0 100644
--- a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
@@ -100,11 +100,16 @@ getSCFMinMaxExpr(Value value, SmallVectorImpl<Value> &dims,
 /// `loops` contains a list of perfectly nested loops with bounds and steps
 /// independent of any loop induction variable involved in the nest.
 LogicalResult coalesceLoops(MutableArrayRef<scf::ForOp> loops);
+LogicalResult coalesceLoops(RewriterBase &rewriter,
+                            MutableArrayRef<scf::ForOp>);
+
+/// Walk an affine.for to find a band to coalesce.
+LogicalResult coalescePerfectlyNestedSCFForLoops(scf::ForOp op);
 
 /// Take the ParallelLoop and for each set of dimension indices, combine them
 /// into a single dimension. combinedDimensions must contain each index into
 /// loops exactly once.
-void collapseParallelLoops(scf::ParallelOp loops,
+void collapseParallelLoops(RewriterBase &rewriter, scf::ParallelOp loops,
                            ArrayRef<std::vector<unsigned>> combinedDimensions);
 
 /// Unrolls this for operation by the specified unroll factor. Returns failure
diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h
index 070e6ed702f86a..fabe4cc401cff5 100644
--- a/mlir/include/mlir/IR/PatternMatch.h
+++ b/mlir/include/mlir/IR/PatternMatch.h
@@ -12,6 +12,7 @@
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinOps.h"
 #include "llvm/ADT/FunctionExtras.h"
+#include "llvm/ADT/SmallPtrSet.h"
 #include "llvm/Support/TypeName.h"
 #include <optional>
 
@@ -697,6 +698,14 @@ class RewriterBase : public OpBuilder {
       return user != exceptedUser;
     });
   }
+  void
+  replaceAllUsesExcept(Value from, Value to,
+                       const SmallPtrSetImpl<Operation *> &preservedUsers) {
+    return replaceUsesWithIf(from, to, [&](OpOperand &use) {
+      Operation *user = use.getOwner();
+      return !preservedUsers.contains(user);
+    });
+  }
 
   /// Used to notify the listener that the IR failed to be rewritten because of
   /// a match failure, and provide a callback to populate a diagnostic with the
diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index c591e5056480ca..4837ced453fa43 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -4532,6 +4532,42 @@ LogicalResult AffineDelinearizeIndexOp::verify() {
   return success();
 }
 
+namespace {
+// When outer dimension used for delinearization are ones, the corresponding
+// results can all be replaced by zeros.
+struct DropUnitOuterDelinearizeDims
+    : public OpRewritePattern<AffineDelinearizeIndexOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(AffineDelinearizeIndexOp indexOp,
+                                PatternRewriter &rewriter) const override {
+    ValueRange basis = indexOp.getBasis();
+    if (basis.empty()) {
+      return failure();
+    }
+    std::optional<int64_t> basisValue =
+        getConstantIntValue(getAsOpFoldResult(basis.front()));
+    if (!basisValue || basisValue != 1) {
+      return failure();
+    }
+    SmallVector<Value> replacements;
+    Location loc = indexOp.getLoc();
+    replacements.push_back(rewriter.create<arith::ConstantIndexOp>(loc, 0));
+    auto newIndexOp = rewriter.create<AffineDelinearizeIndexOp>(
+        loc, indexOp.getLinearIndex(), basis.drop_front());
+    replacements.append(newIndexOp->result_begin(), newIndexOp->result_end());
+    rewriter.replaceOp(indexOp, replacements);
+    return success();
+  }
+};
+
+} // namespace
+
+void AffineDelinearizeIndexOp::getCanonicalizationPatterns(
+    RewritePatternSet &results, MLIRContext *context) {
+  results.add<DropUnitOuterDelinearizeDims>(context);
+}
+
 //===----------------------------------------------------------------------===//
 // TableGen'd op method definitions
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Affine/Transforms/LoopCoalescing.cpp b/mlir/lib/Dialect/Affine/Transforms/LoopCoalescing.cpp
index 1dc69ab493d477..1f23055544d2a5 100644
--- a/mlir/lib/Dialect/Affine/Transforms/LoopCoalescing.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/LoopCoalescing.cpp
@@ -35,13 +35,17 @@ namespace {
 struct LoopCoalescingPass
     : public affine::impl::LoopCoalescingBase<LoopCoalescingPass> {
 
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry.insert<affine::AffineDialect>();
+  }
+
   void runOnOperation() override {
     func::FuncOp func = getOperation();
     func.walk<WalkOrder::PreOrder>([](Operation *op) {
       if (auto scfForOp = dyn_cast<scf::ForOp>(op))
-        (void)coalescePerfectlyNestedLoops(scfForOp);
+        (void)coalescePerfectlyNestedSCFForLoops(scfForOp);
       else if (auto affineForOp = dyn_cast<AffineForOp>(op))
-        (void)coalescePerfectlyNestedLoops(affineForOp);
+        (void)coalescePerfectlyNestedAffineLoops(affineForOp);
     });
   }
 };
diff --git a/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp b/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp
index af59973d7a92c5..0b2885e6396aae 100644
--- a/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp
+++ b/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp
@@ -2765,3 +2765,49 @@ mlir::affine::separateFullTiles(MutableArrayRef<AffineForOp> inputNest,
 
   return success();
 }
+
+LogicalResult affine::coalescePerfectlyNestedAffineLoops(AffineForOp op) {
+  LogicalResult result(failure());
+  SmallVector<AffineForOp> loops;
+  getPerfectlyNestedLoops(loops, op);
+
+  // Look for a band of loops that can be coalesced, i.e. perfectly nested
+  // loops with bounds defined above some loop.
+  // 1. For each loop, find above which parent loop its operands are
+  // defined.
+  SmallVector<unsigned, 4> operandsDefinedAbove(loops.size());
+  for (unsigned i = 0, e = loops.size(); i < e; ++i) {
+    operandsDefinedAbove[i] = i;
+    for (unsigned j = 0; j < i; ++j) {
+      if (areValuesDefinedAbove(loops[i].getOperands(), loops[j].getRegion())) {
+        operandsDefinedAbove[i] = j;
+        break;
+      }
+    }
+  }
+
+  // 2. Identify bands of loops such that the operands of all of them are
+  // defined above the first loop in the band.  Traverse the nest bottom-up
+  // so that modifications don't invalidate the inner loops.
+  for (unsigned end = loops.size(); end > 0; --end) {
+    unsigned start = 0;
+    for (; start < end - 1; ++start) {
+      auto maxPos =
+          *std::max_element(std::next(operandsDefinedAbove.begin(), start),
+                            std::next(operandsDefinedAbove.begin(), end));
+      if (maxPos > start)
+        continue;
+      assert(maxPos == start &&
+             "expected loop bounds to be known at the start of the band");
+      auto band = llvm::MutableArrayRef(loops.data() + start, end - start);
+      if (succeeded(coalesceLoops(band)))
+        result = success();
+      break;
+    }
+    // If a band was found and transformed, keep looking at the loops above
+    // the outermost transformed loop.
+    if (start != end - 1)
+      end = start + 1;
+  }
+  return result;
+}
diff --git a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
index c0918414820803..7e4faf8b73afbb 100644
--- a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
+++ b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
@@ -332,9 +332,9 @@ transform::LoopCoalesceOp::applyToOne(transform::TransformRewriter &rewriter,
                                       transform::TransformState &state) {
   LogicalResult result(failure());
   if (scf::ForOp scfForOp = dyn_cast<scf::ForOp>(op))
-    result = coalescePerfectlyNestedLoops(scfForOp);
+    result = coalescePerfectlyNestedSCFForLoops(scfForOp);
   else if (AffineForOp affineForOp = dyn_cast<AffineForOp>(op))
-    result = coalescePerfectlyNestedLoops(affineForOp);
+    result = coalescePerfectlyNestedAffineLoops(affineForOp);
 
   results.push_back(op);
   if (failed(result)) {
diff --git a/mlir/lib/Dialect/SCF/Transforms/ParallelLoopCollapsing.cpp b/mlir/lib/Dialect/SCF/Transforms/ParallelLoopCollapsing.cpp
index a69df025bcba81..ada0c971cb86bf 100644
--- a/mlir/lib/Dialect/SCF/Transforms/ParallelLoopCollapsing.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/ParallelLoopCollapsing.cpp
@@ -8,6 +8,7 @@
 
 #include "mlir/Dialect/SCF/Transforms/Passes.h"
 
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Dialect/SCF/Utils/Utils.h"
 #include "mlir/Transforms/RegionUtils.h"
@@ -28,6 +29,11 @@ namespace {
 struct TestSCFParallelLoopCollapsing
     : public impl::TestSCFParallelLoopCollapsingBase<
           TestSCFParallelLoopCollapsing> {
+
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry.insert<affine::AffineDialect>();
+  }
+
   void runOnOperation() override {
     Operation *module = getOperation();
 
@@ -88,6 +94,7 @@ struct TestSCFParallelLoopCollapsing
     // Only apply the transformation on parallel loops where the specified
     // transformation is valid, but do NOT early abort in the case of invalid
     // loops.
+    IRRewriter rewriter(&getContext());
     module->walk([&](scf::ParallelOp op) {
       if (flattenedCombinedLoops.size() != op.getNumLoops()) {
         op.emitOpError("has ")
@@ -97,7 +104,7 @@ struct TestSCFParallelLoopCollapsing
             << flattenedCombinedLoops.size() << " iter args.";
         return;
       }
-      collapseParallelLoops(op, combinedLoops);
+      collapseParallelLoops(rewriter, op, combinedLoops);
     });
   }
 };
diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
index 914aeb4fa79fda..ac42a21a883fa3 100644
--- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
@@ -12,7 +12,9 @@
 
 #include "mlir/Dialect/SCF/Utils/Utils.h"
 #include "mlir/Analysis/SliceAnalysis.h"
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Arith/Utils/Utils.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/IR/BuiltinOps.h"
@@ -472,18 +474,43 @@ LogicalResult mlir::loopUnrollByFactor(
   return success();
 }
 
-/// Return the new lower bound, upper bound, and step in that order. Insert any
-/// additional bounds calculations before the given builder and any additional
-/// conversion back to the original loop induction value inside the given Block.
-static LoopParams normalizeLoop(OpBuilder &boundsBuilder,
-                                OpBuilder &insideLoopBuilder, Location loc,
-                                Value lowerBound, Value upperBound, Value step,
-                                Value inductionVar) {
+/// Transform a loop with a strictly positive step
+///   for %i = %lb to %ub step %s
+/// into a 0-based loop with step 1
+///   for %ii = 0 to ceildiv(%ub - %lb, %s) step 1 {
+///     %i = %ii * %s + %lb
+/// Insert the induction variable remapping in the body of `inner`, which is
+/// expected to be either `loop` or another loop perfectly nested under `loop`.
+/// Insert the definition of new bounds immediate before `outer`, which is
+/// expected to be either `loop` or its parent in the loop nest.
+static OpFoldResult normalizeLoop(RewriterBase &rewriter, Location loc,
+                                  OpFoldResult lb, OpFoldResult ub,
+                                  OpFoldResult step) {
+  AffineExpr s0, s1, s2;
+  bindSymbols(rewriter.getContext(), s0, s1, s2);
+  AffineExpr normalizeExpr = (s1 - s0).ceilDiv(s2);
+
+  OpFoldResult newUb = affine::makeComposedFoldedAffineApply(
+      rewriter, loc, normalizeExpr, {lb, ub, step});
+  return newUb;
+}
+static LoopParams normalizeLoop(RewriterBase &rewriter, Location loc, Value lb,
+                                Value ub, Value step) {
+  auto isIndexType = [](Value v) { return v.getType().isa<IndexType>(); };
+  if (isIndexType(lb) && isIndexType(ub) && isIndexType(step)) {
+    OpFoldResult newUb =
+        normalizeLoop(rewriter, loc, getAsOpFoldResult(lb),
+                      getAsOpFoldResult(ub), getAsOpFoldResult(step));
+    Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+    Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
+    return {zero, getValueOrCreateConstantIndexOp(rewriter, loc, newUb), one};
+  }
+  // For non-index types, generate `arith` instructions
   // Check if the loop is already known to have a constant zero lower bound or
   // a constant one step.
   bool isZeroBased = false;
-  if (auto ubCst = getConstantIntValue(lowerBound))
-    isZeroBased = ubCst.value() == 0;
+  if (auto lbCst = getConstantIntValue(lb))
+    isZeroBased = lbCst.value() == 0;
 
   bool isStepOne = false;
   if (auto stepCst = getConstantIntValue(step))
@@ -493,62 +520,130 @@ static LoopParams normalizeLoop(OpBuilder &boundsBuilder,
   // assuming the step is strictly positive.  Update the bounds and the step
   // of the loop to go from 0 to the number of iterations, if necessary.
   if (isZeroBased && isStepOne)
-    return {/*lowerBound=*/lowerBound, /*upperBound=*/upperBound,
-            /*step=*/step};
+    return {lb, ub, step};
 
-  Value diff = boundsBuilder.create<arith::SubIOp>(loc, upperBound, lowerBound);
+  Value diff = isZeroBased ? ub : rewriter.create<arith::SubIOp>(loc, ub, lb);
   Value newUpperBound =
-      boundsBuilder.create<arith::CeilDivSIOp>(loc, diff, step);
-
-  Value newLowerBound =
-      isZeroBased ? lowerBound
-                  : boundsBuilder.create<arith::ConstantOp>(
-                        loc, boundsBuilder.getZeroAttr(lowerBound.getType()));
-  Value newStep =
-      isStepOne ? step
-                : boundsBuilder.create<arith::ConstantOp>(
-                      loc, boundsBuilder.getIntegerAttr(step.getType(), 1));
-
-  // Insert code computing the value of the original loop induction variable
-  // from the "normalized" one.
-  Value scaled =
-      isStepOne
-          ? inductionVar
-          : insideLoopBuilder.create<arith::MulIOp>(loc, inductionVar, step);
-  Value shifted =
-      isZeroBased
-          ? scaled
-          : insideLoopBuilder.create<arith::AddIOp>(loc, scaled, lowerBound);
-
-  SmallPtrSet<Operation *, 2> preserve{scaled.getDefiningOp(),
-                                       shifted.getDefiningOp()};
-  inductionVar.replaceAllUsesExcept(shifted, preserve);
-  return {/*lowerBound=*/newLowerBound, /*upperBound=*/newUpperBound,
-          /*step=*/newStep};
+      isStepOne ? diff : rewriter.create<arith::CeilDivSIOp>(loc, diff, step);
+
+  Value newLowerBound = isZeroBased
+                            ? lb
+                            : rewriter.create<arith::ConstantOp>(
+                                  loc, rewriter.getZeroAttr(lb.getType()));
+  Value newStep = isStepOne
+                      ? step
+                      : rewriter.create<arith::ConstantOp>(
+                            loc, rewriter.getIntegerAttr(step.getType(), 1));
+
+  return {newLowerBound, newUpperBound, newStep};
 }
 
-/// Transform a loop with a strictly positive step
-///   for %i = %lb to %ub step %s
-/// into a 0-based loop with step 1
-///   for %ii = 0 to ceildiv(%ub - %lb, %s) step 1 {
-///     %i = %ii * %s + %lb
-/// Insert the induction variable remapping in the body of `inner`, which is
-/// expected to be either `loop` or another loop perfectly nested under `loop`.
-/// Insert the definition of new bounds immediate before `outer`, which is
-/// expected to be either `loop` or its parent in the loop nest.
-static void normalizeLoop(scf::ForOp loop, scf::ForOp outer, scf::ForOp inner) {
-  OpBuilder builder(outer);
-  OpBuilder innerBuilder = OpBuilder::atBlockBegin(inner.getBody());
-  auto loopPieces = normalizeLoop(builder, innerBuilder, loop.getLoc(),
-                                  loop.getLowerBound(), loop.getUpperBound(),
-                                  loop.getStep(), loop.getInductionVar());
-
-  loop.setLowerBound(loopPieces.lowerBound);
-  loop.setUpperBound(loopPieces.upperBound);
-  loop.setStep(loopPieces.step);
+/// Get back the original induction variable values after loop normalization
+static void unNormalizeInductionVariable(RewriterBase &rewriter, Location loc,
+                                         Value normalizedIv, Value origLb,
+                                         Value origStep) {
+  Value unNormalizedIv;
+  std::optional<Operation *> preserve;
+  if (normalizedIv.getType().isa<IndexType>()) {
+    AffineExpr s0, s1, s2;
+    bindSymbols(rewriter.getContext(), s0, s1, s2);
+    AffineExpr ivExpr = (s0 * s1) + s2;
+    OpFoldResult newIv = affine::makeComposedFoldedAffineApply(
+        rewriter, loc, ivExpr,
+        ArrayRef<OpFoldResult>{normalizedIv, origStep, origLb});
+    unNormalizedIv = getValueOrCreateConstantIndexOp(rewriter, loc, newIv);
+    preserve = unNormalizedIv.getDefiningOp();
+  } else {
+    bool isStepOne = isConstantIntValue(origStep, 1);
+    bool isZeroBased = isConstantIntValue(origLb, 0);
+
+    Value scaled = normalizedIv;
+    if (!isStepOne) {
+      scaled = rewriter.create<arith::MulIOp>(loc, normalizedIv, origStep);
+      preserve = scaled.getDefiningOp();
+    }
+    unNormalizedIv = scaled;
+    if (!isZeroBased)
+      unNormalizedIv = rewriter.create<arith::AddIOp>(loc, scaled, origLb);
+  }
+
+  if (preserve) {
+    rewriter.replaceAllUsesExcept(normalizedIv, unNormalizedIv,
+                                  preserve.value());
+  } else {
+    rewriter.replaceAllUsesWith(normalizedIv, unNormalizedIv);
+  }
+  return;
 }
 
-LogicalResult mlir::coalesceLoops(MutableArrayRef<scf::ForOp> loops) {
+/// Helper function to multiply a sequence of values.
+static OpFoldResult getProductOfIndexes(RewriterBase &rewriter, Location loc,
+                                        ArrayRef<OpFoldResult> values) {
+  AffineExpr s0, s1;
+  bindSymbols(rewriter.getContext(), s0, s1);
+  AffineExpr mulExpr = s0 * s1;
+  OpFoldResult productOf = rewriter.getIndexAttr(1);
+  for (auto v : values) {
+    productOf = affine::makeComposedFoldedAffineApply(rewriter, loc, mulExpr,
+                                                      {productOf, v});
+  }
+  return productOf;
+}
+static Value getProductOfIntsOrIndexes(RewriterBase &rewriter, Location loc,
+                                       ArrayRef<Value> values) {
+  assert(!values.empty() && "unexpected empty list");
+  if (values.front().getType().isa<IndexType>()) {
+    return getValueOrCreateConstantIndexOp(
+        rewriter, loc,
+        getProductOfIndexes(rewriter, loc, getAsOpFoldResult(values)));
+  }
+  Value productOf = values.front();
+  for (auto v : values.drop_front()) {
+    productOf = rewriter.create<arith::MulIOp>(loc, productOf, v);
+  }
+  return productOf;
+}
+
+// For each original loop, the value of the
+// induction variable can be obtained by dividing the induction variable of
+// the linearized loop by the total number of iterations of the loops nested
+// in it modulo the number of iterations in this loop (remove the values
+// related to the outer loops):
+//   iv_i = floordiv(iv_linear, product-of-loop-ranges-until-i) mod range_i.
+// Compute these iteratively from the innermost loop by creating a "running
+// quotient" of division by the range.
+static std::pair<SmallVector<Value>, SmallPtrSet<Operation *, 2>>
+delinearizeInductionVariable(RewriterBase &rewriter, Location loc,
+                             Value linearizedIv, ArrayRef<Value> ubs) {
+  auto isIndexType = [](Value v) { return v.getType().isa<IndexType>(); };
+  if (linearizedIv.getType().isa<IndexType>()) {
+    auto delinearizeIvs = rewriter.create<affine::AffineDelinearizeIndexOp>(
+        loc, linearizedIv, ubs);
+    return {llvm::map_to_vector(delinearizeIvs.getResults(),
+                                [](OpResult res) -> Value { return res; }),
+            {delinearizeIvs}};
+  }
+  Value previous = linearizedIv;
+  SmallVector<Value> delinearizedIvs(ubs.size());
+  SmallPtrSet<Operation *, 2> preservedUsers;
+  for (unsigned i = 0, e = ubs.size(); i < e; ++i) {
+    unsigned idx = ubs.size() - i - 1;
+    if (i != 0) {
+      previous = rewriter.create<arith::DivSIOp>(loc, previous, ubs[idx + 1]);
+      preservedUsers.insert(previous.getDefiningOp());
+    }
+    Value iv = previous;
+    if (i != e - 1) {
+      iv = rewriter.create<arith::RemSIOp>(loc, previous, ubs[idx]);
+      preservedUsers.insert(iv.getDefiningOp());
+    }
+    delinearizedIvs[idx] = iv;
+  }
+  return {delinearizedIvs, preservedUsers};
+}
+
+LogicalResult mlir::coalesceLoops(RewriterBase &rewriter,
+                                  MutableArrayRef<scf::ForOp> loops) {
   if (loops.size() < 2)
     return failure();
 
@@ -557,57 +652,151 @@ LogicalResult mlir::coalesceLoops(MutableArrayRef<scf::ForOp> loops) {
 
   // 1. Make sure all loops iterate from 0 to upperBound with step 1.  This
   // allows the following code to assume upperBound is the number of iterations.
-  for (auto loop : loops)
-    normalizeLoop(loop, outermost, innermost);
+  for (auto loop : loops) {
+    OpBuilder::InsertionGuard g(rewriter);
+    rewriter.setInsertionPoint(outermost);
+    Value lb = loop.getLowerBound();
+    Value ub = loop.getUpperBound();
+    Value step = loop.getStep();
+    auto newLoopParams = normalizeLoop(rewriter, loop.getLoc(), lb, ub, step);
+    loop.setLowerBound(newLoopParams.lowerBound);
+    loop.setUpperBound(newLoopParams.upperBound);
+    loop.setStep(newLoopParams.step);
+
+    rewriter.setInsertionPointToStart(innermost.getBody());
+    unNormalizeInductionVariable(rewriter, loop.getLoc(),
+                                 loop.getInductionVar(), lb, step);
+  }
 
   // 2. Emit code computing the upper bound of the coalesced loop as product
   // of the number of iterations of all loops.
-  OpBuilder builder(outermost);
+  OpBuilder::InsertionGuard g(rewriter);
+  rewriter.setInsertionPoint(outermost);
   Location loc = outermost.getLoc();
-  Value upperBound = outermost.getUpperBound();
-  for (auto loop : loops.drop_front())
-    upperBound =
-        builder.create<arith::MulIOp>(loc, upperBound, loop.getUpperBound());
+  SmallVector<Value> upperBounds = llvm::map_to_vector(
+      loops, [](auto loop) { return loop.getUpperBound(); });
+  Value upperBound = getProductOfIntsOrIndexes(rewriter, loc, upperBounds);
   outermost.setUpperBound(upperBound);
 
-  builder.setInsertionPointToStart(outermost.getBody());
-
-  // 3. Remap induction variables. For each original loop, the value of the
-  // induction variable can be obtained by dividing the induction variable of
-  // the linearized loop by the total number of iterations of the loops nested
-  // in it modulo the number of iterations in this loop (remove the values
-  // related to the outer loops):
-  //   iv_i = floordiv(iv_linear, product-of-loop-ranges-until-i) mod range_i.
-  // Compute these iteratively from the innermost loop by creating a "running
-  // quotient" of division by the range.
-  Value previous = outermost.getInductionVar();
-  for (unsigned i = 0, e = loops.size(); i < e; ++i) {
-    unsigned idx = loops.size() - i - 1;
-    if (i != 0)
-      previous = builder.create<arith::DivSIOp>(loc, previous,
-                                                loops[idx + 1].getUpperBound());
-
-    Value iv = (i == e - 1) ? previous
-                            : builder.create<arith::RemSIOp>(
-                                  loc, previous, loops[idx].getUpperBound());
-    replaceAllUsesInRegionWith(loops[idx].getInductionVar(), iv,
-                               loops.back().getRegion());
-  }
-
-  // 4. Move the operations from the innermost just above the second-outermost
-  // loop, delete the extra terminator and the second-outermost loop.
-  scf::ForOp second = loops[1];
-  innermost.getBody()->back().erase();
-  outermost.getBody()->getOperations().splice(
-      Block::iterator(second.getOperation()),
-      innermost.getBody()->getOperations());
-  second.erase();
+  rewriter.setInsertionPointToStart(innermost.getBody());
+  auto [delinearizeIvs, preservedUsers] = delinearizeInductionVariable(
+      rewriter, loc, outermost.getInductionVar(), upperBounds);
+  rewriter.replaceAllUsesExcept(outermost.getInductionVar(), delinearizeIvs[0],
+                                preservedUsers);
+
+  for (int i = loops.size() - 1; i > 0; --i) {
+    auto outerLoop = loops[i - 1];
+    auto innerLoop = loops[i];
+
+    rewriter.replaceAllUsesWith(innerLoop.getInductionVar(), delinearizeIvs[i]);
+    for (auto [outerLoopIterArg, innerLoopIterArg] : llvm::zip_equal(
+             outerLoop.getRegionIterArgs(), innerLoop.getRegionIterArgs())) {
+      rewriter.replaceAllUsesExcept(innerLoopIterArg, outerLoopIterArg,
+                                    preservedUsers);
+    }
+    Operation *innerTerminator = innerLoop.getBody()->getTerminator();
+    auto yieldedVals = llvm::to_vector(innerTerminator->getOperands());
+    rewriter.eraseOp(innerTerminator);
+    outerLoop.getBody()->getOperations().splice(
+        Block::iterator(innerLoop), innerLoop.getBody()->getOperations());
+    rewriter.replaceOp(innerLoop, yieldedVals);
+  }
   return success();
 }
 
+LogicalResult mlir::coalesceLoops(MutableArrayRef<scf::ForOp> loops) {
+  if (loops.empty()) {
+    return success();
+  }
+  IRRewriter rewriter(loops.front().getContext());
+  return coalesceLoops(rewriter, loops);
+}
+
+LogicalResult mlir::coalescePerfectlyNestedSCFForLoops(scf::ForOp op) {
+  LogicalResult result(failure());
+  SmallVector<scf::ForOp> loops;
+  getPerfectlyNestedLoops(loops, op);
+
+  // Look for a band of loops that can be coalesced, i.e. perfectly nested
+  // loops with bounds defined above some loop.
+
+  // 1. For each loop, find above which parent loop its bounds operands are
+  // defined.
+  SmallVector<unsigned, 4> operandsDefinedAbove(loops.size());
+  for (unsigned i = 0, e = loops.size(); i < e; ++i) {
+    operandsDefinedAbove[i] = i;
+    for (unsigned j = 0; j < i; ++j) {
+      SmallVector<Value> boundsOperands = {loops[i].getLowerBound(),
+                                           loops[i].getUpperBound(),
+                                           loops[i].getStep()};
+      if (areValuesDefinedAbove(boundsOperands, loops[j].getRegion())) {
+        operandsDefinedAbove[i] = j;
+        break;
+      }
+    }
+  }
+
+  // 2. For each inner loop check that the iter_args for the immediately outer
+  // loop are the init for the immediately inner loop and that the yields of the
+  // return of the inner loop is the yield for the immediately outer loop. Keep
+  // track of where the chain starts from for each loop.
+  SmallVector<unsigned> iterArgChainStart(loops.size());
+  iterArgChainStart[0] = 0;
+  for (unsigned i = 1, e = loops.size(); i < e; ++i) {
+    // By default set the start of the chain to itself.
+    iterArgChainStart[i] = i;
+    auto outerloop = loops[i - 1];
+    auto innerLoop = loops[i];
+    if (outerloop.getNumRegionIterArgs() != innerLoop.getNumRegionIterArgs()) {
+      continue;
+    }
+    if (llvm::any_of(
+            llvm::zip_equal(outerloop.getRegionIterArgs(),
+                            innerLoop.getInitArgs()),
+            [](auto it) { return std::get<0>(it) != std::get<1>(it); })) {
+      continue;
+    }
+    auto outerloopTerminator = outerloop.getBody()->getTerminator();
+    if (llvm::any_of(
+            llvm::zip_equal(outerloopTerminator->getOperands(),
+                            innerLoop.getResults()),
+            [](auto it) { return std::get<0>(it) != std::get<1>(it); })) {
+      continue;
+    }
+    iterArgChainStart[i] = iterArgChainStart[i - 1];
+  }
+
+  // 3. Identify bands of loops such that the operands of all of them are
+  // defined above the first loop in the band.  Traverse the nest bottom-up
+  // so that modifications don't invalidate the inner loops.
+  for (unsigned end = loops.size(); end > 0; --end) {
+    unsigned start = 0;
+    for (; start < end - 1; ++start) {
+      auto maxPos =
+          *std::max_element(std::next(operandsDefinedAbove.begin(), start),
+                            std::next(operandsDefinedAbove.begin(), end));
+      if (maxPos > start)
+        continue;
+      if (iterArgChainStart[end - 1] > start)
+        continue;
+      auto band = llvm::MutableArrayRef(loops.data() + start, end - start);
+      if (succeeded(coalesceLoops(band)))
+        result = success();
+      break;
+    }
+    // If a band was found and transformed, keep looking at the loops above
+    // the outermost transformed loop.
+    if (start != end - 1)
+      end = start + 1;
+  }
+  return result;
+}
+
 void mlir::collapseParallelLoops(
-    scf::ParallelOp loops, ArrayRef<std::vector<unsigned>> combinedDimensions) {
-  OpBuilder outsideBuilder(loops);
+    RewriterBase &rewriter, scf::ParallelOp loops,
+    ArrayRef<std::vector<unsigned>> combinedDimensions) {
+  OpBuilder::InsertionGuard g(rewriter);
+  rewriter.setInsertionPoint(loops);
   Location loc = loops.getLoc();
 
   // Presort combined dimensions.
@@ -619,25 +808,29 @@ void mlir::collapseParallelLoops(
   SmallVector<Value, 3> normalizedLowerBounds, normalizedSteps,
       normalizedUpperBounds;
   for (unsigned i = 0, e = loops.getNumLoops(); i < e; ++i) {
-    OpBuilder insideLoopBuilder = OpBuilder::atBlockBegin(loops.getBody());
-    auto resultBounds =
-        normalizeLoop(outsideBuilder, insideLoopBuilder, loc,
-                      loops.getLowerBound()[i], loops.getUpperBound()[i],
-                      loops.getStep()[i], loops.getBody()->getArgument(i));
-
-    normalizedLowerBounds.push_back(resultBounds.lowerBound);
-    normalizedUpperBounds.push_back(resultBounds.upperBound);
-    normalizedSteps.push_back(resultBounds.step);
+    OpBuilder::InsertionGuard g2(rewriter);
+    rewriter.setInsertionPoint(loops);
+    Value lb = loops.getLowerBound()[i];
+    Value ub = loops.getUpperBound()[i];
+    Value step = loops.getStep()[i];
+    auto newLoopParams = normalizeLoop(rewriter, loc, lb, ub, step);
+    normalizedLowerBounds.push_back(newLoopParams.lowerBound);
+    normalizedUpperBounds.push_back(newLoopParams.upperBound);
+    normalizedSteps.push_back(newLoopParams.step);
+
+    rewriter.setInsertionPointToStart(loops.getBody());
+    unNormalizeInductionVariable(rewriter, loc, loops.getInductionVars()[i], lb,
+                                 step);
   }
 
   // Combine iteration spaces.
   SmallVector<Value, 3> lowerBounds, upperBounds, steps;
-  auto cst0 = outsideBuilder.create<arith::ConstantIndexOp>(loc, 0);
-  auto cst1 = outsideBuilder.create<arith::ConstantIndexOp>(loc, 1);
+  auto cst0 = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+  auto cst1 = rewriter.create<arith::ConstantIndexOp>(loc, 1);
   for (auto &sortedDimension : sortedDimensions) {
-    Value newUpperBound = outsideBuilder.create<arith::ConstantIndexOp>(loc, 1);
+    Value newUpperBound = rewriter.create<arith::ConstantIndexOp>(loc, 1);
     for (auto idx : sortedDimension) {
-      newUpperBound = outsideBuilder.create<arith::MulIOp>(
+      newUpperBound = rewriter.create<arith::MulIOp>(
           loc, newUpperBound, normalizedUpperBounds[idx]);
     }
     lowerBounds.push_back(cst0);
@@ -651,7 +844,7 @@ void mlir::collapseParallelLoops(
   // value. The remainders then determine based on that range, which iteration
   // of the original induction value this represents. This is a normalized value
   // that is un-normalized already by the previous logic.
-  auto newPloop = outsideBuilder.create<scf::ParallelOp>(
+  auto newPloop = rewriter.create<scf::ParallelOp>(
       loc, lowerBounds, upperBounds, steps,
       [&](OpBuilder &insideBuilder, Location, ValueRange ploopIVs) {
         for (unsigned i = 0, e = combinedDimensions.size(); i < e; ++i) {
diff --git a/mlir/test/Dialect/Affine/canonicalize.mlir b/mlir/test/Dialect/Affine/canonicalize.mlir
index 7c0930eedc8568..14df5b7b80dcc8 100644
--- a/mlir/test/Dialect/Affine/canonicalize.mlir
+++ b/mlir/test/Dialect/Affine/canonicalize.mlir
@@ -1452,3 +1452,19 @@ func.func @mod_of_mod(%lb: index, %ub: index, %step: index) -> (index, index) {
   %1 = affine.apply affine_map<()[s0, s1, s2] -> ((s0 - ((s0 - s2) mod s1) - s2) mod s1)> ()[%ub, %step, %lb]
   return %0, %1 : index, index
 }
+
+// -----
+
+func.func @outer_unit_delinearize(%arg0 : index, %arg1 : index, %arg2 : index) -> (index, index, index, index, index) {
+  %c1 = arith.constant 1 : index
+  %0:5 = affine.delinearize_index %arg0 into (%c1, %c1, %arg1, %c1, %arg2) : index, index, index, index, index
+  return %0#0, %0#1, %0#2, %0#3, %0#4 : index, index, index, index, index
+}
+//     CHECK: func @outer_unit_delinearize
+// CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: index
+// CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: index
+// CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: index
+//  CHECK-DAG:   %[[C1:.+]] = arith.constant 1 : index
+//  CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
+//      CHECK:   %[[RESULT:.+]]:3 = affine.delinearize_index %[[ARG0]] into (%[[ARG1]], %[[C1]], %[[ARG2]])
+//      CHECK:   return %[[C0]], %[[C0]], %[[RESULT]]#0, %[[RESULT]]#1, %[[RESULT]]#2
diff --git a/mlir/test/Dialect/Affine/loop-coalescing.mlir b/mlir/test/Dialect/Affine/loop-coalescing.mlir
index 9c17fb24be690a..80f5f80f3720d0 100644
--- a/mlir/test/Dialect/Affine/loop-coalescing.mlir
+++ b/mlir/test/Dialect/Affine/loop-coalescing.mlir
@@ -1,38 +1,30 @@
-// RUN: mlir-opt -split-input-file -allow-unregistered-dialect -affine-loop-coalescing %s | FileCheck %s
+// RUN: mlir-opt -split-input-file -allow-unregistered-dialect -affine-loop-coalescing --cse %s | FileCheck %s
 
 // CHECK-LABEL: @one_3d_nest
 func.func @one_3d_nest() {
-  // Capture original bounds.  Note that for zero-based step-one loops, the
-  // upper bound is also the number of iterations.
-  // CHECK: %[[orig_lb:.*]] = arith.constant 0
-  // CHECK: %[[orig_step:.*]] = arith.constant 1
-  // CHECK: %[[orig_ub_k:.*]] = arith.constant 3
-  // CHECK: %[[orig_ub_i:.*]] = arith.constant 42
-  // CHECK: %[[orig_ub_j:.*]] = arith.constant 56
   %c0 = arith.constant 0 : index
   %c1 = arith.constant 1 : index
   %c2 = arith.constant 2 : index
   %c3 = arith.constant 3 : index
   %c42 = arith.constant 42 : index
   %c56 = arith.constant 56 : index
-  // The range of the new scf.
-  // CHECK:     %[[partial_range:.*]] = arith.muli %[[orig_ub_i]], %[[orig_ub_j]]
-  // CHECK-NEXT:%[[range:.*]] = arith.muli %[[partial_range]], %[[orig_ub_k]]
-
+  // CHECK-DAG: %[[c0:.+]] = arith.constant 0
+  // CHECK-DAG: %[[c1:.+]] = arith.constant 1
+  // CHECK-DAG: %[[c42:.+]] = arith.constant 42
+  // CHECK-DAG: %[[c56:.+]] = arith.constant 56
+  // CHECK-DAG: %[[c3:.+]] = arith.constant 3
   // Updated loop bounds.
-  // CHECK: scf.for %[[i:.*]] = %[[orig_lb]] to %[[range]] step %[[orig_step]]
+  // CHECK: %[[range:.*]] = arith.constant 7056
+  // CHECK: scf.for %[[i:.*]] = %[[c0]] to %[[range]] step %[[c1]]
   scf.for %i = %c0 to %c42 step %c1 {
     // Inner loops must have been removed.
     // CHECK-NOT: scf.for
 
     // Reconstruct original IVs from the linearized one.
-    // CHECK: %[[orig_k:.*]] = arith.remsi %[[i]], %[[orig_ub_k]]
-    // CHECK: %[[div:.*]] = arith.divsi %[[i]], %[[orig_ub_k]]
-    // CHECK: %[[orig_j:.*]] = arith.remsi %[[div]], %[[orig_ub_j]]
-    // CHECK: %[[orig_i:.*]] = arith.divsi %[[div]], %[[orig_ub_j]]
+    // CHECK: %[[delinearize:.+]]:3 = affine.delinearize_index %[[i]] into (%[[c42]], %[[c56]], %[[c3]])
     scf.for %j = %c0 to %c56 step %c1 {
       scf.for %k = %c0 to %c3 step %c1 {
-        // CHECK: "use"(%[[orig_i]], %[[orig_j]], %[[orig_k]])
+        // CHECK: "use"(%[[delinearize]]#0, %[[delinearize]]#1, %[[delinearize]]#2)
         "use"(%i, %j, %k) : (index, index, index) -> ()
       }
     }
@@ -40,25 +32,26 @@ func.func @one_3d_nest() {
   return
 }
 
+// -----
+
 // Check that there is no chasing the replacement of value uses by ensuring
 // multiple uses of loop induction variables get rewritten to the same values.
 
-// CHECK-LABEL: @multi_use
+// CHECK: #[[MAP:.+]] = affine_map<()[s0] -> (s0 + 1)>
+// CHECK: @multi_use
 func.func @multi_use() {
   %c0 = arith.constant 0 : index
   %c1 = arith.constant 1 : index
   %c10 = arith.constant 10 : index
+  // CHECK: %[[c9:.+]] = arith.constant 9
   // CHECK: scf.for %[[iv:.*]] =
   scf.for %i = %c1 to %c10 step %c1 {
     scf.for %j = %c1 to %c10 step %c1 {
       scf.for %k = %c1 to %c10 step %c1 {
-        // CHECK: %[[k_unshifted:.*]] = arith.remsi %[[iv]], %[[k_extent:.*]]
-        // CHECK: %[[ij:.*]] = arith.divsi %[[iv]], %[[k_extent]]
-        // CHECK: %[[j_unshifted:.*]] = arith.remsi %[[ij]], %[[j_extent:.*]]
-        // CHECK: %[[i_unshifted:.*]] = arith.divsi %[[ij]], %[[j_extent]]
-        // CHECK: %[[k:.*]] = arith.addi %[[k_unshifted]]
-        // CHECK: %[[j:.*]] = arith.addi %[[j_unshifted]]
-        // CHECK: %[[i:.*]] = arith.addi %[[i_unshifted]]
+        // CHECK: %[[delinearize:.+]]:3 = affine.delinearize_index %[[iv]] into (%[[c9]], %[[c9]], %[[c9]])
+        // CHECK: %[[k:.*]] = affine.apply #[[MAP]]()[%[[delinearize]]#2]
+        // CHECK: %[[j:.*]] = affine.apply #[[MAP]]()[%[[delinearize]]#1]
+        // CHECK: %[[i:.*]] = affine.apply #[[MAP]]()[%[[delinearize]]#0]
 
         // CHECK: "use1"(%[[i]], %[[j]], %[[k]])
         "use1"(%i,%j,%k) : (index,index,index) -> ()
@@ -72,13 +65,11 @@ func.func @multi_use() {
   return
 }
 
+// -----
+
+// CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0] -> (s0 * 3 + 7)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 * 2 + 5)>
 func.func @unnormalized_loops() {
-  // CHECK: %[[orig_step_i:.*]] = arith.constant 2
-  // CHECK: %[[orig_step_j:.*]] = arith.constant 3
-  // CHECK: %[[orig_lb_i:.*]] = arith.constant 5
-  // CHECK: %[[orig_lb_j:.*]] = arith.constant 7
-  // CHECK: %[[orig_ub_i:.*]] = arith.constant 10
-  // CHECK: %[[orig_ub_j:.*]] = arith.constant 17
   %c2 = arith.constant 2 : index
   %c3 = arith.constant 3 : index
   %c5 = arith.constant 5 : index
@@ -87,31 +78,22 @@ func.func @unnormalized_loops() {
   %c17 = arith.constant 17 : index
 
   // Number of iterations in the outer scf.
-  // CHECK: %[[diff_i:.*]] = arith.subi %[[orig_ub_i]], %[[orig_lb_i]]
-  // CHECK: %[[numiter_i:.*]] = arith.ceildivsi %[[diff_i]], %[[orig_step_i]]
+  // CHECK: %[[C3:.+]] = arith.constant 3
+  // CHECK: %[[C4:.+]] = arith.constant 4
 
   // Normalized lower bound and step for the outer scf.
-  // CHECK: %[[lb_i:.*]] = arith.constant 0
-  // CHECK: %[[step_i:.*]] = arith.constant 1
-
-  // Number of iterations in the inner loop, the pattern is the same as above,
-  // only capture the final result.
-  // CHECK: %[[numiter_j:.*]] = arith.ceildivsi {{.*}}, %[[orig_step_j]]
 
   // New bounds of the outer scf.
-  // CHECK: %[[range:.*]] = arith.muli %[[numiter_i]], %[[numiter_j]]
-  // CHECK: scf.for %[[i:.*]] = %[[lb_i]] to %[[range]] step %[[step_i]]
+  // CHECK: %[[range:.*]] = arith.constant 12
+  // CHECK: scf.for %[[i:.*]] = %{{.+}} to %[[range]] step %{{.+}}
   scf.for %i = %c5 to %c10 step %c2 {
     // The inner loop has been removed.
     // CHECK-NOT: scf.for
     scf.for %j = %c7 to %c17 step %c3 {
       // The IVs are rewritten.
-      // CHECK: %[[normalized_j:.*]] = arith.remsi %[[i]], %[[numiter_j]]
-      // CHECK: %[[normalized_i:.*]] = arith.divsi %[[i]], %[[numiter_j]]
-      // CHECK: %[[scaled_j:.*]] = arith.muli %[[normalized_j]], %[[orig_step_j]]
-      // CHECK: %[[orig_j:.*]] = arith.addi %[[scaled_j]], %[[orig_lb_j]]
-      // CHECK: %[[scaled_i:.*]] = arith.muli %[[normalized_i]], %[[orig_step_i]]
-      // CHECK: %[[orig_i:.*]] = arith.addi %[[scaled_i]], %[[orig_lb_i]]
+      // CHECK: %[[delinearized:.+]]:2 = affine.delinearize_index %[[i]] into (%[[C3]], %[[C4]])
+      // CHECK: %[[orig_j:.*]] = affine.apply #[[MAP]]()[%[[delinearize]]#1]
+      // CHECK: %[[orig_i:.*]] = affine.apply #[[MAP1]]()[%[[delinearize]]#0]
       // CHECK: "use"(%[[orig_i]], %[[orig_j]])
       "use"(%i, %j) : (index, index) -> ()
     }
@@ -119,8 +101,13 @@ func.func @unnormalized_loops() {
   return
 }
 
+// -----
+
 // Check with parametric loop bounds and steps, capture the bounds here.
-// CHECK-LABEL: @parametric
+// CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0, s1, s2] -> ((-s0 + s1) ceildiv s2)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1, s2, s3, s4, s5] -> (((-s0 + s1) ceildiv s2) * ((-s3 + s4) ceildiv s5))>
+// CHECK-DAG: #[[MAP2:.+]] = affine_map<()[s0, s1, s2] -> (s0 * s1 + s2)>
+// CHECK: @parametric
 // CHECK-SAME: %[[orig_lb1:[A-Za-z0-9]+]]:
 // CHECK-SAME: %[[orig_ub1:[A-Za-z0-9]+]]:
 // CHECK-SAME: %[[orig_step1:[A-Za-z0-9]+]]:
@@ -131,25 +118,22 @@ func.func @parametric(%lb1 : index, %ub1 : index, %step1 : index,
                  %lb2 : index, %ub2 : index, %step2 : index) {
   // Compute the number of iterations for each of the loops and the total
   // number of iterations.
-  // CHECK: %[[range1:.*]] = arith.subi %[[orig_ub1]], %[[orig_lb1]]
-  // CHECK: %[[numiter1:.*]] = arith.ceildivsi %[[range1]], %[[orig_step1]]
-  // CHECK: %[[range2:.*]] = arith.subi %[[orig_ub2]], %[[orig_lb2]]
-  // CHECK: %[[numiter2:.*]] = arith.ceildivsi %[[range2]], %[[orig_step2]]
-  // CHECK: %[[range:.*]] = arith.muli %[[numiter1]], %[[numiter2]] : index
+  // CHECK-DAG: %[[numiter1:.+]] = affine.apply #[[MAP]]()[%[[orig_lb1]], %[[orig_ub1]], %[[orig_step1]]]
+  // CHECK-DAG: %[[C0:.+]] = arith.constant 0
+  // CHECK-DAG: %[[C1:.+]] = arith.constant 1
+  // CHECK-DAG: %[[numiter2:.+]] = affine.apply #[[MAP]]()[%[[orig_lb2]], %[[orig_ub2]], %[[orig_step2]]]
+  // CHECK-DAG: %[[range:.+]] = affine.apply #[[MAP1]]()[%[[orig_lb1]], %[[orig_ub1]], %[[orig_step1]], %[[orig_lb2]], %[[orig_ub2]], %[[orig_step2]]]
 
   // Check that the outer loop is updated.
-  // CHECK: scf.for %[[i:.*]] = %c0{{.*}} to %[[range]] step %c1
+  // CHECK: scf.for %[[i:.*]] = %[[C0]] to %[[range]] step %[[C1]]
   scf.for %i = %lb1 to %ub1 step %step1 {
     // Check that the inner loop is removed.
     // CHECK-NOT: scf.for
     scf.for %j = %lb2 to %ub2 step %step2 {
       // Remapping of the induction variables.
-      // CHECK: %[[normalized_j:.*]] = arith.remsi %[[i]], %[[numiter2]] : index
-      // CHECK: %[[normalized_i:.*]] = arith.divsi %[[i]], %[[numiter2]] : index
-      // CHECK: %[[scaled_j:.*]] = arith.muli %[[normalized_j]], %[[orig_step2]]
-      // CHECK: %[[orig_j:.*]] = arith.addi %[[scaled_j]], %[[orig_lb2]]
-      // CHECK: %[[scaled_i:.*]] = arith.muli %[[normalized_i]], %[[orig_step1]]
-      // CHECK: %[[orig_i:.*]] = arith.addi %[[scaled_i]], %[[orig_lb1]]
+      // CHECK: %[[delinearize:.+]]:2 = affine.delinearize_index %[[i]] into (%[[numiter1]], %[[numiter2]])
+      // CHECK: %[[orig_j:.*]] = affine.apply #[[MAP2]]()[%[[delinearize]]#1, %[[orig_step2]], %[[orig_lb2]]]
+      // CHECK: %[[orig_i:.*]] = affine.apply #[[MAP2]]()[%[[delinearize]]#0, %[[orig_step1]], %[[orig_lb1]]]
 
       // CHECK: "foo"(%[[orig_i]], %[[orig_j]])
       "foo"(%i, %j) : (index, index) -> ()
@@ -158,19 +142,20 @@ func.func @parametric(%lb1 : index, %ub1 : index, %step1 : index,
   return
 }
 
+// -----
+
 // CHECK-LABEL: @two_bands
 func.func @two_bands() {
   %c0 = arith.constant 0 : index
   %c1 = arith.constant 1 : index
   %c10 = arith.constant 10 : index
-  // CHECK: %[[outer_range:.*]] = arith.muli
+  // CHECK: %[[outer_range:.*]] = arith.constant 100
   // CHECK: scf.for %{{.*}} = %{{.*}} to %[[outer_range]]
   scf.for %i = %c0 to %c10 step %c1 {
     // Check that the "j" loop was removed and that the inner loops were
-    // coalesced as well.  The preparation step for coalescing will inject the
-    // subtraction operation unlike the IV remapping.
+    // coalesced as well.  The coalescing will inject the delinearization.
     // CHECK-NOT: scf.for
-    // CHECK: arith.subi
+    // CHECK: affine.delinearize 
     scf.for %j = %c0 to %c10 step %c1 {
       // The inner pair of loops is coalesced separately.
       // CHECK: scf.for
@@ -239,19 +224,15 @@ func.func @coalesce_affine_for(%arg0: memref<?x?xf32>) {
   }
   return
 }
-// CHECK: %[[T0:.*]] = memref.dim %arg{{.*}}, %c{{.*}} : memref<?x?xf32>
-// CHECK: %[[T1:.*]] = memref.dim %arg{{.*}}, %c{{.*}} : memref<?x?xf32>
-// CHECK: %[[T2:.*]] = memref.dim %arg{{.*}}, %c{{.*}} : memref<?x?xf32>
-// CHECK-DAG: %[[T3:.*]] = affine.apply #[[IDENTITY]]()[%[[T0]]]
-// CHECK-DAG: %[[T4:.*]] = affine.apply #[[IDENTITY]]()[%[[T1]]]
-// CHECK-DAG: %[[T5:.*]] = affine.apply #[[PRODUCT]](%[[T3]])[%[[T4]]]
-// CHECK-DAG: %[[T6:.*]] = affine.apply #[[IDENTITY]]()[%[[T2]]]
-// CHECK-DAG: %[[T7:.*]] = affine.apply #[[PRODUCT]](%[[T5]])[%[[T6]]]
-// CHECK: affine.for %[[IV:.*]] = 0 to %[[T7]]
-// CHECK-DAG:    %[[K:.*]] = affine.apply #[[MOD]](%[[IV]])[%[[T6]]]
-// CHECK-DAG:    %[[T9:.*]] = affine.apply #[[FLOOR]](%[[IV]])[%[[T6]]]
-// CHECK-DAG:    %[[J:.*]] = affine.apply #[[MOD]](%[[T9]])[%[[T4]]]
-// CHECK-DAG:    %[[I:.*]] = affine.apply #[[FLOOR]](%[[T9]])[%[[T4]]]
+// CHECK: %[[DIM:.*]] = memref.dim %arg{{.*}}, %c{{.*}} : memref<?x?xf32>
+// CHECK-DAG: %[[T0:.*]] = affine.apply #[[IDENTITY]]()[%[[DIM]]]
+// CHECK-DAG: %[[T1:.*]] = affine.apply #[[PRODUCT]](%[[T0]])[%[[T0]]]
+// CHECK-DAG: %[[T2:.*]] = affine.apply #[[PRODUCT]](%[[T1]])[%[[T0]]]
+// CHECK: affine.for %[[IV:.*]] = 0 to %[[T2]]
+// CHECK-DAG:    %[[K:.*]] = affine.apply #[[MOD]](%[[IV]])[%[[T0]]]
+// CHECK-DAG:    %[[T4:.*]] = affine.apply #[[FLOOR]](%[[IV]])[%[[T0]]]
+// CHECK-DAG:    %[[J:.*]] = affine.apply #[[MOD]](%[[T4]])[%[[T0]]]
+// CHECK-DAG:    %[[I:.*]] = affine.apply #[[FLOOR]](%[[T4]])[%[[T0]]]
 // CHECK-NEXT:    "test.foo"(%[[I]], %[[J]], %[[K]])
 // CHECK-NEXT:  }
 // CHECK-NEXT:  return
@@ -259,11 +240,13 @@ func.func @coalesce_affine_for(%arg0: memref<?x?xf32>) {
 // -----
 
 // Check coalescing of affine.for loops when some of the loop has constant upper bounds while others have nin constant upper bounds.
-// CHECK-DAG: #[[IDENTITY:.*]] = affine_map<()[s0] -> (s0)>
-// CHECK-DAG: #[[PRODUCT:.*]] = affine_map<(d0)[s0] -> (d0 * s0)>
-// CHECK-DAG: #[[SIXTY_FOUR:.*]] = affine_map<() -> (64)>
-// CHECK-DAG: #[[MOD:.*]] = affine_map<(d0)[s0] -> (d0 mod s0)>
-// CHECK-DAG: #[[DIV:.*]] = affine_map<(d0)[s0] -> (d0 floordiv s0)>
+// CHECK-DAG: #[[IDENTITY:.+]] = affine_map<()[s0] -> (s0)>
+// CHECK-DAG: #[[PRODUCT:.+]] = affine_map<(d0)[s0] -> (d0 * s0)>
+// CHECK-DAG: #[[SIXTY_FOUR:.+]] = affine_map<() -> (64)>
+// CHECK-DAG: #[[MOD:.+]] = affine_map<(d0)[s0] -> (d0 mod s0)>
+// CHECK-DAG: #[[DIV:.+]] = affine_map<(d0)[s0] -> (d0 floordiv s0)>
+// CHECK: func.func @coalesce_affine_for
+// CHECK-SAME: %[[ARG0:.+]]: memref<?x?xf32>
 func.func @coalesce_affine_for(%arg0: memref<?x?xf32>) {
   %c0 = arith.constant 0 : index
   %M = memref.dim %arg0, %c0 : memref<?x?xf32>
@@ -277,18 +260,17 @@ func.func @coalesce_affine_for(%arg0: memref<?x?xf32>) {
   }
   return
 }
-// CHECK: %[[T0:.*]] = memref.dim %arg{{.*}}, %c{{.*}} : memref<?x?xf32>
-// CHECK: %[[T1:.*]] = memref.dim %arg{{.*}}, %c{{.*}} : memref<?x?xf32>
-// CHECK-DAG: %[[T2:.*]] = affine.apply #[[IDENTITY]]()[%[[T0]]]
-// CHECK-DAG: %[[T3:.*]] = affine.apply #[[IDENTITY]]()[%[[T1]]]
-// CHECK-DAG: %[[T4:.*]] = affine.apply #[[PRODUCT]](%[[T2]])[%[[T3]]]
-// CHECK-DAG: %[[T5:.*]] = affine.apply #[[SIXTY_FOUR]]()
-// CHECK-DAG: %[[T6:.*]] = affine.apply #[[PRODUCT]](%[[T4]])[%[[T5]]]
-// CHECK: affine.for %[[IV:.*]] = 0 to %[[T6]]
-// CHECK-DAG:    %[[K:.*]] = affine.apply #[[MOD]](%[[IV]])[%[[T5]]]
-// CHECK-DAG:    %[[T8:.*]] = affine.apply #[[DIV]](%[[IV]])[%[[T5]]]
-// CHECK-DAG:    %[[J:.*]] = affine.apply #[[MOD]](%[[T8]])[%[[T3]]]
-// CHECK-DAG:    %[[I:.*]] = affine.apply #[[DIV]](%[[T8]])[%[[T3]]]
+// CHECK: %[[C0:.+]] = arith.constant 0
+// CHECK: %[[DIM:.+]] = memref.dim %[[ARG0]], %[[C0]] : memref<?x?xf32>
+// CHECK-DAG: %[[T0:.+]] = affine.apply #[[IDENTITY]]()[%[[DIM]]]
+// CHECK-DAG: %[[T1:.*]] = affine.apply #[[PRODUCT]](%[[T0]])[%[[T0]]]
+// CHECK-DAG: %[[T2:.*]] = affine.apply #[[SIXTY_FOUR]]()
+// CHECK-DAG: %[[T3:.*]] = affine.apply #[[PRODUCT]](%[[T1]])[%[[T2]]]
+// CHECK: affine.for %[[IV:.*]] = 0 to %[[T3]]
+// CHECK-DAG:    %[[K:.*]] = affine.apply #[[MOD]](%[[IV]])[%[[T2]]]
+// CHECK-DAG:    %[[T5:.*]] = affine.apply #[[DIV]](%[[IV]])[%[[T2]]]
+// CHECK-DAG:    %[[J:.*]] = affine.apply #[[MOD]](%[[T5]])[%[[T0]]]
+// CHECK-DAG:    %[[I:.*]] = affine.apply #[[DIV]](%[[T5]])[%[[T0]]]
 // CHECK-NEXT:    "test.foo"(%[[I]], %[[J]], %[[K]])
 // CHECK-NEXT:  }
 // CHECK-NEXT:  return
@@ -301,6 +283,8 @@ func.func @coalesce_affine_for(%arg0: memref<?x?xf32>) {
 // CHECK-DAG: #[[PRODUCT:.*]] = affine_map<(d0)[s0] -> (d0 * s0)>
 // CHECK-DAG: #[[MOD:.*]] = affine_map<(d0)[s0] -> (d0 mod s0)>
 // CHECK-DAG: #[[DIV:.*]] = affine_map<(d0)[s0] -> (d0 floordiv s0)>
+// CHECK: func.func @coalesce_affine_for
+// CHECK-SAME: %[[ARG0:.+]]: memref<?x?xf32>
 #myMap = affine_map<()[s1] -> (s1, -s1)>
 func.func @coalesce_affine_for(%arg0: memref<?x?xf32>) {
  %c0 = arith.constant 0 : index
@@ -316,19 +300,17 @@ func.func @coalesce_affine_for(%arg0: memref<?x?xf32>) {
  }
  return
 }
-// CHECK: %[[T0:.*]] = memref.dim %arg{{.*}}, %c{{.*}} : memref<?x?xf32>
-// CHECK: %[[T1:.*]] = memref.dim %arg{{.*}}, %c{{.*}} : memref<?x?xf32>
-// CHECK: %[[T2:.*]] = memref.dim %arg{{.*}}, %c{{.*}} : memref<?x?xf32>
-// CHECK-DAG: %[[T3:.*]] = affine.min #[[MAP0]]()[%[[T0]]]
-// CHECK-DAG: %[[T4:.*]] = affine.apply #[[IDENTITY]]()[%[[T1]]]
-// CHECK-DAG: %[[T5:.*]] = affine.apply #[[PRODUCT]](%[[T3]])[%[[T4]]]
-// CHECK-DAG: %[[T6:.*]] = affine.apply #[[IDENTITY]]()[%[[T2]]]
-// CHECK-DAG: %[[T7:.*]] = affine.apply #[[PRODUCT]](%[[T5]])[%[[T6]]]
-// CHECK: affine.for %[[IV:.*]] = 0 to %[[T7]]
-// CHECK-DAG:    %[[K:.*]] = affine.apply #[[MOD]](%[[IV]])[%[[T6]]]
-// CHECK-DAG:    %[[T9:.*]] = affine.apply #[[DIV]](%[[IV]])[%[[T6]]]
-// CHECK-DAG:    %[[J:.*]] = affine.apply #[[MOD]](%[[T9]])[%[[T4]]]
-// CHECK-DAG:    %[[I:.*]] = affine.apply #[[DIV]](%[[T9]])[%[[T4]]]
+// CHECK: %[[C0:.+]] = arith.constant 0
+// CHECK: %[[DIM:.*]] = memref.dim %[[ARG0]], %[[C0]] : memref<?x?xf32>
+// CHECK-DAG: %[[T0:.*]] = affine.min #[[MAP0]]()[%[[DIM]]]
+// CHECK-DAG: %[[T1:.*]] = affine.apply #[[IDENTITY]]()[%[[DIM]]]
+// CHECK-DAG: %[[T2:.*]] = affine.apply #[[PRODUCT]](%[[T0]])[%[[T1]]]
+// CHECK-DAG: %[[T3:.*]] = affine.apply #[[PRODUCT]](%[[T2]])[%[[T1]]]
+// CHECK: affine.for %[[IV:.*]] = 0 to %[[T3]]
+// CHECK-DAG:    %[[K:.*]] = affine.apply #[[MOD]](%[[IV]])[%[[T1]]]
+// CHECK-DAG:    %[[T5:.*]] = affine.apply #[[DIV]](%[[IV]])[%[[T1]]]
+// CHECK-DAG:    %[[J:.*]] = affine.apply #[[MOD]](%[[T5]])[%[[T1]]]
+// CHECK-DAG:    %[[I:.*]] = affine.apply #[[DIV]](%[[T5]])[%[[T1]]]
 // CHECK-NEXT:    "test.foo"(%[[I]], %[[J]], %[[K]])
 // CHECK-NEXT:  }
 // CHECK-NEXT:  return
@@ -342,12 +324,14 @@ func.func @coalesce_affine_for(%arg0: memref<?x?xf32>) {
 func.func @test_loops_do_not_get_coalesced() {
   affine.for %i = 0 to 7 {
     affine.for %j = #map0(%i) to min #map1(%i) {
+      "use"(%i, %j) : (index, index) -> ()
     }
   }
   return
 }
 // CHECK: affine.for %[[IV0:.*]] = 0 to 7
 // CHECK-NEXT: affine.for %[[IV1:.*]] = #[[MAP0]](%[[IV0]]) to min #[[MAP1]](%[[IV0]])
+// CHECK-NEXT: "use"(%[[IV0]], %[[IV1]])
 // CHECK-NEXT: }
 // CHECK-NEXT: }
 // CHECK-NEXT: return
diff --git a/mlir/test/Dialect/SCF/transform-op-coalesce.mlir b/mlir/test/Dialect/SCF/transform-op-coalesce.mlir
index 2d59331b72cf65..666958e46ad4e7 100644
--- a/mlir/test/Dialect/SCF/transform-op-coalesce.mlir
+++ b/mlir/test/Dialect/SCF/transform-op-coalesce.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -transform-interpreter -split-input-file -verify-diagnostics | FileCheck %s
+// RUN: mlir-opt %s -transform-interpreter -split-input-file -verify-diagnostics -allow-unregistered-dialect --cse | FileCheck %s
 
 func.func @coalesce_inner() {
   %c0 = arith.constant 0 : index
@@ -14,7 +14,7 @@ func.func @coalesce_inner() {
       scf.for %k = %i to %j step %c1 {
         // Inner loop must have been removed.
         scf.for %l = %i to %j step %c1 {
-          arith.addi %i, %j : index
+          "use"(%i, %j) : (index, index) -> ()
         }
       } {coalesce}
     }
@@ -33,13 +33,19 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
+// CHECK-DAG: #[[MAP:.+]] = affine_map<() -> (64)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0)[s0] -> (d0 * s0)>
+// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0)[s0] -> (d0 mod s0)>
+// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0)[s0] -> (d0 floordiv s0)>
 func.func @coalesce_outer(%arg1: memref<64x64xf32, 1>, %arg2: memref<64x64xf32, 1>, %arg3: memref<64x64xf32, 1>) attributes {} {
+  // CHECK: %[[T0:.+]] = affine.apply #[[MAP]]()
+  // CHECK: %[[UB:.+]] = affine.apply #[[MAP1]](%[[T0]])[%[[T0]]]
   // CHECK: affine.for %[[IV1:.+]] = 0 to %[[UB:.+]] {
   // CHECK-NOT: affine.for %[[IV2:.+]]
   affine.for %arg4 = 0 to 64 {
     affine.for %arg5 = 0 to 64 {
-      // CHECK: %[[IDX0:.+]] = affine.apply #[[MAP0:.+]](%[[IV1]])[%{{.+}}]
-      // CHECK: %[[IDX1:.+]] = affine.apply #[[MAP1:.+]](%[[IV1]])[%{{.+}}]
+      // CHECK: %[[IDX0:.+]] = affine.apply #[[MAP2]](%[[IV1]])[%{{.+}}]
+      // CHECK: %[[IDX1:.+]] = affine.apply #[[MAP3]](%[[IV1]])[%{{.+}}]
       // CHECK-NEXT: %{{.+}} = affine.load %{{.+}}[%[[IDX1]], %[[IDX0]]] : memref<64x64xf32, 1>
       %0 = affine.load %arg1[%arg4, %arg5] : memref<64x64xf32, 1>
       %1 = affine.load %arg2[%arg4, %arg5] : memref<64x64xf32, 1>
@@ -70,9 +76,8 @@ func.func @coalesce_and_unroll(%arg1: memref<64x64xf32, 1>, %arg2: memref<64x64x
   scf.for %arg4 = %c0 to %c64 step %c1 {
     // CHECK-NOT: scf.for
     scf.for %arg5 = %c0 to %c64 step %c1 {
-      // CHECK: %[[IDX0:.+]] = arith.remsi %[[IV1]]
-      // CHECK: %[[IDX1:.+]] = arith.divsi %[[IV1]]
-      // CHECK-NEXT: %{{.+}} = memref.load %{{.+}}[%[[IDX1]], %[[IDX0]]] : memref<64x64xf32, 1>
+      // CHECK: %[[IV:.+]]:2 = affine.delinearize_index %[[IV1]]
+      // CHECK: %{{.+}} = memref.load %{{.+}}[%[[IV]]#0, %[[IV]]#1] : memref<64x64xf32, 1>
       %0 = memref.load %arg1[%arg4, %arg5] : memref<64x64xf32, 1>
       %1 = memref.load %arg2[%arg4, %arg5] : memref<64x64xf32, 1>
       %2 = arith.addf %0, %1 : f32
@@ -96,3 +101,150 @@ module attributes {transform.with_named_sequence} {
     transform.yield
   }
 }
+
+// -----
+
+func.func @tensor_loops(%arg0 : tensor<?x?xf32>, %lb0 : index, %ub0 : index, %step0 : index,
+    %lb1 : index, %ub1 : index, %step1 : index, %lb2 : index, %ub2 : index, %step2 : index) -> tensor<?x?xf32> {
+  %0 = scf.for %i = %lb0 to %ub0 step %step0 iter_args(%arg1 = %arg0) -> tensor<?x?xf32> {
+    %1 = scf.for %j = %lb1 to %ub1 step %step1 iter_args(%arg2 = %arg1) -> tensor<?x?xf32> {
+      %2 = scf.for %k = %lb2 to %ub2 step %step2 iter_args(%arg3 = %arg2) -> tensor<?x?xf32> {
+        %3 = "use"(%arg3, %i, %j, %k) : (tensor<?x?xf32>, index, index, index) -> (tensor<?x?xf32>)
+        scf.yield %3 : tensor<?x?xf32>
+      }
+      scf.yield %2 : tensor<?x?xf32>
+    }
+    scf.yield %1 : tensor<?x?xf32>
+  } {coalesce}
+  return %0 : tensor<?x?xf32>
+}
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["scf.for"]} attributes {coalesce} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %1 = transform.cast %0 : !transform.any_op to !transform.op<"scf.for">
+    %2 = transform.loop.coalesce %1 : (!transform.op<"scf.for">) -> (!transform.op<"scf.for">)
+    transform.yield
+  }
+}
+//  CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0, s1, s2] -> ((-s0 + s1) ceildiv s2)>
+//  CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8] -> ((((-s0 + s1) ceildiv s2) * ((-s3 + s4) ceildiv s5)) * ((-s6 + s7) ceildiv s8))>
+//  CHECK-DAG: #[[MAP3:.+]] = affine_map<()[s0, s1, s2] -> (s0 * s1 + s2)>
+//      CHECK: func.func @tensor_loops(
+// CHECK-SAME:     %[[ARG0:.+]]: tensor<?x?xf32>
+// CHECK-SAME:     %[[LB0:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME:     %[[UB0:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME:     %[[STEP0:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME:     %[[LB1:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME:     %[[UB1:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME:     %[[STEP1:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME:     %[[LB2:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME:     %[[UB2:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME:     %[[STEP2:[a-zA-Z0-9_]+]]: index
+//  CHECK-DAG:   %[[NEWUB0:.+]] = affine.apply #[[MAP]]()[%[[LB0]], %[[UB0]], %[[STEP0]]]
+//  CHECK-DAG:   %[[C0:.+]] = arith.constant 0
+//  CHECK-DAG:   %[[C1:.+]] = arith.constant 1
+//  CHECK-DAG:   %[[NEWUB1:.+]] = affine.apply #[[MAP]]()[%[[LB1]], %[[UB1]], %[[STEP1]]]
+//  CHECK-DAG:   %[[NEWUB2:.+]] = affine.apply #[[MAP]]()[%[[LB2]], %[[UB2]], %[[STEP2]]]
+//  CHECK-DAG:   %[[NEWUB:.+]] = affine.apply #[[MAP1]]()[%[[LB0]], %[[UB0]], %[[STEP0]], %[[LB1]], %[[UB1]], %[[STEP1]], %[[LB2]], %[[UB2]], %[[STEP2]]]
+//      CHECK:   %[[RESULT:.+]] = scf.for %[[IV:[a-zA-Z0-9]+]] = %[[C0]] to %[[NEWUB]] step %[[C1]] iter_args(%[[ITER_ARG:.+]] = %[[ARG0]])
+//      CHECK:     %[[DELINEARIZE:.+]]:3 = affine.delinearize_index %[[IV]] into (%[[NEWUB0]], %[[NEWUB1]], %[[NEWUB2]])
+//      CHECK:     %[[K:.+]] = affine.apply #[[MAP3]]()[%[[DELINEARIZE]]#2, %[[STEP2]], %[[LB2]]]
+//      CHECK:     %[[J:.+]] = affine.apply #[[MAP3]]()[%[[DELINEARIZE]]#1, %[[STEP1]], %[[LB1]]]
+//      CHECK:     %[[I:.+]] = affine.apply #[[MAP3]]()[%[[DELINEARIZE]]#0, %[[STEP0]], %[[LB0]]]
+//      CHECK:     %[[USE:.+]] = "use"(%[[ITER_ARG]], %[[I]], %[[J]], %[[K]])
+//      CHECK:     scf.yield %[[USE]]
+//      CHECK:   return %[[RESULT]]
+
+// -----
+
+// Coalesce only first two loops, but not the last since the iter_args dont line up
+func.func @tensor_loops(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>, %lb0 : index, %ub0 : index, %step0 : index,
+    %lb1 : index, %ub1 : index, %step1 : index, %lb2 : index, %ub2 : index, %step2 : index) -> (tensor<?x?xf32>, tensor<?x?xf32>) {
+  %0:2 = scf.for %i = %lb0 to %ub0 step %step0 iter_args(%arg2 = %arg0, %arg3 = %arg1) -> (tensor<?x?xf32>, tensor<?x?xf32>) {
+    %1:2 = scf.for %j = %lb1 to %ub1 step %step1 iter_args(%arg4 = %arg2, %arg5 = %arg3) -> (tensor<?x?xf32>, tensor<?x?xf32>) {
+      %2:2 = scf.for %k = %lb2 to %ub2 step %step2 iter_args(%arg6 = %arg5, %arg7 = %arg4) -> (tensor<?x?xf32>, tensor<?x?xf32>) {
+        %3:2 = "use"(%arg3, %i, %j, %k) : (tensor<?x?xf32>, index, index, index) -> (tensor<?x?xf32>, tensor<?x?xf32>)
+        scf.yield %3#0, %3#1 : tensor<?x?xf32>, tensor<?x?xf32>
+      }
+      scf.yield %2#0, %2#1 : tensor<?x?xf32>, tensor<?x?xf32>
+    }
+    scf.yield %1#0, %1#1 : tensor<?x?xf32>, tensor<?x?xf32>
+  } {coalesce}
+  return %0#0, %0#1 : tensor<?x?xf32>, tensor<?x?xf32>
+}
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["scf.for"]} attributes {coalesce} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %1 = transform.cast %0 : !transform.any_op to !transform.op<"scf.for">
+    %2 = transform.loop.coalesce %1 : (!transform.op<"scf.for">) -> (!transform.op<"scf.for">)
+    transform.yield
+  }
+}
+//     CHECK: scf.for
+//     CHECK:   %{{.+}}:2 = affine.delinearize_index
+//     CHECK:   scf.for
+// CHECK-NOT:     scf.for
+//     CHECK:   transform.named_sequence
+
+// -----
+
+// Coalesce only first two loops, but not the last since the yields dont match up
+func.func @tensor_loops(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>, %lb0 : index, %ub0 : index, %step0 : index,
+    %lb1 : index, %ub1 : index, %step1 : index, %lb2 : index, %ub2 : index, %step2 : index) -> (tensor<?x?xf32>, tensor<?x?xf32>) {
+  %0:2 = scf.for %i = %lb0 to %ub0 step %step0 iter_args(%arg2 = %arg0, %arg3 = %arg1) -> (tensor<?x?xf32>, tensor<?x?xf32>) {
+    %1:2 = scf.for %j = %lb1 to %ub1 step %step1 iter_args(%arg4 = %arg2, %arg5 = %arg3) -> (tensor<?x?xf32>, tensor<?x?xf32>) {
+      %2:2 = scf.for %k = %lb2 to %ub2 step %step2 iter_args(%arg6 = %arg4, %arg7 = %arg5) -> (tensor<?x?xf32>, tensor<?x?xf32>) {
+        %3:2 = "use"(%arg3, %i, %j, %k) : (tensor<?x?xf32>, index, index, index) -> (tensor<?x?xf32>, tensor<?x?xf32>)
+        scf.yield %3#0, %3#1 : tensor<?x?xf32>, tensor<?x?xf32>
+      }
+      scf.yield %2#1, %2#0 : tensor<?x?xf32>, tensor<?x?xf32>
+    }
+    scf.yield %1#0, %1#1 : tensor<?x?xf32>, tensor<?x?xf32>
+  } {coalesce}
+  return %0#0, %0#1 : tensor<?x?xf32>, tensor<?x?xf32>
+}
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["scf.for"]} attributes {coalesce} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %1 = transform.cast %0 : !transform.any_op to !transform.op<"scf.for">
+    %2 = transform.loop.coalesce %1 : (!transform.op<"scf.for">) -> (!transform.op<"scf.for">)
+    transform.yield
+  }
+}
+//     CHECK: scf.for
+//     CHECK:   %{{.+}}:2 = affine.delinearize_index
+//     CHECK:   scf.for
+// CHECK-NOT:     scf.for
+//     CHECK:   transform.named_sequence
+
+// -----
+
+// Coalesce only last two loops, but not the first since the yields dont match up
+func.func @tensor_loops(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>, %lb0 : index, %ub0 : index, %step0 : index,
+    %lb1 : index, %ub1 : index, %step1 : index, %lb2 : index, %ub2 : index, %step2 : index) -> (tensor<?x?xf32>, tensor<?x?xf32>) {
+  %0:2 = scf.for %i = %lb0 to %ub0 step %step0 iter_args(%arg2 = %arg0, %arg3 = %arg1) -> (tensor<?x?xf32>, tensor<?x?xf32>) {
+    %1:2 = scf.for %j = %lb1 to %ub1 step %step1 iter_args(%arg4 = %arg2, %arg5 = %arg3) -> (tensor<?x?xf32>, tensor<?x?xf32>) {
+      %2:2 = scf.for %k = %lb2 to %ub2 step %step2 iter_args(%arg6 = %arg4, %arg7 = %arg5) -> (tensor<?x?xf32>, tensor<?x?xf32>) {
+        %3:2 = "use"(%arg3, %i, %j, %k) : (tensor<?x?xf32>, index, index, index) -> (tensor<?x?xf32>, tensor<?x?xf32>)
+        scf.yield %3#0, %3#1 : tensor<?x?xf32>, tensor<?x?xf32>
+      }
+      scf.yield %2#0, %2#1 : tensor<?x?xf32>, tensor<?x?xf32>
+    }
+    scf.yield %1#1, %1#0 : tensor<?x?xf32>, tensor<?x?xf32>
+  } {coalesce}
+  return %0#0, %0#1 : tensor<?x?xf32>, tensor<?x?xf32>
+}
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["scf.for"]} attributes {coalesce} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %1 = transform.cast %0 : !transform.any_op to !transform.op<"scf.for">
+    %2 = transform.loop.coalesce %1 : (!transform.op<"scf.for">) -> (!transform.op<"scf.for">)
+    transform.yield
+  }
+}
+//     CHECK: scf.for
+//     CHECK:   scf.for
+//     CHECK:     %{{.+}}:2 = affine.delinearize_index
+// CHECK-NOT:     scf.for
+//     CHECK:   transform.named_sequence
+
diff --git a/mlir/test/Transforms/parallel-loop-collapsing.mlir b/mlir/test/Transforms/parallel-loop-collapsing.mlir
index 660d7edb2fbb37..f020d677c82b07 100644
--- a/mlir/test/Transforms/parallel-loop-collapsing.mlir
+++ b/mlir/test/Transforms/parallel-loop-collapsing.mlir
@@ -1,6 +1,7 @@
 // RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline='builtin.module(func.func(test-scf-parallel-loop-collapsing{collapsed-indices-0=0,3 collapsed-indices-1=1,4 collapsed-indices-2=2}, canonicalize))' | FileCheck %s
 
-// CHECK-LABEL: func @parallel_many_dims() {
+// CHECK: #[[MAP:.+]] = affine_map<()[s0] -> (s0 * 10 + 9)>
+// CHECK: func @parallel_many_dims() {
 func.func @parallel_many_dims() {
   %c0 = arith.constant 0 : index
   %c1 = arith.constant 1 : index
@@ -28,19 +29,16 @@ func.func @parallel_many_dims() {
   return
 }
 
-// CHECK-DAG: [[C12:%.*]] = arith.constant 12 : index
-// CHECK-DAG: [[C10:%.*]] = arith.constant 10 : index
-// CHECK-DAG: [[C9:%.*]] = arith.constant 9 : index
-// CHECK-DAG: [[C6:%.*]] = arith.constant 6 : index
-// CHECK-DAG: [[C4:%.*]] = arith.constant 4 : index
-// CHECK-DAG: [[C3:%.*]] = arith.constant 3 : index
-// CHECK-DAG: [[C2:%.*]] = arith.constant 2 : index
-// CHECK-DAG: [[C1:%.*]] = arith.constant 1 : index
-// CHECK-DAG: [[C0:%.*]] = arith.constant 0 : index
-// CHECK: scf.parallel ([[NEW_I0:%.*]]) = ([[C0]]) to ([[C4]]) step ([[C1]]) {
-// CHECK:   [[V0:%.*]] = arith.remsi [[NEW_I0]], [[C2]] : index
-// CHECK:   [[I0:%.*]] = arith.divsi [[NEW_I0]], [[C2]] : index
-// CHECK:   [[V2:%.*]] = arith.muli [[V0]], [[C10]] : index
-// CHECK:   [[I3:%.*]] = arith.addi [[V2]], [[C9]] : index
-// CHECK:   "magic.op"([[I0]], [[C3]], [[C6]], [[I3]], [[C12]]) : (index, index, index, index, index) -> index
+// CHECK-DAG: %[[C12:.*]] = arith.constant 12 : index
+// CHECK-DAG: %[[C6:.*]] = arith.constant 6 : index
+// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
+// CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index
+// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
+// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: scf.parallel (%[[NEW_I0:.*]]) = (%[[C0]]) to (%[[C4]]) step (%[[C1]]) {
+// CHECK:   %[[V0:.*]] = arith.remsi %[[NEW_I0]], %[[C2]] : index
+// CHECK:   %[[I0:.*]] = arith.divsi %[[NEW_I0]], %[[C2]] : index
+// CHECK:   %[[I3:.*]] = affine.apply #[[MAP]]()[%[[V0]]]
+// CHECK:   "magic.op"(%[[I0]], %[[C3]], %[[C6]], %[[I3]], %[[C12]]) : (index, index, index, index, index) -> index
 // CHECK:   scf.reduce
diff --git a/mlir/test/Transforms/single-parallel-loop-collapsing.mlir b/mlir/test/Transforms/single-parallel-loop-collapsing.mlir
index 542786b5fa5e57..812442fd9e38fb 100644
--- a/mlir/test/Transforms/single-parallel-loop-collapsing.mlir
+++ b/mlir/test/Transforms/single-parallel-loop-collapsing.mlir
@@ -13,22 +13,19 @@ func.func @collapse_to_single() {
   return
 }
 
-// CHECK-LABEL: func @collapse_to_single() {
-// CHECK-DAG:         [[C18:%.*]] = arith.constant 18 : index
-// CHECK-DAG:         [[C6:%.*]] = arith.constant 6 : index
-// CHECK-DAG:         [[C3:%.*]] = arith.constant 3 : index
-// CHECK-DAG:         [[C7:%.*]] = arith.constant 7 : index
-// CHECK-DAG:         [[C4:%.*]] = arith.constant 4 : index
-// CHECK-DAG:         [[C1:%.*]] = arith.constant 1 : index
-// CHECK-DAG:         [[C0:%.*]] = arith.constant 0 : index
-// CHECK:         scf.parallel ([[NEW_I:%.*]]) = ([[C0]]) to ([[C18]]) step ([[C1]]) {
-// CHECK:           [[I0_COUNT:%.*]] = arith.remsi [[NEW_I]], [[C6]] : index
-// CHECK:           [[I1_COUNT:%.*]] = arith.divsi [[NEW_I]], [[C6]] : index
-// CHECK:           [[V0:%.*]] = arith.muli [[I0_COUNT]], [[C4]] : index
-// CHECK:           [[I1:%.*]] = arith.addi [[V0]], [[C7]] : index
-// CHECK:           [[V1:%.*]] = arith.muli [[I1_COUNT]], [[C3]] : index
-// CHECK:           [[I0:%.*]] = arith.addi [[V1]], [[C3]] : index
-// CHECK:           "magic.op"([[I0]], [[I1]]) : (index, index) -> index
+// CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0] -> (s0 * 4 + 7)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 * 3 + 3)>
+// CHECK: func @collapse_to_single() {
+// CHECK-DAG:         %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG:         %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG:         %[[C6:.*]] = arith.constant 6 : index
+// CHECK-DAG:         %[[C18:.*]] = arith.constant 18 : index
+// CHECK:         scf.parallel (%[[NEW_I:.*]]) = (%[[C0]]) to (%[[C18]]) step (%[[C1]]) {
+// CHECK:           %[[I0_COUNT:.*]] = arith.remsi %[[NEW_I]], %[[C6]] : index
+// CHECK:           %[[I1_COUNT:.*]] = arith.divsi %[[NEW_I]], %[[C6]] : index
+// CHECK:           %[[I1:.*]] = affine.apply #[[MAP]]()[%[[I0_COUNT]]]
+// CHECK:           %[[I0:.*]] = affine.apply #[[MAP1]]()[%[[I1_COUNT]]]
+// CHECK:           "magic.op"(%[[I0]], %[[I1]]) : (index, index) -> index
 // CHECK:           scf.reduce
 // CHECK-NEXT:    }
 // CHECK-NEXT:    return



More information about the Mlir-commits mailing list