[Mlir-commits] [mlir] [mlir] IntRangeNarrowing: Narrow loop induction variables. (PR #175455)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sun Jan 11 12:25:56 PST 2026


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-scf

Author: Ivan Butygin (Hardcode84)

<details>
<summary>Changes</summary>

There are 2 parts:
* Update `LoopLikeOpInterface` to check the supported induction var type and to update the loop bounds.
* Implement `NarrowLoopBounds` pattern which tries to narrow loop induction var and bounds using this new interface.

---

Patch is 21.79 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/175455.diff


9 Files Affected:

- (modified) mlir/include/mlir/Dialect/Arith/Transforms/Passes.h (+6) 
- (modified) mlir/include/mlir/Dialect/SCF/IR/SCFOps.td (+2-1) 
- (modified) mlir/include/mlir/Interfaces/LoopLikeInterface.td (+48) 
- (modified) mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt (+1) 
- (modified) mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp (+197-2) 
- (modified) mlir/lib/Dialect/SCF/IR/SCF.cpp (+34) 
- (modified) mlir/lib/Interfaces/LoopLikeInterface.cpp (+11) 
- (modified) mlir/test/Dialect/Arith/int-range-narrowing.mlir (+51-11) 
- (modified) mlir/test/Dialect/SCF/invalid.mlir (+13) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
index b03cf2db78041..18ac0dbc8d13e 100644
--- a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
@@ -87,6 +87,12 @@ void populateIntRangeNarrowingPatterns(RewritePatternSet &patterns,
                                        DataFlowSolver &solver,
                                        ArrayRef<unsigned> bitwidthsSupported);
 
+/// Add patterns for narrowing control flow values (loop bounds, steps, etc.)
+/// based on int range analysis.
+void populateControlFlowValuesNarrowingPatterns(
+    RewritePatternSet &patterns, DataFlowSolver &solver,
+    ArrayRef<unsigned> bitwidthsSupported);
+
 //===----------------------------------------------------------------------===//
 // Registration
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
index 8bdf3e0b566ef..7c26d78683de9 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
@@ -154,7 +154,8 @@ def ForOp : SCF_Op<"for",
         "getLoopInductionVars", "getLoopLowerBounds", "getLoopSteps",
         "getLoopUpperBounds", "getStaticTripCount", "getYieldedValuesMutable",
         "promoteIfSingleIteration", "replaceWithAdditionalYields",
-        "yieldTiledValuesAndReplace"]>,
+        "yieldTiledValuesAndReplace", "isValidInductionVarType",
+        "setLoopLowerBounds", "setLoopUpperBounds", "setLoopSteps"]>,
        AllTypesMatch<["lowerBound", "upperBound", "step"]>,
        ConditionallySpeculatable,
        DeclareOpInterfaceMethods<RegionBranchOpInterface,
diff --git a/mlir/include/mlir/Interfaces/LoopLikeInterface.td b/mlir/include/mlir/Interfaces/LoopLikeInterface.td
index e09b8672f2d08..89526e92a4c92 100644
--- a/mlir/include/mlir/Interfaces/LoopLikeInterface.td
+++ b/mlir/include/mlir/Interfaces/LoopLikeInterface.td
@@ -243,6 +243,54 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> {
       /*defaultImplementation=*/[{
         return ::std::nullopt;
       }]
+    >,
+    InterfaceMethod<[{
+        Check if the given type is supported as a loop induction variable.
+        By default, only IndexType is supported.
+      }],
+      /*retTy=*/"bool",
+      /*methodName=*/"isValidInductionVarType",
+      /*args=*/(ins "::mlir::Type":$type),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/[{
+        return type.isIndex();
+      }]
+    >,
+    InterfaceMethod<[{
+        Update the loop lower bounds. Returns failure if the loop doesn't
+        support updating bounds or if the number of bounds doesn't match.
+      }],
+      /*retTy=*/"::mlir::LogicalResult",
+      /*methodName=*/"setLoopLowerBounds",
+      /*args=*/(ins "::llvm::ArrayRef<::mlir::OpFoldResult>":$bounds),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/[{
+        return ::mlir::failure();
+      }]
+    >,
+    InterfaceMethod<[{
+        Update the loop upper bounds. Returns failure if the loop doesn't
+        support updating bounds or if the number of bounds doesn't match.
+      }],
+      /*retTy=*/"::mlir::LogicalResult",
+      /*methodName=*/"setLoopUpperBounds",
+      /*args=*/(ins "::llvm::ArrayRef<::mlir::OpFoldResult>":$bounds),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/[{
+        return ::mlir::failure();
+      }]
+    >,
+    InterfaceMethod<[{
+        Update the loop steps. Returns failure if the loop doesn't support
+        updating steps or if the number of steps doesn't match.
+      }],
+      /*retTy=*/"::mlir::LogicalResult",
+      /*methodName=*/"setLoopSteps",
+      /*args=*/(ins "::llvm::ArrayRef<::mlir::OpFoldResult>":$steps),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/[{
+        return ::mlir::failure();
+      }]
     >
   ];
 
