[Mlir-commits] [mlir] [mlir][SCF] Modernize `coalesceLoops` method to handle `scf.for` loops with iter_args (PR #87019)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Mar 28 16:17:57 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-core

@llvm/pr-subscribers-mlir-affine

Author: None (MaheshRavishankar)

<details>
<summary>Changes</summary>

As part of this extension this change also does some general cleanup

1) Make all the methods take `RewriterBase` as arguments instead of
   creating their own builders that tend to crash when used within
   pattern rewrites
2) For the induction variables being `index` types use
   `makeComposedFoldedAffineApply` to constant propagate where
   possible. The non-index types cant use this path, so they continue
   to generate the arith instructions
3) Split `coalesePerfectlyNestedLoops` into two separate methods, one
   for `scf.for` and other for `affine.for`. The templatization didnt
   seem to be buying much there.
4) Also add a canonicalization to `affine.delinearize_index` to drop
   the delinearization when the outer dimensions are `1` and replace
   those with `0`.

Also general clean up of tests.

---

Patch is 69.67 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/87019.diff


15 Files Affected:

- (modified) mlir/include/mlir/Dialect/Affine/IR/AffineOps.td (+1) 
- (modified) mlir/include/mlir/Dialect/Affine/LoopUtils.h (+2-47) 
- (modified) mlir/include/mlir/Dialect/SCF/Utils/Utils.h (+6-1) 
- (modified) mlir/include/mlir/IR/PatternMatch.h (+9) 
- (modified) mlir/lib/Dialect/Affine/IR/AffineOps.cpp (+36) 
- (modified) mlir/lib/Dialect/Affine/Transforms/LoopCoalescing.cpp (+6-2) 
- (modified) mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp (+46) 
- (modified) mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp (+2-2) 
- (modified) mlir/lib/Dialect/SCF/Transforms/ParallelLoopCollapsing.cpp (+8-1) 
- (modified) mlir/lib/Dialect/SCF/Utils/Utils.cpp (+307-114) 
- (modified) mlir/test/Dialect/Affine/canonicalize.mlir (+16) 
- (modified) mlir/test/Dialect/Affine/loop-coalescing.mlir (+92-108) 
- (modified) mlir/test/Dialect/SCF/transform-op-coalesce.mlir (+159-7) 
- (modified) mlir/test/Transforms/parallel-loop-collapsing.mlir (+14-16) 
- (modified) mlir/test/Transforms/single-parallel-loop-collapsing.mlir (+13-16) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
index edcfcfd830c443..a0b14614934519 100644
--- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
+++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
@@ -1095,6 +1095,7 @@ def AffineDelinearizeIndexOp : Affine_Op<"delinearize_index",
   ];
 
   let hasVerifier = 1;
+  let hasCanonicalizer = 1;
 }
 
 #endif // AFFINE_OPS
diff --git a/mlir/include/mlir/Dialect/Affine/LoopUtils.h b/mlir/include/mlir/Dialect/Affine/LoopUtils.h
index 723a262f24acc5..d143954b78fc12 100644
--- a/mlir/include/mlir/Dialect/Affine/LoopUtils.h
+++ b/mlir/include/mlir/Dialect/Affine/LoopUtils.h
@@ -299,53 +299,8 @@ LogicalResult
 separateFullTiles(MutableArrayRef<AffineForOp> nest,
                   SmallVectorImpl<AffineForOp> *fullTileNest = nullptr);
 
