[Mlir-commits] [mlir] [mlir] Add loop bounds normalization pass (PR #93781)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu May 30 01:11:15 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Jorn Tuyls (jtuyls)

<details>
<summary>Changes</summary>

Add pass to normalize loop bounds, i.e.  i.e. calculate new normalized upper bounds for lower bounds equal to zero and step sizes equal to one. Then, insert new `affine.apply` operations to calculate the denormalized index values and update all usage from the original induction variables to the results of the `affine.apply` operations.

I created a new interface for loop-like operations with induction variables (`LoopLikeWithInductionVarsOpInterface`) instead of putting the new methods in `LoopLikeOpInterface` as not all loop-like operations have induction variables and therefore wouldn't want to implement this interface.

cc @<!-- -->MaheshRavishankar @<!-- -->qedawkins 

---

Patch is 41.23 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/93781.diff


16 Files Affected:

- (modified) mlir/include/mlir/Dialect/SCF/IR/SCFOps.td (+97) 
- (added) mlir/include/mlir/Dialect/Utils/LoopUtils.h (+30) 
- (modified) mlir/include/mlir/Interfaces/LoopLikeInterface.h (+4) 
- (modified) mlir/include/mlir/Interfaces/LoopLikeInterface.td (+126) 
- (modified) mlir/include/mlir/Transforms/Passes.h (+4) 
- (modified) mlir/include/mlir/Transforms/Passes.td (+6) 
- (modified) mlir/lib/Dialect/SCF/IR/SCF.cpp (+60) 
- (modified) mlir/lib/Dialect/SCF/Utils/Utils.cpp (+1-54) 
- (modified) mlir/lib/Dialect/Utils/CMakeLists.txt (+1) 
- (added) mlir/lib/Dialect/Utils/LoopUtils.cpp (+52) 
- (modified) mlir/lib/IR/Operation.cpp (-2) 
- (modified) mlir/lib/Interfaces/LoopLikeInterface.cpp (+25) 
- (modified) mlir/lib/Transforms/CMakeLists.txt (+2) 
- (added) mlir/lib/Transforms/NormalizeLoopBounds.cpp (+118) 
- (modified) mlir/test/Dialect/Affine/loop-coalescing.mlir (+4-7) 
- (added) mlir/test/Transforms/normalize-loop-bounds.mlir (+266) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
index 0b063aa772bab..0e23257456223 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
@@ -140,6 +140,7 @@ def ForOp : SCF_Op<"for",
         "getSingleUpperBound", "getYieldedValuesMutable",
         "promoteIfSingleIteration", "replaceWithAdditionalYields",
         "yieldTiledValuesAndReplace"]>,
+       LoopLikeWithInductionVarsOpInterface,
        AllTypesMatch<["lowerBound", "upperBound", "step"]>,
        ConditionallySpeculatable,
        DeclareOpInterfaceMethods<RegionBranchOpInterface,
@@ -267,6 +268,74 @@ def ForOp : SCF_Op<"for",
       return getBody()->getArguments().drop_front(getNumInductionVars())[index];
     }
 