diff --git a/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt
index 4780dbb4338bb..637e16a3963d6 100644
--- a/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt
@@ -26,6 +26,7 @@ add_mlir_dialect_library(MLIRArithTransforms
   MLIRFuncTransforms
   MLIRInferIntRangeInterface
   MLIRIR
+  MLIRLoopLikeInterface
   MLIRMemRefDialect
   MLIRShardDialect
   MLIRPass
diff --git a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
index a85d3bb43a372..01b8ed7ccae54 100644
--- a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
@@ -21,6 +21,7 @@
 #include "mlir/IR/Matchers.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/TypeUtilities.h"
+#include "mlir/Interfaces/LoopLikeInterface.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 #include "mlir/Transforms/FoldUtils.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -50,8 +51,6 @@ static std::optional<APInt> getMaybeConstantValue(DataFlowSolver &solver,
 
 static void copyIntegerRange(DataFlowSolver &solver, Value oldVal,
                              Value newVal) {
-  assert(oldVal.getType() == newVal.getType() &&
-         "Can't copy integer ranges between different types");
   auto *oldState = solver.lookupState<IntegerValueRangeLattice>(oldVal);
   if (!oldState)
     return;
@@ -479,6 +478,193 @@ struct FoldIndexCastChain final : OpRewritePattern<CastOp> {
   SmallVector<unsigned, 4> targetBitwidths;
 };
 
+struct NarrowLoopBounds final : OpInterfaceRewritePattern<LoopLikeOpInterface> {
+  NarrowLoopBounds(MLIRContext *context, DataFlowSolver &s,
+                   ArrayRef<unsigned> target)
+      : OpInterfaceRewritePattern<LoopLikeOpInterface>(context), solver(s),
+        targetBitwidths(target) {}
+
+  LogicalResult matchAndRewrite(LoopLikeOpInterface loopLike,
+                                PatternRewriter &rewriter) const override {
+    // Skip ops where bounds narrowing previously failed.
+    if (loopLike->hasAttr("arith.bounds_narrowing_failed"))
+      return failure();
+
+    auto inductionVars = loopLike.getLoopInductionVars();
+    if (!inductionVars.has_value() || inductionVars->empty())
+      return rewriter.notifyMatchFailure(loopLike, "no induction variables");
+
+    auto lowerBounds = loopLike.getLoopLowerBounds();
+    auto upperBounds = loopLike.getLoopUpperBounds();
+    auto steps = loopLike.getLoopSteps();
+
+    if (!lowerBounds.has_value() || !upperBounds.has_value() ||
+        !steps.has_value())
+      return rewriter.notifyMatchFailure(loopLike, "no loop bounds or steps");
+
+    if (lowerBounds->size() != inductionVars->size() ||
+        upperBounds->size() != inductionVars->size() ||
+        steps->size() != inductionVars->size())
+      return rewriter.notifyMatchFailure(loopLike,
+                                         "mismatched bounds/steps count");
+
+    Location loc = loopLike->getLoc();
+    SmallVector<OpFoldResult> newLowerBounds(*lowerBounds);
+    SmallVector<OpFoldResult> newUpperBounds(*upperBounds);
+    SmallVector<OpFoldResult> newSteps(*steps);
+    SmallVector<std::pair<size_t, std::pair<Type, CastKind>>> narrowings;
+
+    // Check each (indVar, lb, ub, step) tuple.
+    for (auto [idx, indVar, lbOFR, ubOFR, stepOFR] :
+         llvm::enumerate(*inductionVars, *lowerBounds, *upperBounds, *steps)) {
+
+      // Only process value operands, skip attributes.
+      auto maybeLb = dyn_cast<Value>(lbOFR);
+      auto maybeUb = dyn_cast<Value>(ubOFR);
+      auto maybeStep = dyn_cast<Value>(stepOFR);
+
+      if (!maybeLb || !maybeUb || !maybeStep)
+        continue;
+
+      // Collect ranges for (lb, ub, step, indVar).
+      SmallVector<ConstantIntRanges> ranges;
+      if (failed(collectRanges(
+              solver, ValueRange{maybeLb, maybeUb, maybeStep, indVar}, ranges)))
+        continue;
+
+      const ConstantIntRanges &lbRange = ranges[0];
+      const ConstantIntRanges &ubRange = ranges[1];
+      const ConstantIntRanges &stepRange = ranges[2];
+      const ConstantIntRanges &indVarRange = ranges[3];
+
+      Type srcType = maybeLb.getType();
+
+      // Try each target bitwidth.
+      for (unsigned targetBitwidth : targetBitwidths) {
+        Type targetType = getTargetType(srcType, targetBitwidth);
+        if (targetType == srcType)
+          continue;
+
+        // Check if the target type is valid for this loop's induction
+        // variables.
+        if (!loopLike.isValidInductionVarType(targetType))
+          continue;
+
+        // Check if all values in this tuple can be truncated.
+        CastKind castKind = CastKind::Both;
+        for (const ConstantIntRanges &range : ranges) {
+          castKind = mergeCastKinds(castKind,
+                                    checkTruncatability(range, targetBitwidth));
+          if (castKind == CastKind::None)
+            break;
+        }
+
+        if (castKind == CastKind::None)
+          continue;
+
+        // Additional check: ensure indVar + step doesn't overflow.
+        // During loop increment, we compute iv_next = iv_current + step in the
+        // narrowed type. If this overflows, the loop behavior becomes incorrect.
+        // We must check both signed and unsigned ranges since we don't know
+        // whether the loop semantics treat values as signed or unsigned.
+        unsigned srcWidth = lbRange.smin().getBitWidth();
+        unsigned removedWidth = srcWidth - targetBitwidth;
+
+        // Check that max(indVar) + step fits in the target type.
+        APInt indVarPlusStepSmin = indVarRange.smin().sadd_sat(stepRange.smin());
+        APInt indVarPlusStepSmax = indVarRange.smax().sadd_sat(stepRange.smax());
+        APInt indVarPlusStepUmin = indVarRange.umin().uadd_sat(stepRange.umin());
+        APInt indVarPlusStepUmax = indVarRange.umax().uadd_sat(stepRange.umax());
+
+        bool indVarPlusStepFitsSigned =
+            indVarPlusStepSmin.getNumSignBits() >= (removedWidth + 1) &&
+            indVarPlusStepSmax.getNumSignBits() >= (removedWidth + 1);
+        bool indVarPlusStepFitsUnsigned =
+            indVarPlusStepUmin.countLeadingZeros() >= removedWidth &&
+            indVarPlusStepUmax.countLeadingZeros() >= removedWidth;
+
+        // Both signed and unsigned must fit since loop semantics are unknown.
+        if (!indVarPlusStepFitsSigned || !indVarPlusStepFitsUnsigned)
+          continue;
+
+        // Narrow the bounds and step values.
+        Value newLb = doCast(rewriter, loc, maybeLb, targetType, castKind);
+        Value newUb = doCast(rewriter, loc, maybeUb, targetType, castKind);
+        Value newStep = doCast(rewriter, loc, maybeStep, targetType, castKind);
+
+        newLowerBounds[idx] = newLb;
+        newUpperBounds[idx] = newUb;
+        newSteps[idx] = newStep;
+        narrowings.push_back({idx, {targetType, castKind}});
+        break;
+      }
+    }
+
+    if (narrowings.empty())
+      return failure();
+
+    // Save original types before modifying.
+    SmallVector<Type> origTypes;
+    for (auto [idx, typeAndCast] : narrowings) {
+      Value indVar = (*inductionVars)[idx];
+      origTypes.push_back(indVar.getType());
+    }
+
+    // Attempt to update bounds and induction variable types.
+    // If this fails, mark the op so we don't try again.
+    bool updateFailed = false;
+    rewriter.modifyOpInPlace(loopLike, [&]() {
+      // Update the loop bounds and steps.
+      if (failed(loopLike.setLoopLowerBounds(newLowerBounds)) ||
+          failed(loopLike.setLoopUpperBounds(newUpperBounds)) ||
+          failed(loopLike.setLoopSteps(newSteps))) {
+        // Mark op to prevent future attempts. IR was modified (attribute
+        // added), so we must return success() from the pattern.
+        loopLike->setAttr("arith.bounds_narrowing_failed",
+                          rewriter.getUnitAttr());
+        updateFailed = true;
+        return;
+      }
+
+      // Update induction variable types.
+      for (auto [idx, typeAndCast] : narrowings) {
+        auto [targetType, castKind] = typeAndCast;
+        Value indVar = (*inductionVars)[idx];
+        auto blockArg = cast<BlockArgument>(indVar);
+
+        // Change the block argument type.
+        blockArg.setType(targetType);
+      }
+    });
+
+    if (updateFailed)
+      return success();
+
+    // Insert casts back to original type for uses.
+    for (auto [narrowingIdx, narrowingInfo] : llvm::enumerate(narrowings)) {
+      auto [idx, typeAndCast] = narrowingInfo;
+      auto [targetType, castKind] = typeAndCast;
+      Value indVar = (*inductionVars)[idx];
+      auto blockArg = cast<BlockArgument>(indVar);
+      Type origType = origTypes[narrowingIdx];
+
+      OpBuilder::InsertionGuard guard(rewriter);
+      rewriter.setInsertionPointToStart(blockArg.getOwner());
+      Value casted = doCast(rewriter, loc, blockArg, origType, castKind);
+      copyIntegerRange(solver, blockArg, casted);
+
+      // Replace all uses of the narrowed indVar with the casted value.
+      rewriter.replaceAllUsesExcept(blockArg, casted, casted.getDefiningOp());
+    }
+
+    return success();
+  }
+
+private:
+  DataFlowSolver &solver;
+  SmallVector<unsigned, 4> targetBitwidths;
+};
+
 struct IntRangeOptimizationsPass final
     : arith::impl::ArithIntRangeOptsBase<IntRangeOptimizationsPass> {
 
@@ -520,6 +706,8 @@ struct IntRangeNarrowingPass final
 
     RewritePatternSet patterns(ctx);
     populateIntRangeNarrowingPatterns(patterns, solver, bitwidthsSupported);
+    populateControlFlowValuesNarrowingPatterns(patterns, solver,
+                                               bitwidthsSupported);
 
     // We specifically need bottom-up traversal as cmpi pattern needs range
     // data, attached to its original argument values.
@@ -548,6 +736,13 @@ void mlir::arith::populateIntRangeNarrowingPatterns(
                                                        bitwidthsSupported);
 }
 
+void mlir::arith::populateControlFlowValuesNarrowingPatterns(
+    RewritePatternSet &patterns, DataFlowSolver &solver,
+    ArrayRef<unsigned> bitwidthsSupported) {
+  patterns.add<NarrowLoopBounds>(patterns.getContext(), solver,
+                                 bitwidthsSupported);
+}
+
 std::unique_ptr<Pass> mlir::arith::createIntRangeOptimizationsPass() {
   return std::make_unique<IntRangeOptimizationsPass>();
 }
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 8803a6d136f7a..4d85e2f4bcf47 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -520,6 +520,40 @@ std::optional<SmallVector<OpFoldResult>> ForOp::getLoopUpperBounds() {
   return SmallVector<OpFoldResult>{OpFoldResult(getUpperBound())};
 }
 
+bool ForOp::isValidInductionVarType(Type type) {
+  return type.isIndex() || type.isSignlessInteger();
+}
+
+LogicalResult ForOp::setLoopLowerBounds(ArrayRef<OpFoldResult> bounds) {
+  if (bounds.size() != 1)
+    return failure();
+  if (auto val = dyn_cast<Value>(bounds[0])) {
+    setLowerBound(val);
+    return success();
+  }
+  return failure();
+}
+
+LogicalResult ForOp::setLoopUpperBounds(ArrayRef<OpFoldResult> bounds) {
+  if (bounds.size() != 1)
+    return failure();
+  if (auto val = dyn_cast<Value>(bounds[0])) {
+    setUpperBound(val);
+    return success();
+  }
+  return failure();
+}
+
+LogicalResult ForOp::setLoopSteps(ArrayRef<OpFoldResult> steps) {
+  if (steps.size() != 1)
+    return failure();
+  if (auto val = dyn_cast<Value>(steps[0])) {
+    setStep(val);
+    return success();
+  }
+  return failure();
+}
+
 std::optional<ResultRange> ForOp::getLoopResults() { return getResults(); }
 
 /// Promotes the loop body of a forOp to its containing block if the forOp
diff --git a/mlir/lib/Interfaces/LoopLikeInterface.cpp b/mlir/lib/Interfaces/LoopLikeInterface.cpp
index d4cef29008c2a..ae2ac13187dff 100644
--- a/mlir/lib/Interfaces/LoopLikeInterface.cpp
+++ b/mlir/lib/Interfaces/LoopLikeInterface.cpp
@@ -110,5 +110,16 @@ LogicalResult detail::verifyLoopLikeOpInterface(Operation *op) {
     ++i;
   }
 
+  // Verify that all induction variables have valid types.
+  auto inductionVars = loopLikeOp.getLoopInductionVars();
+  if (inductionVars.has_value()) {
+    for (auto [index, inductionVar] : llvm::enumerate(*inductionVars)) {
+      if (!loopLikeOp.isValidInductionVarType(inductionVar.getType()))
+        return op->emitOpError(std::to_string(index))
+               << "-th induction variable has invalid type: "
+               << inductionVar.getType();
+    }
+  }
+
   return success();
 }
diff --git a/mlir/test/Dialect/Arith/int-range-narrowing.mlir b/mlir/test/Dialect/Arith/int-range-narrowing.mlir
index 42dd4294b86e9..9107bf649b561 100644
--- a/mlir/test/Dialect/Arith/int-range-narrowing.mlir
+++ b/mlir/test/Dialect/Arith/int-range-narrowing.mlir
@@ -314,21 +314,28 @@ func.func @i32_overflows_to_i64(%arg0: i32) -> i64 {
 // Motivating example for negative number support, added as a test case
 // and simplified
 // CHECK-LABEL: func.func @clamp_to_loop_bound_and_id()
+// CHECK-DAG: %[[C64_I8:.+]] = arith.constant 64 : i8
+// CHECK-DAG: %[[C64_I16:.+]] = arith.constant 64 : i16
+// CHECK-DAG: %[[C16_I16:.+]] = arith.constant 16 : i16
+// CHECK-DAG: %[[C0_I16:.+]] = arith.constant 0 : i16
 // CHECK: %[[TID:.+]] = test.with_bounds
 // CHECK-SAME: umax = 63
 // CHECK: %[[BOUND:.+]] = test.with_bounds
 // CHECK-SAME: umax = 112
-// CHECK: scf.for %[[ARG0:.+]] = %{{.*}} to %[[BOUND]] step %{{.*}}
-// CHECK-DAG:   %[[BOUND_I8:.+]] = arith.index_castui %[[BOUND]] : index to i8
-// CHECK-DAG:   %[[ARG0_I8:.+]] = arith.index_castui %[[ARG0]] : index to i8
-//     CHECK:   %[[V0_I8:.+]] = arith.subi %[[BOUND_I8]], %[[ARG0_I8]] : i8
-//     CHECK:   %[[V1_I8:.+]] = arith.minsi %[[V0_I8]], %{{.*}} : i8
-//     CHECK:   %[[V1_INDEX:.+]] = arith.index_cast %[[V1_I8]] : i8 to index
-//     CHECK:   %[[V1_I16:.+]] = arith.index_cast %[[V1_INDEX]] : index to i16
-//     CHECK:   %[[TID_I16:.+]] = arith.index_castui %[[TID]] : index to i16
-//     CHECK:   %[[V2_I16:.+]] = arith.subi %[[V1_I16]], %[[TID_I16]] : i16
-//     CHECK:   %[[V3:.+]] = arith.cmpi slt, %[[V2_I16]], %{{.*}} : i16
-//     CHECK:   scf.if %[[V3]]
+// Loop narrows to i16 (not i8) because indVar+step=[80,144] doesn't fit in signed i8.
+// CHECK: %[[BOUND_I16:.+]] = arith.index_castui %[[BOUND]] : index to i16
+// CHECK: scf.for %[[ARG0:.+]] = %[[C16_I16]] to %[[BOUND_I16]] step %[[C64_I16]]  : i16 {
+// CHECK:   %[[ARG0_INDEX:.+]] = arith.index_castui %[[ARG0]] : i16 to index
+// CHECK:   %[[BOUND_I8:.+]] = arith.index_castui %[[BOUND]] : index to i8
+// CHECK:   %[[ARG0_I8:.+]] = arith.index_castui %[[ARG0_INDEX]] : index to i8
+// CHECK:   %[[V0_I8:.+]] = arith.subi %[[BOUND_I8]], %[[ARG0_I8]] : i8
+// CHECK:   %[[V1_I8:.+]] = arith.minsi %[[V0_I8]], %[[C64_I8]] : i8
+// CHECK:   %[[V1_INDEX:.+]] = arith.index_cast %[[V1_I8]] : i8 to index
+// CHECK:   %[[V1_I16:.+]] = arith.index_cast %[[V1_INDEX]] : index to i16
+// CHECK:   %[[TID_I16:.+]] = arith.index_castui %[[TID]] : index to i16
+// CHECK:   %[[V2_I16:.+]] = arith.subi %[[V1_I16]], %[[TID_I16]] : i16
+// CHECK:   %[[V3:.+]] = arith.cmpi slt, %[[V2_I16]], %[[C0_I16]] : i16
+// CHECK:   scf.if %[[V3]]
 func.func @clamp_to_loop_bound_and_id() {
   %c0 = arith.constant 0 : index
   %c16 = arith.constant 16 : index
@@ -349,6 +356,7 @@ func.func @clamp_to_loop_bound_and_id() {
 }
 
 func.func private @use(index)
+func.func private @use_i64(i64)
 
 // CHECK-LABEL: func.func @loop_with_iter_arg
 func.func @loop_with_iter_arg() {
@@ -377,3 +385,35 @@ func.func @loop_with_iter_arg() {
   }
   return
 }
+
+//===----------------------------------------------------------------------===//
+// Loop bounds narrowing
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: func.func @narrow_loop_bounds
+func.func @narrow_loop_bounds() {
+  %c0_i64 = arith.constant 0 : i64
+  %c10_i64 = arith.constant 10 : i64
+  %c1_i64 = arith.constant 1 : i64
+
+  // CHECK-DAG: %[[C1_I8:.*]] = arith.constant 1 : i8
+  // CHECK-DAG: %[[LB:.*]] = test.with_bounds {smax = 0 : i64, smin = 0 : i64, umax = 0 : i64, umin = 0 : i64} : i64
+  // CHECK-DAG: %[[UB:.*]] = test.with_bounds {smax = 10 : i64, smin = 10 : i64, umax = 10 :...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/175455


More information about the Mlir-commits mailing list