-/// Walk either an scf.for or an affine.for to find a band to coalesce.
-template <typename LoopOpTy>
-LogicalResult coalescePerfectlyNestedLoops(LoopOpTy op) {
-  LogicalResult result(failure());
-  SmallVector<LoopOpTy> loops;
-  getPerfectlyNestedLoops(loops, op);
-
-  // Look for a band of loops that can be coalesced, i.e. perfectly nested
-  // loops with bounds defined above some loop.
-  // 1. For each loop, find above which parent loop its operands are
-  // defined.
-  SmallVector<unsigned, 4> operandsDefinedAbove(loops.size());
-  for (unsigned i = 0, e = loops.size(); i < e; ++i) {
-    operandsDefinedAbove[i] = i;
-    for (unsigned j = 0; j < i; ++j) {
-      if (areValuesDefinedAbove(loops[i].getOperands(), loops[j].getRegion())) {
-        operandsDefinedAbove[i] = j;
-        break;
-      }
-    }
-  }
-
-  // 2. Identify bands of loops such that the operands of all of them are
-  // defined above the first loop in the band.  Traverse the nest bottom-up
-  // so that modifications don't invalidate the inner loops.
-  for (unsigned end = loops.size(); end > 0; --end) {
-    unsigned start = 0;
-    for (; start < end - 1; ++start) {
-      auto maxPos =
-          *std::max_element(std::next(operandsDefinedAbove.begin(), start),
-                            std::next(operandsDefinedAbove.begin(), end));
-      if (maxPos > start)
-        continue;
-      assert(maxPos == start &&
-             "expected loop bounds to be known at the start of the band");
-      auto band = llvm::MutableArrayRef(loops.data() + start, end - start);
-      if (succeeded(coalesceLoops(band)))
-        result = success();
-      break;
-    }
-    // If a band was found and transformed, keep looking at the loops above
-    // the outermost transformed loop.
-    if (start != end - 1)
-      end = start + 1;
-  }
-  return result;
-}
+/// Walk an affine.for to find a band to coalesce.
+LogicalResult coalescePerfectlyNestedAffineLoops(AffineForOp op);
 
 } // namespace affine
 } // namespace mlir
diff --git a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
index 883d11bcc4df06..bc09cc7f7fa5e0 100644
--- a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
@@ -100,11 +100,16 @@ getSCFMinMaxExpr(Value value, SmallVectorImpl<Value> &dims,
 /// `loops` contains a list of perfectly nested loops with bounds and steps
 /// independent of any loop induction variable involved in the nest.
 LogicalResult coalesceLoops(MutableArrayRef<scf::ForOp> loops);
+LogicalResult coalesceLoops(RewriterBase &rewriter,
+                            MutableArrayRef<scf::ForOp>);
+
+/// Walk an affine.for to find a band to coalesce.
+LogicalResult coalescePerfectlyNestedSCFForLoops(scf::ForOp op);
 
 /// Take the ParallelLoop and for each set of dimension indices, combine them
 /// into a single dimension. combinedDimensions must contain each index into
 /// loops exactly once.
-void collapseParallelLoops(scf::ParallelOp loops,
+void collapseParallelLoops(RewriterBase &rewriter, scf::ParallelOp loops,
                            ArrayRef<std::vector<unsigned>> combinedDimensions);
 
 /// Unrolls this for operation by the specified unroll factor. Returns failure
diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h
index 070e6ed702f86a..fabe4cc401cff5 100644
--- a/mlir/include/mlir/IR/PatternMatch.h
+++ b/mlir/include/mlir/IR/PatternMatch.h
@@ -12,6 +12,7 @@
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinOps.h"
 #include "llvm/ADT/FunctionExtras.h"
+#include "llvm/ADT/SmallPtrSet.h"
 #include "llvm/Support/TypeName.h"
 #include <optional>
 
@@ -697,6 +698,14 @@ class RewriterBase : public OpBuilder {
       return user != exceptedUser;
     });
   }
+  void
+  replaceAllUsesExcept(Value from, Value to,
+                       const SmallPtrSetImpl<Operation *> &preservedUsers) {
+    return replaceUsesWithIf(from, to, [&](OpOperand &use) {
+      Operation *user = use.getOwner();
+      return !preservedUsers.contains(user);
+    });
+  }
 
   /// Used to notify the listener that the IR failed to be rewritten because of
   /// a match failure, and provide a callback to populate a diagnostic with the
diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index c591e5056480ca..4837ced453fa43 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -4532,6 +4532,42 @@ LogicalResult AffineDelinearizeIndexOp::verify() {
   return success();
 }
 
+namespace {
+// When outer dimension used for delinearization are ones, the corresponding
+// results can all be replaced by zeros.
+struct DropUnitOuterDelinearizeDims
+    : public OpRewritePattern<AffineDelinearizeIndexOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(AffineDelinearizeIndexOp indexOp,
+                                PatternRewriter &rewriter) const override {
+    ValueRange basis = indexOp.getBasis();
+    if (basis.empty()) {
+      return failure();
+    }
+    std::optional<int64_t> basisValue =
+        getConstantIntValue(getAsOpFoldResult(basis.front()));
+    if (!basisValue || basisValue != 1) {
+      return failure();
+    }
+    SmallVector<Value> replacements;
+    Location loc = indexOp.getLoc();
+    replacements.push_back(rewriter.create<arith::ConstantIndexOp>(loc, 0));
+    auto newIndexOp = rewriter.create<AffineDelinearizeIndexOp>(
+        loc, indexOp.getLinearIndex(), basis.drop_front());
+    replacements.append(newIndexOp->result_begin(), newIndexOp->result_end());
+    rewriter.replaceOp(indexOp, replacements);
+    return success();
+  }
+};
+
+} // namespace
+
+void AffineDelinearizeIndexOp::getCanonicalizationPatterns(
+    RewritePatternSet &results, MLIRContext *context) {
+  results.add<DropUnitOuterDelinearizeDims>(context);
+}
+
 //===----------------------------------------------------------------------===//
 // TableGen'd op method definitions
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Affine/Transforms/LoopCoalescing.cpp b/mlir/lib/Dialect/Affine/Transforms/LoopCoalescing.cpp
index 1dc69ab493d477..1f23055544d2a5 100644
--- a/mlir/lib/Dialect/Affine/Transforms/LoopCoalescing.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/LoopCoalescing.cpp
@@ -35,13 +35,17 @@ namespace {
 struct LoopCoalescingPass
     : public affine::impl::LoopCoalescingBase<LoopCoalescingPass> {
 
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry.insert<affine::AffineDialect>();
+  }
+
   void runOnOperation() override {
     func::FuncOp func = getOperation();
     func.walk<WalkOrder::PreOrder>([](Operation *op) {
       if (auto scfForOp = dyn_cast<scf::ForOp>(op))
-        (void)coalescePerfectlyNestedLoops(scfForOp);
+        (void)coalescePerfectlyNestedSCFForLoops(scfForOp);
       else if (auto affineForOp = dyn_cast<AffineForOp>(op))
-        (void)coalescePerfectlyNestedLoops(affineForOp);
+        (void)coalescePerfectlyNestedAffineLoops(affineForOp);
     });
   }
 };
diff --git a/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp b/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp
index af59973d7a92c5..0b2885e6396aae 100644
--- a/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp
+++ b/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp
@@ -2765,3 +2765,49 @@ mlir::affine::separateFullTiles(MutableArrayRef<AffineForOp> inputNest,
 
   return success();
 }