+    /// Return the induction variables.
+    ::mlir::ValueRange getInductionVars() {
+      return getBody()->getArguments().take_front(getNumInductionVars());
+    }
+
+    /// Get lower bounds as `OpFoldResult`.
+    SmallVector<OpFoldResult> getMixedLowerBound() {
+      return {getAsOpFoldResult(getLowerBound())};
+    }
+
+    /// Get upper bounds as `OpFoldResult`.
+    SmallVector<OpFoldResult> getMixedUpperBound() {
+      return {getAsOpFoldResult(getUpperBound())};
+    }
+
+    // Get steps as `OpFoldResult`.
+    SmallVector<OpFoldResult> getMixedStep() {
+      return {getAsOpFoldResult(getStep())};
+    }
+
+    /// Get lower bounds as values.
+    SmallVector<Value> getLowerBound(OpBuilder &b) {
+      return ValueRange{getLowerBound()};
+    }
+
+    /// Get upper bounds as values.
+    SmallVector<Value> getUpperBound(OpBuilder &b) {
+      return ValueRange{getUpperBound()};
+    }
+
+    /// Get steps as values.
+    SmallVector<Value> getStep(OpBuilder &b) {
+      return ValueRange{getStep()};
+    }
+
+    /// Set the lower bounds from `OpFoldResult`.
+    void setMixedLowerBounds(OpBuilder &b, ArrayRef<OpFoldResult> lbs) {
+      setLowerBound(getValueOrCreateConstantIndexOp(b, getLoc(), lbs[0]));
+    }
+
+    /// Set the upper bounds from `OpFoldResult`.
+    void setMixedUpperBounds(OpBuilder &b, ArrayRef<OpFoldResult> ubs) {
+      setUpperBound(getValueOrCreateConstantIndexOp(b, getLoc(), ubs[0]));
+    }
+
+    /// Set the steps from `OpFoldResult`.
+    void setMixedSteps(OpBuilder &b, ArrayRef<OpFoldResult> steps) {
+      setStep(getValueOrCreateConstantIndexOp(b, getLoc(), steps[0]));
+    }
+
+    /// Set the lower bounds from values.
+    void setLowerBounds(ArrayRef<Value> lbs) {
+      assert(lbs.size() == 1 && "expected a single lower bound");
+      setLowerBound(lbs[0]);
+    }
+
+    /// Set the upper bounds from values.
+    void setUpperBounds(ArrayRef<Value> ubs) {
+      assert(ubs.size() == 1 && "expected a single upper bound");
+      setUpperBound(ubs[0]);
+    }
+
+    /// Set the steps from values.
+    void setSteps(ArrayRef<Value> steps) {
+      assert(steps.size() == 1 && "expected a single step");
+      setStep(steps[0]);
+    }
+
     void setLowerBound(Value bound) { getOperation()->setOperand(0, bound); }
     void setUpperBound(Value bound) { getOperation()->setOperand(1, bound); }
     void setStep(Value step) { getOperation()->setOperand(2, step); }
@@ -304,6 +373,7 @@ def ForallOp : SCF_Op<"forall", [
           ["getInitsMutable", "getRegionIterArgs", "getSingleInductionVar", 
            "getSingleLowerBound", "getSingleUpperBound", "getSingleStep",
            "promoteIfSingleIteration", "yieldTiledValuesAndReplace"]>,
+       LoopLikeWithInductionVarsOpInterface,
        RecursiveMemoryEffects,
        SingleBlockImplicitTerminator<"scf::InParallelOp">,
        DeclareOpInterfaceMethods<RegionBranchOpInterface>,
@@ -543,6 +613,33 @@ def ForallOp : SCF_Op<"forall", [
       return getValueOrCreateConstantIndexOp(b, getLoc(), getMixedStep());
     }
 
+    /// Set the lower bounds from `OpFoldResult`.
+    void setMixedLowerBounds(OpBuilder &b, ArrayRef<OpFoldResult> lbs);
+
+    /// Set the upper bounds from `OpFoldResult`.
+    void setMixedUpperBounds(OpBuilder &b, ArrayRef<OpFoldResult> ubs);
+
+    /// Set the steps from `OpFoldResult`.
+    void setMixedSteps(OpBuilder &b, ArrayRef<OpFoldResult> steps);
+
+    /// Set the lower bounds from values.
+    void setLowerBounds(ArrayRef<Value> lbs) {
+      OpBuilder b(getOperation()->getContext());
+      return setMixedLowerBounds(b, getAsOpFoldResult(lbs));
+    }
+
+    /// Set the upper bounds from values.
+    void setUpperBounds(ArrayRef<Value> ubs) {
+      OpBuilder b(getOperation()->getContext());
+      return setMixedUpperBounds(b, getAsOpFoldResult(ubs));
+    }
+
+    /// Set the steps from values.
+    void setSteps(ArrayRef<Value> steps) {
+      OpBuilder b(getOperation()->getContext());
+      return setMixedSteps(b, getAsOpFoldResult(steps));
+    }
+
     int64_t getRank() { return getStaticLowerBound().size(); }
 
     /// Number of operands controlling the loop: lbs, ubs, steps
diff --git a/mlir/include/mlir/Dialect/Utils/LoopUtils.h b/mlir/include/mlir/Dialect/Utils/LoopUtils.h
new file mode 100644
index 0000000000000..15e901dc0e45e
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Utils/LoopUtils.h
@@ -0,0 +1,30 @@
+//===- LoopUtils.h - Helpers related to loop operations ---------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This header file defines utilities for loop operations.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/IR/PatternMatch.h"
+
+namespace mlir {
+
+// This structure is to pass and return sets of loop parameters without
+// confusing the order.
+struct LoopParams {
+  Value lowerBound;
+  Value upperBound;
+  Value step;
+};
+
+/// Calculate the normalized loop upper bounds with lower bound equal to zero
+/// and step equal to one.
+LoopParams emitNormalizedLoopBounds(RewriterBase &rewriter, Location loc,
+                                    Value lb, Value ub, Value step);
+
+} // namespace mlir
diff --git a/mlir/include/mlir/Interfaces/LoopLikeInterface.h b/mlir/include/mlir/Interfaces/LoopLikeInterface.h
index 42609e824c86a..fab5ffa26e574 100644
--- a/mlir/include/mlir/Interfaces/LoopLikeInterface.h
+++ b/mlir/include/mlir/Interfaces/LoopLikeInterface.h
@@ -13,6 +13,7 @@
 #ifndef MLIR_INTERFACES_LOOPLIKEINTERFACE_H_
 #define MLIR_INTERFACES_LOOPLIKEINTERFACE_H_
 
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
 #include "mlir/IR/OpDefinition.h"
 
 namespace mlir {
@@ -28,6 +29,9 @@ using NewYieldValuesFn = std::function<SmallVector<Value>(
 namespace detail {
 /// Verify invariants of the LoopLikeOpInterface.
 LogicalResult verifyLoopLikeOpInterface(Operation *op);
+
+/// Verify invariants of the LoopLikeWithInductionVarsOpInterface.
+LogicalResult verifyLoopLikeWithInductionVarsOpInterface(Operation *op);
 } // namespace detail
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Interfaces/LoopLikeInterface.td b/mlir/include/mlir/Interfaces/LoopLikeInterface.td
index f0dc6e60eba58..95a8c5a244b62 100644
--- a/mlir/include/mlir/Interfaces/LoopLikeInterface.td
+++ b/mlir/include/mlir/Interfaces/LoopLikeInterface.td
@@ -375,6 +375,132 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> {
   }];
 }
 
+def LoopLikeWithInductionVarsOpInterface 
+    : OpInterface<"LoopLikeWithInductionVarsOpInterface"> {
+  let description = [{
+    Interface for loop-like operations with one or more induction variables.
+    This interface contains helper functions for retrieving and updating the
+    lower bound, upper bound and step size for each induction variable and
+    provides a utility function to check whether the loop is normalized., i.e.
+    all lower bounds are equal to zero and steps are equal to one.
+  }];
+  let cppNamespace = "::mlir";
+
+  let methods = [
+    InterfaceMethod<[{
+        Return the induction variables if they exist, otherwise return
+        std::nullopt.
+      }],
+      /*retTy=*/"::mlir::ValueRange",
+      /*methodName=*/"getInductionVars"
+    >,
+    InterfaceMethod<[{
+        Return the lower bound values or attributes as OpFoldResult.
+      }],
+      /*retTy=*/"SmallVector<::mlir::OpFoldResult>",
+      /*methodName=*/"getMixedLowerBound"
+    >,
+    InterfaceMethod<[{
+        Return the step values or attributes if they exist as OpFoldResult.
+      }],
+      /*retTy=*/"SmallVector<::mlir::OpFoldResult>",
+      /*methodName=*/"getMixedStep"
+    >,
+    InterfaceMethod<[{
+        Return the upper bound values or attributes as OpFoldResult.
+      }],
+      /*retTy=*/"SmallVector<::mlir::OpFoldResult>",
+      /*methodName=*/"getMixedUpperBound"
+    >,
+    InterfaceMethod<[{
+        Return the lower bounds as values.
+      }],
+      /*retTy=*/"SmallVector<Value>",
+      /*methodName=*/"getLowerBound",
+      /*args=*/(ins "OpBuilder &":$b)
+    >,
+    InterfaceMethod<[{
+        Return the steps as values.
+      }],
+      /*retTy=*/"SmallVector<Value>",
+      /*methodName=*/"getStep",
+      /*args=*/(ins "OpBuilder &":$b)
+    >,
+    InterfaceMethod<[{
+        Return the upper bounds as values.
+      }],
+      /*retTy=*/"SmallVector<Value>",
+      /*methodName=*/"getUpperBound",
+      /*args=*/(ins "OpBuilder &":$b)
+    >,
+    InterfaceMethod<[{
+        Set the lower bounds from an array of `OpFoldResult`.
+      }],
+      /*retTy=*/"void",
+      /*methodName=*/"setMixedLowerBounds",
+      /*args=*/(ins "OpBuilder &":$b, "ArrayRef<OpFoldResult>":$lbs)
+    >,
+    InterfaceMethod<[{
+        Set the steps from an array of `OpFoldResult`.
+      }],
+      /*retTy=*/"void",
+      /*methodName=*/"setMixedSteps",
+      /*args=*/(ins "OpBuilder &":$b, "ArrayRef<OpFoldResult>":$lbs)
+    >,
+    InterfaceMethod<[{
+        Set the upper bounds from an array of `OpFoldResult`.
+      }],
+      /*retTy=*/"void",
+      /*methodName=*/"setMixedUpperBounds",
+      /*args=*/(ins "OpBuilder &":$b, "ArrayRef<OpFoldResult>":$lbs)
+    >,
+    InterfaceMethod<[{
+        Set the lower bounds from an array of values.
+      }],
+      /*retTy=*/"void",
+      /*methodName=*/"setLowerBounds",
+      /*args=*/(ins "ArrayRef<Value>":$lbs)
+    >,
+    InterfaceMethod<[{
+        Set the steps from an array of values.
+      }],
+      /*retTy=*/"void",
+      /*methodName=*/"setSteps",
+      /*args=*/(ins "ArrayRef<Value>":$lbs)
+    >,
+    InterfaceMethod<[{
+        Set the upper bounds from an array of values.
+      }],
+      /*retTy=*/"void",
+      /*methodName=*/"setUpperBounds",
+      /*args=*/(ins "ArrayRef<Value>":$lbs)
+    >,
+    InterfaceMethod<[{
+        Checks if the lower bounds are zeros and steps are ones.
+      }],
+      /*retTy=*/"bool",
+      /*methodName=*/"isNormalized",
+      /*args=*/(ins),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/[{
+        auto allEqual = [](ArrayRef<OpFoldResult> results, int64_t val) {
+          return llvm::all_of(results, [&](OpFoldResult ofr) {
+            auto intValue = getConstantIntValue(ofr);
+            return intValue.has_value() && intValue == val;
+          });
+        };
+        SmallVector<::mlir::OpFoldResult> lbs = $_op.getMixedLowerBound();
+        SmallVector<::mlir::OpFoldResult> steps = $_op.getMixedStep();
+        return allEqual(lbs, 0) && allEqual(steps, 1);
+      }]
+    >
+  ];
+
+  let verify = [{
+    return detail::verifyLoopLikeWithInductionVarsOpInterface($_op);
+  }];
+}
+
 //===----------------------------------------------------------------------===//
 // Traits
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Transforms/Passes.h b/mlir/include/mlir/Transforms/Passes.h
index 58bd61b2ae8b8..755ec7ecdfbad 100644
--- a/mlir/include/mlir/Transforms/Passes.h
+++ b/mlir/include/mlir/Transforms/Passes.h
@@ -82,6 +82,10 @@ std::unique_ptr<Pass> createLoopInvariantCodeMotionPass();
 /// Creates a pass that hoists loop-invariant subset ops.
 std::unique_ptr<Pass> createLoopInvariantSubsetHoistingPass();
 
+/// Create a pass that normalizes the loop bounds of loop-like operations with
+/// induction variables.
+std::unique_ptr<Pass> createNormalizeLoopBoundsPass();
+
 /// Creates a pass to strip debug information from a function.
 std::unique_ptr<Pass> createStripDebugInfoPass();
 
diff --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td
index 1b40a87c63f27..5d1256e502a12 100644
--- a/mlir/include/mlir/Transforms/Passes.td
+++ b/mlir/include/mlir/Transforms/Passes.td
@@ -377,6 +377,12 @@ def Mem2Reg : Pass<"mem2reg"> {
   ];
 }
 