+
+LogicalResult affine::coalescePerfectlyNestedAffineLoops(AffineForOp op) {
+  LogicalResult result(failure());
+  SmallVector<AffineForOp> loops;
+  getPerfectlyNestedLoops(loops, op);
+
+  // Look for a band of loops that can be coalesced, i.e. perfectly nested
+  // loops with bounds defined above some loop.
+  // 1. For each loop, find above which parent loop its operands are
+  // defined.
+  SmallVector<unsigned, 4> operandsDefinedAbove(loops.size());
+  for (unsigned i = 0, e = loops.size(); i < e; ++i) {
+    operandsDefinedAbove[i] = i;
+    for (unsigned j = 0; j < i; ++j) {
+      if (areValuesDefinedAbove(loops[i].getOperands(), loops[j].getRegion())) {
+        operandsDefinedAbove[i] = j;
+        break;
+      }
+    }
+  }
+
+  // 2. Identify bands of loops such that the operands of all of them are
+  // defined above the first loop in the band.  Traverse the nest bottom-up
+  // so that modifications don't invalidate the inner loops.
+  for (unsigned end = loops.size(); end > 0; --end) {
+    unsigned start = 0;
+    for (; start < end - 1; ++start) {
+      auto maxPos =
+          *std::max_element(std::next(operandsDefinedAbove.begin(), start),
+                            std::next(operandsDefinedAbove.begin(), end));
+      if (maxPos > start)
+        continue;
+      assert(maxPos == start &&
+             "expected loop bounds to be known at the start of the band");
+      auto band = llvm::MutableArrayRef(loops.data() + start, end - start);
+      if (succeeded(coalesceLoops(band)))
+        result = success();
+      break;
+    }
+    // If a band was found and transformed, keep looking at the loops above
+    // the outermost transformed loop.
+    if (start != end - 1)
+      end = start + 1;
+  }
+  return result;
+}
diff --git a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
index c0918414820803..7e4faf8b73afbb 100644
--- a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
+++ b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
@@ -332,9 +332,9 @@ transform::LoopCoalesceOp::applyToOne(transform::TransformRewriter &rewriter,
                                       transform::TransformState &state) {
   LogicalResult result(failure());
   if (scf::ForOp scfForOp = dyn_cast<scf::ForOp>(op))
-    result = coalescePerfectlyNestedLoops(scfForOp);
+    result = coalescePerfectlyNestedSCFForLoops(scfForOp);
   else if (AffineForOp affineForOp = dyn_cast<AffineForOp>(op))
-    result = coalescePerfectlyNestedLoops(affineForOp);
+    result = coalescePerfectlyNestedAffineLoops(affineForOp);
 
   results.push_back(op);
   if (failed(result)) {
diff --git a/mlir/lib/Dialect/SCF/Transforms/ParallelLoopCollapsing.cpp b/mlir/lib/Dialect/SCF/Transforms/ParallelLoopCollapsing.cpp
index a69df025bcba81..ada0c971cb86bf 100644
--- a/mlir/lib/Dialect/SCF/Transforms/ParallelLoopCollapsing.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/ParallelLoopCollapsing.cpp
@@ -8,6 +8,7 @@
 
 #include "mlir/Dialect/SCF/Transforms/Passes.h"
 
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Dialect/SCF/Utils/Utils.h"
 #include "mlir/Transforms/RegionUtils.h"
@@ -28,6 +29,11 @@ namespace {
 struct TestSCFParallelLoopCollapsing
     : public impl::TestSCFParallelLoopCollapsingBase<
           TestSCFParallelLoopCollapsing> {
+
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry.insert<affine::AffineDialect>();
+  }
+
   void runOnOperation() override {
     Operation *module = getOperation();
 
@@ -88,6 +94,7 @@ struct TestSCFParallelLoopCollapsing
     // Only apply the transformation on parallel loops where the specified
     // transformation is valid, but do NOT early abort in the case of invalid
     // loops.
+    IRRewriter rewriter(&getContext());
     module->walk([&](scf::ParallelOp op) {
       if (flattenedCombinedLoops.size() != op.getNumLoops()) {
         op.emitOpError("has ")
@@ -97,7 +104,7 @@ struct TestSCFParallelLoopCollapsing
             << flattenedCombinedLoops.size() << " iter args.";
         return;
       }
-      collapseParallelLoops(op, combinedLoops);
+      collapseParallelLoops(rewriter, op, combinedLoops);
     });
   }
 };
diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
index 914aeb4fa79fda..ac42a21a883fa3 100644
--- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
@@ -12,7 +12,9 @@
 
 #include "mlir/Dialect/SCF/Utils/Utils.h"
 #include "mlir/Analysis/SliceAnalysis.h"
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Arith/Utils/Utils.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/IR/BuiltinOps.h"
@@ -472,18 +474,43 @@ LogicalResult mlir::loopUnrollByFactor(
   return success();
 }
 
-/// Return the new lower bound, upper bound, and step in that order. Insert any
-/// additional bounds calculations before the given builder and any additional
-/// conversion back to the original loop induction value inside the given Block.
-static LoopParams normalizeLoop(OpBuilder &boundsBuilder,
-                                OpBuilder &insideLoopBuilder, Location loc,
-                                Value lowerBound, Value upperBound, Value step,
-                                Value inductionVar) {
+/// Transform a loop with a strictly positive step
+///   for %i = %lb to %ub step %s
+/// into a 0-based loop with step 1
+///   for %ii = 0 to ceildiv(%ub - %lb, %s) step 1 {
+///     %i = %ii * %s + %lb
+/// Insert the induction variable remapping in the body of `inner`, which is
+/// expected to be either `loop` or another loop perfectly nested under `loop`.
+/// Insert the definition of new bounds immediate before `outer`, which is
+/// expected to be either `loop` or its parent in the loop nest.
+static OpFoldResult normalizeLoop(RewriterBase &rewriter, Location loc,
+                                  OpFoldResult lb, OpFoldResult ub,
+                                  OpFoldResult step) {
+  AffineExpr s0, s1, s2;
+  bindSymbols(rewriter.getContext(), s0, s1, s2);
+  AffineExpr normalizeExpr = (s1 - s0).ceilDiv(s2);
+
+  OpFoldResult newUb = affine::makeComposedFoldedAffineApply(
+      rewriter, loc, normalizeExpr, {lb, ub, step});
+  return newUb;
+}
+static LoopParams normalizeLoop(RewriterBase &rewriter, Location loc, Value lb,
+                                Value ub, Value step) {
+  auto isIndexType = [](Value v) { return v.getType().isa<IndexType>(); };
+  if (isIndexType(lb) && isIndexType(ub) && isIndexType(step)) {
+    OpFoldResult newUb =
+        normalizeLoop(rewriter, loc, getAsOpFoldResult(lb),
+                      getAsOpFoldResult(ub), getAsOpFoldResult(step));
+    Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+    Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
+    return {zero, getValueOrCreateConstantIndexOp(rewriter, loc, newUb), one};
+  }
+  // For non-index types, generate `arith` instructions
   // Check if the loop is already known to have a constant zero lower bound or
   // a constant one step.
   bool isZeroBased = false;
-  if (auto ubCst = getConstantIntValue(lowerBound))
-    isZeroBased = ubCst.value() == 0;
+  if (auto lbCst = getConstantIntValue(lb))
+    isZeroBased = lbCst.value() == 0;
 
   bool isStepOne = false;
   if (auto stepCst = getConstantIntValue(step))
@@ -493,62 +520,130 @@ static LoopParams normalizeLoop(OpBuilder &boundsBuilder,
   // assuming the step is strictly positive.  Update the bounds and the step
   // of the loop to go from 0 to the number of iterations, if necessary.
   if (isZeroBased && isStepOne)
-    return {/*lowerBound=*/lowerBound, /*upperBound=*/upperBound,
-            /*step=*/step};
+    return {lb, ub, step};
 
-  Value diff = boundsBuilder.create<arith::SubIOp>(loc, upperBound, lowerBound);
+  Value diff = isZeroBased ? ub : rewriter.create<arith::SubIOp>(loc, ub, lb);
   Value newUpperBound =
-      boundsBuilder.create<arith::CeilDivSIOp>(loc, diff, step);
-
-  Value newLowerBound =
-      isZeroBased ? lowerBound
-                  : boundsBuilder.create<arith::ConstantOp>(
-                        loc, boundsBuilder.getZeroAttr(lowerBound.getType()));
-  Value newStep =
-      isStepOne ? step
-                : boundsBuilder.create<arith::ConstantOp>(
-                      loc, boundsBuilder.getIntegerAttr(step.getType(), 1));
-
-  // Insert code computing the value of the original loop induction variable
-  // from the "normalized" one.
-  Value scaled =
-      isStepOne
-          ? inductionVar
-          : insideLoopBuilder.create<arith::MulIOp>(loc, inductionVar, step);
-  Value shifted =
-      isZeroBased
-          ? scaled
-          : insideLoopBuilder.create<arith::AddIOp>(loc, scaled, lowerBound);
-
-  SmallPtrSet<Operation *, 2> preserve{scaled.getDefiningOp(),
-                                       shifted.getDefiningOp()};
-  inductionVar.replaceAllUsesExcept(shifted, preserve);
-  return {/*lowerBound=*/newLowerBound, /*upperBound=*/newUpperBound,
-          /*step=*/newStep};
+      isStepOne ? diff : rewriter.create<arith::CeilDivSIOp>(loc, diff, step);
+
+  Value newLowerBound = isZeroBased
+                            ? lb
+                            : rewriter.create<arith::ConstantOp>(
+                                  loc, rewriter.getZeroAttr(lb.getType()));
+  Value newStep = isStepOne
+                      ? step
+                      : rewriter.create<arith::ConstantOp>(
+                            loc, rewriter.getIntegerAttr(step.getType(), 1));
+
+  return {newLowerBound, newUpperBound, newStep};
 }
 
-/// Transform a loop with a strictly positive step
-///   for %i = %lb to %ub step %s
-/// into a 0-based loop with step 1
-///   for %ii = 0 to ceildiv(%ub - %lb, %s) step 1 {
-///     %i = %ii * %s + %lb
-/// Insert the induction variable remapping in the body of `inner`, which is
-/// expected to be either `loop` or another loop perfectly nested under `loop`.
-/// Insert the definition of new bounds immediate before `outer`, which is
-/// expected to be either `loop` or its parent in the loop nest.
-static void normalizeLoop(scf::ForOp loop, scf::ForOp outer, scf::ForOp inner) {
-  OpBuilder builder(outer);
-  OpBuilder innerBuilder = OpBuilder::atBlockBegin(inner.getBody());
-  auto loopPieces = normalizeLoop(builder, innerBuilder, loop.getLoc(),
-                                  loop.getLowerBound(), loop.getUpperBound(),
-                                  loop.getStep(), loop.getInductionVar());
-
-  loop.setLowerBound(loopPieces.lowerBound);
-  loop.setUpperBound(loopPieces.upperBound);
-  loop.setStep(loopPieces.step);
+/// Get back the original induction variable values after loop normalization
+static void unNormalizeInductionVariable(RewriterBase &rewriter, Location loc,
+                                         Value normalizedIv, Value origLb,
+                                         Value origStep) {
+  Value unNormalizedIv;
+  std::optional<Operation *> preserve;
+  if (normalizedIv.getType().isa<IndexType>()) {
+    AffineExpr s0, s1, s2;
+    bindSymbols(rewriter.getContext(), s0, s1, s2);
+    AffineExpr ivExpr = (s0 * s1) + s2;
+    OpFoldResult newIv = affine::makeComposedFoldedAffineApply(
+        rewriter, loc, ivExpr,
+        ArrayRef<OpFoldResult>{normalizedIv, origStep, origLb});
+    unNormalizedIv = getValueOrCreateConstantIndexOp(rewriter, loc, newIv);
+    preserve = unNormalizedIv.getDefiningOp();
+  } else {
+    bool isStepOne = isConstantIntValue(origStep, 1);
+    bool isZeroBased = isConstantIntValue(origLb, 0);
+
+    Value scaled = normalizedIv;
+    if (!isStepOne) {
+      scaled = rewriter.create<arith::MulIOp>(loc, normalizedIv, origStep);
+      preserve = scaled.getDefiningOp();
+    }
+    unNormalizedIv = scaled;
+    if (!isZeroBased)
+      unNormalizedIv = rewriter.create...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/87019


More information about the Mlir-commits mailing list