+def NormalizeLoopBounds : Pass<"normalize-loop-bounds"> {
+  let summary = "Normalize the loop bounds of loop-like operations with "
+                "induction variables.";
+  let constructor = "mlir::createNormalizeLoopBoundsPass()";
+}
+
 def PrintOpStats : Pass<"print-op-stats"> {
   let summary = "Print statistics of operations";
   let constructor = "mlir::createPrintOpStatsPass()";
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 107fd0690f193..3e7becb094b6b 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -1387,6 +1387,66 @@ void ForallOp::build(
   build(b, result, lbs, ubs, steps, outputs, mapping, bodyBuilderFn);
 }
 
+/// Set the lower bounds from `OpFoldResult`.
+void ForallOp::setMixedLowerBounds(OpBuilder &b, ArrayRef<OpFoldResult> lbs) {
+  SmallVector<int64_t> staticLbs;
+  SmallVector<Value> dynamicLbs;
+  dispatchIndexOpFoldResults(lbs, dynamicLbs, staticLbs);
+  getOperation()->setOperands(0, getDynamicLowerBound().size(), dynamicLbs);
+  (*this)->setAttr(getStaticLowerBoundAttrName(),
+                   b.getDenseI64ArrayAttr(staticLbs));
+  ArrayRef<int32_t> segmentSizes =
+      (*this)
+          ->getAttrOfType<DenseI32ArrayAttr>("operandSegmentSizes")
+          .asArrayRef();
+  SmallVector<int32_t> newSegmentSizes(segmentSizes.begin(),
+                                       segmentSizes.end());
+  newSegmentSizes[0] = dynamicLbs.size();
+  (*this)->setAttr("operandSegmentSizes",
+                   b.getDenseI32ArrayAttr(newSegmentSizes));
+}
+
+/// Set the upper bounds from `OpFoldResult`.
+void ForallOp::setMixedUpperBounds(OpBuilder &b, ArrayRef<OpFoldResult> ubs) {
+  SmallVector<int64_t> staticUbs;
+  SmallVector<Value> dynamicUbs;
+  dispatchIndexOpFoldResults(ubs, dynamicUbs, staticUbs);
+  size_t offset = getDynamicLowerBound().size();
+  getOperation()->setOperands(offset, getDynamicUpperBound().size(),
+                              dynamicUbs);
+  (*this)->setAttr(getStaticUpperBoundAttrName(),
+                   b.getDenseI64ArrayAttr(staticUbs));
+  ArrayRef<int32_t> segmentSizes =
+      (*this)
+          ->getAttrOfType<DenseI32ArrayAttr>("operandSegmentSizes")
+          .asArrayRef();
+  SmallVector<int32_t> newSegmentSizes(segmentSizes.begin(),
+                                       segmentSizes.end());
+  newSegmentSizes[1] = dynamicUbs.size();
+  (*this)->setAttr("operandSegmentSizes",
+                   b.getDenseI32ArrayAttr(newSegmentSizes));
+}
+
+/// Set the steps from `OpFoldResult`.
+void ForallOp::setMixedSteps(OpBuilder &b, ArrayRef<OpFoldResult> steps) {
+  SmallVector<int64_t> staticSteps;
+  SmallVector<Value> dynamicSteps;
+  dispatchIndexOpFoldResults(steps, dynamicSteps, staticSteps);
+  size_t offset = getDynamicLowerBound().size() + getDynamicUpperBound().size();
+  getOperation()->setOperands(offset, getDynamicStep().size(), dynamicSteps);
+  (*this)->setAttr(getStaticStepAttrName(),
+                   b.getDenseI64ArrayAttr(staticSteps));
+  ArrayRef<int32_t> segmentSizes =
+      (*this)
+          ->getAttrOfType<DenseI32ArrayAttr>("operandSegmentSizes")
+          .asArrayRef();
+  SmallVector<int32_t> newSegmentSizes(segmentSizes.begin(),
+                                       segmentSizes.end());
+  newSegmentSizes[2] = dynamicSteps.size();
+  (*this)->setAttr("operandSegmentSizes",
+                   b.getDenseI32ArrayAttr(newSegmentSizes));
+}
+
 // Checks if the lbs are zeros and steps are ones.
 bool ForallOp::isNormalized() {
   auto allEqual = [](ArrayRef<OpFoldResult> results, int64_t val) {
diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
index 6658cca03eba7..41f52cb84f4ed 100644
--- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
@@ -16,6 +16,7 @@
 #include "mlir/Dialect/Arith/Utils/Utils.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/Utils/LoopUtils.h"
 #include "mlir/IR/BuiltinOps.h"
 #include "mlir/IR/IRMapping.h"
 #include "mlir/IR/PatternMatch.h"
@@ -29,16 +30,6 @@
 
 using namespace mlir;
 
-namespace {
-// This structure is to pass and return sets of loop parameters without
-// confusing the order.
-struct LoopParams {
-  Value lowerBound;
-  Value upperBound;
-  Value step;
-};
-} // namespace
-
 SmallVector<scf::ForOp> mlir::replaceLoopNestWithNewYields(
     RewriterBase &rewriter, MutableArrayRef<scf::ForOp> loopNest,
     ValueRange newIterOperands, const NewYieldValuesFn &newYieldValuesFn,
@@ -473,50 +464,6 @@ LogicalResult mlir::loopUnrollByFactor(
   return success();
 }
 
-/// Transform a loop with a strictly positive step
-///   for %i = %lb to %ub step %s
-/// into a 0-based loop with step 1
-///   for %ii = 0 to ceildiv(%ub - %lb, %s) step 1 {
-///     %i = %ii * %s + %lb
-/// Insert the induction variable remapping in the body of `inner`, which is
-/// expected to be either `loop` or another loop perfectly nested under `loop`.
-/// Insert the definition of new bounds immediate before `outer`, which is
-/// expected to be either `loop` or its parent in the loop nest.
-static LoopParams emitNormalizedLoopBounds(RewriterBase &rewriter, Location loc,
-                                           Value lb, Value ub, Value step) {
-  // For non-index types, generate `arith` instructions
-  // Check if the loop is already known to have a constant zero lower bound or
-  // a constant one step.
-  bool isZeroBased = false;
-  if (auto lbCst = getConstantIntValue(lb))
-    isZeroBased = lbCst.value() == 0;
-
-  bool isStepOne = false;
-  if (auto stepCst = getConstantIntValue(step))
-    isStepOne = stepCst.value() == 1;
-
-  // Compute the number of iterations the loop executes: ceildiv(ub - lb, step)
-  // assuming the step is strictly positive.  Update the bounds and the step
-  // of the loop to go from 0 to the number of iterations, if necessary.
-  if (isZeroBased && isStepOne)
-    return {lb, ub, step};
-
-  Value diff = isZeroBased ? ub : rewriter.create<arith::SubIOp>(loc, ub, lb);
-  Value newUpperBound =
-      isStepOne ? diff : rewriter.create<arith::CeilDivSIOp>(loc, diff, step);
-
-  Value newLowerBound = isZeroBased
-                            ? lb
-                            : rewriter.create<arith::ConstantOp>(
-                                  loc, rewriter.getZeroAttr(lb.getType()));
-  Value newStep = isStepOne
-                      ? step
-                      : rewriter.create<arith::ConstantOp>(
-                            loc, rewriter.getIntegerAttr(step.getType(), 1));
-
-  return {newLowerBound, newUpperBound, newStep};
-}
-
 /// Get back the original induction variable values after loop normalization
 static void denormalizeInductionVariable(RewriterBase &rewriter, Location loc,
                                          Value normalizedIv, Value origLb,
diff --git a/mlir/lib/Dialect/Utils/CMakeLists.txt b/mlir/lib/Dialect/Utils/CMakeLists.txt
index a0096e5f299d5..41b2fe287beb3 100644
--- a/mlir/lib/Dialect/Utils/CMakeLists.txt
+++ b/mlir/lib/Dialect/Utils/CMakeLists.txt
@@ -1,5 +1,6 @@
 add_mlir_library(MLIRDialectUtils
   IndexingUtils.cpp
+  LoopUtils.cpp
   ReshapeOpsUtils.cpp
   StructuredOpsUtils.cpp
   StaticValueUtils.cpp
diff --git a/mlir/lib/Dialect/Utils/LoopUtils.cpp b/mlir/lib/Dialect/Utils/LoopUtils.cpp
new file mode 100644
index 0000000000000..3d8aa5ef7dfc1
--- /dev/null
+++ b/mlir/lib/Dialect/Utils/LoopUtils.cpp
@@ -0,0 +1,52 @@
+//...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/93781


More information about the Mlir-commits mailing list