[Mlir-commits] [mlir] [mlir] IntRangeNarrowing: Narrow loop induction variables. (PR #175455)

Ivan Butygin llvmlistbot at llvm.org
Mon Jan 12 10:54:30 PST 2026


https://github.com/Hardcode84 updated https://github.com/llvm/llvm-project/pull/175455

>From 57e264c0a51d3633fc4df4c7a9c0fae53acf19b3 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sun, 11 Jan 2026 17:26:48 +0100
Subject: [PATCH 01/13] interface

---
 mlir/include/mlir/Dialect/SCF/IR/SCFOps.td        |  2 +-
 mlir/include/mlir/Interfaces/LoopLikeInterface.td | 12 ++++++++++++
 mlir/lib/Dialect/SCF/IR/SCF.cpp                   |  4 ++++
 mlir/lib/Interfaces/LoopLikeInterface.cpp         | 11 +++++++++++
 mlir/test/Dialect/SCF/invalid.mlir                | 13 +++++++++++++
 5 files changed, 41 insertions(+), 1 deletion(-)

diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
index 8bdf3e0b566ef..b519d00e9613b 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
@@ -154,7 +154,7 @@ def ForOp : SCF_Op<"for",
         "getLoopInductionVars", "getLoopLowerBounds", "getLoopSteps",
         "getLoopUpperBounds", "getStaticTripCount", "getYieldedValuesMutable",
         "promoteIfSingleIteration", "replaceWithAdditionalYields",
-        "yieldTiledValuesAndReplace"]>,
+        "yieldTiledValuesAndReplace", "isValidInductionVarType"]>,
        AllTypesMatch<["lowerBound", "upperBound", "step"]>,
        ConditionallySpeculatable,
        DeclareOpInterfaceMethods<RegionBranchOpInterface,
diff --git a/mlir/include/mlir/Interfaces/LoopLikeInterface.td b/mlir/include/mlir/Interfaces/LoopLikeInterface.td
index e09b8672f2d08..b9f76e7edc23c 100644
--- a/mlir/include/mlir/Interfaces/LoopLikeInterface.td
+++ b/mlir/include/mlir/Interfaces/LoopLikeInterface.td
@@ -243,6 +243,18 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> {
       /*defaultImplementation=*/[{
         return ::std::nullopt;
       }]
+    >,
+    InterfaceMethod<[{
+        Check if the given type is supported as a loop induction variable.
+        By default, only IndexType is supported.
+      }],
+      /*retTy=*/"bool",
+      /*methodName=*/"isValidInductionVarType",
+      /*args=*/(ins "::mlir::Type":$type),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/[{
+        return type.isIndex();
+      }]
     >
   ];
 
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 8803a6d136f7a..81982b5b8ef55 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -520,6 +520,10 @@ std::optional<SmallVector<OpFoldResult>> ForOp::getLoopUpperBounds() {
   return SmallVector<OpFoldResult>{OpFoldResult(getUpperBound())};
 }
 
+bool ForOp::isValidInductionVarType(Type type) {
+  return type.isIndex() || type.isSignlessInteger();
+}
+
 std::optional<ResultRange> ForOp::getLoopResults() { return getResults(); }
 
 /// Promotes the loop body of a forOp to its containing block if the forOp
diff --git a/mlir/lib/Interfaces/LoopLikeInterface.cpp b/mlir/lib/Interfaces/LoopLikeInterface.cpp
index d4cef29008c2a..ae2ac13187dff 100644
--- a/mlir/lib/Interfaces/LoopLikeInterface.cpp
+++ b/mlir/lib/Interfaces/LoopLikeInterface.cpp
@@ -110,5 +110,16 @@ LogicalResult detail::verifyLoopLikeOpInterface(Operation *op) {
     ++i;
   }
 
+  // Verify that all induction variables have valid types.
+  auto inductionVars = loopLikeOp.getLoopInductionVars();
+  if (inductionVars.has_value()) {
+    for (auto [index, inductionVar] : llvm::enumerate(*inductionVars)) {
+      if (!loopLikeOp.isValidInductionVarType(inductionVar.getType()))
+        return op->emitOpError(std::to_string(index))
+               << "-th induction variable has invalid type: "
+               << inductionVar.getType();
+    }
+  }
+
   return success();
 }
diff --git a/mlir/test/Dialect/SCF/invalid.mlir b/mlir/test/Dialect/SCF/invalid.mlir
index 13a9b1cd38d88..985e347fbf5ee 100644
--- a/mlir/test/Dialect/SCF/invalid.mlir
+++ b/mlir/test/Dialect/SCF/invalid.mlir
@@ -60,6 +60,19 @@ func.func @loop_for_single_block(%arg0: index) {
 
 func.func @loop_for_single_index_argument(%arg0: index) {
   // expected-error at +1 {{expected induction variable to be same type as bounds}}
+  "scf.for"(%arg0, %arg0, %arg0) (
+    {
+    ^bb0(%i0 : i32):
+      scf.yield
+    }
+  ) : (index, index, index) -> ()
+  return
+}
+
+// -----
+
+func.func @loop_for_invalid_ind_type(%arg0: index) {
+  // expected-error at +1 {{0-th induction variable has invalid type: 'f32'}}
   "scf.for"(%arg0, %arg0, %arg0) (
     {
     ^bb0(%i0 : f32):

>From 8dc86a10cbccb259d5795797a9eb49acc556697e Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sun, 11 Jan 2026 20:06:26 +0100
Subject: [PATCH 02/13] wip

---
 .../mlir/Dialect/Arith/Transforms/Passes.h    |   6 +
 mlir/include/mlir/Dialect/SCF/IR/SCFOps.td    |   3 +-
 .../mlir/Interfaces/LoopLikeInterface.td      |  36 +++++
 .../Transforms/IntRangeOptimizations.cpp      | 153 +++++++++++++++++-
 mlir/lib/Dialect/SCF/IR/SCF.cpp               |  30 ++++
 .../Dialect/Arith/int-range-narrowing.mlir    |  39 ++++-
 6 files changed, 261 insertions(+), 6 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
index b03cf2db78041..18ac0dbc8d13e 100644
--- a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
@@ -87,6 +87,12 @@ void populateIntRangeNarrowingPatterns(RewritePatternSet &patterns,
                                        DataFlowSolver &solver,
                                        ArrayRef<unsigned> bitwidthsSupported);
 
+/// Add patterns for narrowing control flow values (loop bounds, steps, etc.)
+/// based on int range analysis.
+void populateControlFlowValuesNarrowingPatterns(
+    RewritePatternSet &patterns, DataFlowSolver &solver,
+    ArrayRef<unsigned> bitwidthsSupported);
+
 //===----------------------------------------------------------------------===//
 // Registration
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
index b519d00e9613b..7c26d78683de9 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
@@ -154,7 +154,8 @@ def ForOp : SCF_Op<"for",
         "getLoopInductionVars", "getLoopLowerBounds", "getLoopSteps",
         "getLoopUpperBounds", "getStaticTripCount", "getYieldedValuesMutable",
         "promoteIfSingleIteration", "replaceWithAdditionalYields",
-        "yieldTiledValuesAndReplace", "isValidInductionVarType"]>,
+        "yieldTiledValuesAndReplace", "isValidInductionVarType",
+        "setLoopLowerBounds", "setLoopUpperBounds", "setLoopSteps"]>,
        AllTypesMatch<["lowerBound", "upperBound", "step"]>,
        ConditionallySpeculatable,
        DeclareOpInterfaceMethods<RegionBranchOpInterface,
diff --git a/mlir/include/mlir/Interfaces/LoopLikeInterface.td b/mlir/include/mlir/Interfaces/LoopLikeInterface.td
index b9f76e7edc23c..89526e92a4c92 100644
--- a/mlir/include/mlir/Interfaces/LoopLikeInterface.td
+++ b/mlir/include/mlir/Interfaces/LoopLikeInterface.td
@@ -255,6 +255,42 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> {
       /*defaultImplementation=*/[{
         return type.isIndex();
       }]
+    >,
+    InterfaceMethod<[{
+        Update the loop lower bounds. Returns failure if the loop doesn't
+        support updating bounds or if the number of bounds doesn't match.
+      }],
+      /*retTy=*/"::mlir::LogicalResult",
+      /*methodName=*/"setLoopLowerBounds",
+      /*args=*/(ins "::llvm::ArrayRef<::mlir::OpFoldResult>":$bounds),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/[{
+        return ::mlir::failure();
+      }]
+    >,
+    InterfaceMethod<[{
+        Update the loop upper bounds. Returns failure if the loop doesn't
+        support updating bounds or if the number of bounds doesn't match.
+      }],
+      /*retTy=*/"::mlir::LogicalResult",
+      /*methodName=*/"setLoopUpperBounds",
+      /*args=*/(ins "::llvm::ArrayRef<::mlir::OpFoldResult>":$bounds),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/[{
+        return ::mlir::failure();
+      }]
+    >,
+    InterfaceMethod<[{
+        Update the loop steps. Returns failure if the loop doesn't support
+        updating steps or if the number of steps doesn't match.
+      }],
+      /*retTy=*/"::mlir::LogicalResult",
+      /*methodName=*/"setLoopSteps",
+      /*args=*/(ins "::llvm::ArrayRef<::mlir::OpFoldResult>":$steps),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/[{
+        return ::mlir::failure();
+      }]
     >
   ];
 
diff --git a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
index a85d3bb43a372..39038cc82553b 100644
--- a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
@@ -21,6 +21,7 @@
 #include "mlir/IR/Matchers.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/TypeUtilities.h"
+#include "mlir/Interfaces/LoopLikeInterface.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 #include "mlir/Transforms/FoldUtils.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -50,8 +51,6 @@ static std::optional<APInt> getMaybeConstantValue(DataFlowSolver &solver,
 
 static void copyIntegerRange(DataFlowSolver &solver, Value oldVal,
                              Value newVal) {
-  assert(oldVal.getType() == newVal.getType() &&
-         "Can't copy integer ranges between different types");
   auto *oldState = solver.lookupState<IntegerValueRangeLattice>(oldVal);
   if (!oldState)
     return;
@@ -479,6 +478,147 @@ struct FoldIndexCastChain final : OpRewritePattern<CastOp> {
   SmallVector<unsigned, 4> targetBitwidths;
 };
 
+struct NarrowLoopBounds final : OpInterfaceRewritePattern<LoopLikeOpInterface> {
+  NarrowLoopBounds(MLIRContext *context, DataFlowSolver &s,
+                   ArrayRef<unsigned> target)
+      : OpInterfaceRewritePattern<LoopLikeOpInterface>(context), solver(s),
+        targetBitwidths(target) {}
+
+  LogicalResult matchAndRewrite(LoopLikeOpInterface loopLike,
+                                PatternRewriter &rewriter) const override {
+    auto inductionVars = loopLike.getLoopInductionVars();
+    if (!inductionVars.has_value() || inductionVars->empty())
+      return rewriter.notifyMatchFailure(loopLike, "no induction variables");
+
+    auto lowerBounds = loopLike.getLoopLowerBounds();
+    auto upperBounds = loopLike.getLoopUpperBounds();
+    auto steps = loopLike.getLoopSteps();
+
+    if (!lowerBounds.has_value() || !upperBounds.has_value() ||
+        !steps.has_value())
+      return rewriter.notifyMatchFailure(loopLike, "no loop bounds or steps");
+
+    if (lowerBounds->size() != inductionVars->size() ||
+        upperBounds->size() != inductionVars->size() ||
+        steps->size() != inductionVars->size())
+      return rewriter.notifyMatchFailure(loopLike,
+                                         "mismatched bounds/steps count");
+
+    Location loc = loopLike->getLoc();
+    SmallVector<OpFoldResult> newLowerBounds(*lowerBounds);
+    SmallVector<OpFoldResult> newUpperBounds(*upperBounds);
+    SmallVector<OpFoldResult> newSteps(*steps);
+    SmallVector<std::pair<size_t, std::pair<Type, CastKind>>> narrowings;
+
+    // Check each (indVar, lb, ub, step) tuple.
+    for (auto [idx, indVar, lbOFR, ubOFR, stepOFR] :
+         llvm::enumerate(*inductionVars, *lowerBounds, *upperBounds, *steps)) {
+
+      // Only process value operands, skip attributes.
+      auto maybeLb = dyn_cast<Value>(lbOFR);
+      auto maybeUb = dyn_cast<Value>(ubOFR);
+      auto maybeStep = dyn_cast<Value>(stepOFR);
+
+      if (!maybeLb || !maybeUb || !maybeStep)
+        continue;
+
+      // Collect ranges for (lb, ub, step, indVar).
+      SmallVector<ConstantIntRanges> ranges;
+      if (failed(collectRanges(
+              solver, ValueRange{maybeLb, maybeUb, maybeStep, indVar}, ranges)))
+        continue;
+
+      Type srcType = maybeLb.getType();
+
+      // Try each target bitwidth.
+      for (unsigned targetBitwidth : targetBitwidths) {
+        Type targetType = getTargetType(srcType, targetBitwidth);
+        if (targetType == srcType)
+          continue;
+
+        // Check if the target type is valid for this loop's induction
+        // variables.
+        if (!loopLike.isValidInductionVarType(targetType))
+          continue;
+
+        // Check if all values in this tuple can be truncated.
+        CastKind castKind = CastKind::Both;
+        for (const ConstantIntRanges &range : ranges) {
+          castKind = mergeCastKinds(castKind,
+                                    checkTruncatability(range, targetBitwidth));
+          if (castKind == CastKind::None)
+            break;
+        }
+
+        if (castKind == CastKind::None)
+          continue;
+
+        // Narrow the bounds and step values.
+        Value newLb = doCast(rewriter, loc, maybeLb, targetType, castKind);
+        Value newUb = doCast(rewriter, loc, maybeUb, targetType, castKind);
+        Value newStep = doCast(rewriter, loc, maybeStep, targetType, castKind);
+
+        newLowerBounds[idx] = newLb;
+        newUpperBounds[idx] = newUb;
+        newSteps[idx] = newStep;
+        narrowings.push_back({idx, {targetType, castKind}});
+        break;
+      }
+    }
+
+    if (narrowings.empty())
+      return failure();
+
+    // Save original types before modifying.
+    SmallVector<Type> origTypes;
+    for (auto [idx, typeAndCast] : narrowings) {
+      Value indVar = (*inductionVars)[idx];
+      origTypes.push_back(indVar.getType());
+    }
+
+    rewriter.modifyOpInPlace(loopLike, [&]() {
+      // Update the loop bounds and steps.
+      if (failed(loopLike.setLoopLowerBounds(newLowerBounds)) ||
+          failed(loopLike.setLoopUpperBounds(newUpperBounds)) ||
+          failed(loopLike.setLoopSteps(newSteps)))
+        llvm_unreachable("Failed to update loop bounds/steps");
+
+      // Update induction variable types.
+      for (auto [idx, typeAndCast] : narrowings) {
+        auto [targetType, castKind] = typeAndCast;
+        Value indVar = (*inductionVars)[idx];
+        auto blockArg = cast<BlockArgument>(indVar);
+
+        // Change the block argument type.
+        blockArg.setType(targetType);
+      }
+    });
+
+    // Insert casts back to original type for uses.
+    for (auto [narrowingIdx, narrowingInfo] : llvm::enumerate(narrowings)) {
+      auto [idx, typeAndCast] = narrowingInfo;
+      auto [targetType, castKind] = typeAndCast;
+      Value indVar = (*inductionVars)[idx];
+      auto blockArg = cast<BlockArgument>(indVar);
+      Type origType = origTypes[narrowingIdx];
+
+      OpBuilder::InsertionGuard guard(rewriter);
+      rewriter.setInsertionPointToStart(blockArg.getOwner());
+      Value casted = doCast(rewriter, loc, blockArg, origType, castKind);
+      copyIntegerRange(solver, blockArg, casted);
+
+      // Replace all uses of the narrowed indVar with the casted value.
+      rewriter.replaceAllUsesExcept(blockArg, casted, casted.getDefiningOp());
+    }
+
+    return success();
+  }
+
+private:
+  DataFlowSolver &solver;
+  SmallVector<unsigned, 4> targetBitwidths;
+};
+
 struct IntRangeOptimizationsPass final
     : arith::impl::ArithIntRangeOptsBase<IntRangeOptimizationsPass> {
 
@@ -520,6 +660,8 @@ struct IntRangeNarrowingPass final
 
     RewritePatternSet patterns(ctx);
     populateIntRangeNarrowingPatterns(patterns, solver, bitwidthsSupported);
+    populateControlFlowValuesNarrowingPatterns(patterns, solver,
+                                               bitwidthsSupported);
 
     // We specifically need bottom-up traversal as cmpi pattern needs range
     // data, attached to its original argument values.
@@ -548,6 +690,13 @@ void mlir::arith::populateIntRangeNarrowingPatterns(
                                                        bitwidthsSupported);
 }
 
+void mlir::arith::populateControlFlowValuesNarrowingPatterns(
+    RewritePatternSet &patterns, DataFlowSolver &solver,
+    ArrayRef<unsigned> bitwidthsSupported) {
+  patterns.add<NarrowLoopBounds>(patterns.getContext(), solver,
+                                 bitwidthsSupported);
+}
+
 std::unique_ptr<Pass> mlir::arith::createIntRangeOptimizationsPass() {
   return std::make_unique<IntRangeOptimizationsPass>();
 }
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 81982b5b8ef55..4d85e2f4bcf47 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -524,6 +524,36 @@ bool ForOp::isValidInductionVarType(Type type) {
   return type.isIndex() || type.isSignlessInteger();
 }
 
+LogicalResult ForOp::setLoopLowerBounds(ArrayRef<OpFoldResult> bounds) {
+  if (bounds.size() != 1)
+    return failure();
+  if (auto val = dyn_cast<Value>(bounds[0])) {
+    setLowerBound(val);
+    return success();
+  }
+  return failure();
+}
+
+LogicalResult ForOp::setLoopUpperBounds(ArrayRef<OpFoldResult> bounds) {
+  if (bounds.size() != 1)
+    return failure();
+  if (auto val = dyn_cast<Value>(bounds[0])) {
+    setUpperBound(val);
+    return success();
+  }
+  return failure();
+}
+
+LogicalResult ForOp::setLoopSteps(ArrayRef<OpFoldResult> steps) {
+  if (steps.size() != 1)
+    return failure();
+  if (auto val = dyn_cast<Value>(steps[0])) {
+    setStep(val);
+    return success();
+  }
+  return failure();
+}
+
 std::optional<ResultRange> ForOp::getLoopResults() { return getResults(); }
 
 /// Promotes the loop body of a forOp to its containing block if the forOp
diff --git a/mlir/test/Dialect/Arith/int-range-narrowing.mlir b/mlir/test/Dialect/Arith/int-range-narrowing.mlir
index 42dd4294b86e9..63e530c63df33 100644
--- a/mlir/test/Dialect/Arith/int-range-narrowing.mlir
+++ b/mlir/test/Dialect/Arith/int-range-narrowing.mlir
@@ -318,10 +318,10 @@ func.func @i32_overflows_to_i64(%arg0: i32) -> i64 {
 // CHECK-SAME: umax = 63
 // CHECK: %[[BOUND:.+]] = test.with_bounds
 // CHECK-SAME: umax = 112
-// CHECK: scf.for %[[ARG0:.+]] = %{{.*}} to %[[BOUND]] step %{{.*}}
+// CHECK: %[[BOUND_I8:.+]] = arith.index_castui %[[BOUND]] : index to i8
+// CHECK: scf.for %[[ARG0:.+]] = %{{.*}} to %[[BOUND_I8]] step %{{.*}}
 // CHECK-DAG:   %[[BOUND_I8:.+]] = arith.index_castui %[[BOUND]] : index to i8
-// CHECK-DAG:   %[[ARG0_I8:.+]] = arith.index_castui %[[ARG0]] : index to i8
-//     CHECK:   %[[V0_I8:.+]] = arith.subi %[[BOUND_I8]], %[[ARG0_I8]] : i8
+//     CHECK:   %[[V0_I8:.+]] = arith.subi %[[BOUND_I8]], %[[ARG0]] : i8
 //     CHECK:   %[[V1_I8:.+]] = arith.minsi %[[V0_I8]], %{{.*}} : i8
 //     CHECK:   %[[V1_INDEX:.+]] = arith.index_cast %[[V1_I8]] : i8 to index
 //     CHECK:   %[[V1_I16:.+]] = arith.index_cast %[[V1_INDEX]] : index to i16
@@ -349,6 +349,7 @@ func.func @clamp_to_loop_bound_and_id() {
 }
 
 func.func private @use(index)
+func.func private @use_i64(i64)
 
 // CHECK-LABEL: func.func @loop_with_iter_arg
 func.func @loop_with_iter_arg() {
@@ -377,3 +378,35 @@ func.func @loop_with_iter_arg() {
   }
   return
 }
+
+//===----------------------------------------------------------------------===//
+// Loop bounds narrowing
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: func.func @narrow_loop_bounds
+func.func @narrow_loop_bounds() {
+  %c0_i64 = arith.constant 0 : i64
+  %c10_i64 = arith.constant 10 : i64
+  %c1_i64 = arith.constant 1 : i64
+
+  // CHECK-DAG: %[[C1_I8:.*]] = arith.constant 1 : i8
+  // CHECK-DAG: %[[LB:.*]] = test.with_bounds {smax = 0 : i64, smin = 0 : i64, umax = 0 : i64, umin = 0 : i64} : i64
+  // CHECK-DAG: %[[UB:.*]] = test.with_bounds {smax = 10 : i64, smin = 10 : i64, umax = 10 : i64, umin = 10 : i64} : i64
+  // CHECK-DAG: %[[STEP:.*]] = test.with_bounds {smax = 1 : i64, smin = 1 : i64, umax = 1 : i64, umin = 1 : i64} : i64
+  %lb = test.with_bounds {smin = 0 : i64, smax = 0 : i64, umin = 0 : i64, umax = 0 : i64} : i64
+  %ub = test.with_bounds {smin = 10 : i64, smax = 10 : i64, umin = 10 : i64, umax = 10 : i64} : i64
+  %step = test.with_bounds {smin = 1 : i64, smax = 1 : i64, umin = 1 : i64, umax = 1 : i64} : i64
+
+  // CHECK-DAG: %[[LB_I8:.*]] = arith.trunci %[[LB]] : i64 to i8
+  // CHECK-DAG: %[[UB_I8:.*]] = arith.trunci %[[UB]] : i64 to i8
+  // CHECK-DAG: %[[STEP_I8:.*]] = arith.trunci %[[STEP]] : i64 to i8
+  // CHECK: scf.for %[[IV:.*]] = %[[LB_I8]] to %[[UB_I8]] step %[[STEP_I8]] : i8 {
+  // CHECK:   %[[ADD_I8:.*]] = arith.addi %[[IV]], %[[C1_I8]] : i8
+  // CHECK:   %[[ADD_I64:.*]] = arith.extui %[[ADD_I8]] : i8 to i64
+  // CHECK:   call @use_i64(%[[ADD_I64]])
+  scf.for %iv = %lb to %ub step %step : i64 {
+    %add = arith.addi %iv, %c1_i64 : i64
+    func.call @use_i64(%add) : (i64) -> ()
+  }
+  return
+}

>From 30be93bc0daf93f8ae1a902ad8d8321752aff68b Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sun, 11 Jan 2026 20:54:22 +0100
Subject: [PATCH 03/13] check iv overflow

---
 .../Transforms/IntRangeOptimizations.cpp      | 30 +++++++++++++++++++
 .../Dialect/Arith/int-range-narrowing.mlir    | 29 +++++++++++-------
 2 files changed, 48 insertions(+), 11 deletions(-)

diff --git a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
index 39038cc82553b..bf396f3bf1eb8 100644
--- a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
@@ -528,6 +528,11 @@ struct NarrowLoopBounds final : OpInterfaceRewritePattern<LoopLikeOpInterface> {
               solver, ValueRange{maybeLb, maybeUb, maybeStep, indVar}, ranges)))
         continue;
 
+      const ConstantIntRanges &lbRange = ranges[0];
+      const ConstantIntRanges &ubRange = ranges[1];
+      const ConstantIntRanges &stepRange = ranges[2];
+      const ConstantIntRanges &indVarRange = ranges[3];
+
       Type srcType = maybeLb.getType();
 
       // Try each target bitwidth.
@@ -553,6 +558,31 @@ struct NarrowLoopBounds final : OpInterfaceRewritePattern<LoopLikeOpInterface> {
         if (castKind == CastKind::None)
           continue;
 
+        // Additional check: ensure indVar + step doesn't overflow.
+        // During loop increment, we compute iv_next = iv_current + step in the
+        // narrowed type. If this overflows, the loop behavior becomes incorrect.
+        // We must check both signed and unsigned ranges since we don't know
+        // whether the loop semantics treat values as signed or unsigned.
+        unsigned srcWidth = lbRange.smin().getBitWidth();
+        unsigned removedWidth = srcWidth - targetBitwidth;
+
+        // Check that max(indVar) + step fits in the target type.
+        APInt indVarPlusStepSmin = indVarRange.smin().sadd_sat(stepRange.smin());
+        APInt indVarPlusStepSmax = indVarRange.smax().sadd_sat(stepRange.smax());
+        APInt indVarPlusStepUmin = indVarRange.umin().uadd_sat(stepRange.umin());
+        APInt indVarPlusStepUmax = indVarRange.umax().uadd_sat(stepRange.umax());
+
+        bool indVarPlusStepFitsSigned =
+            indVarPlusStepSmin.getNumSignBits() >= (removedWidth + 1) &&
+            indVarPlusStepSmax.getNumSignBits() >= (removedWidth + 1);
+        bool indVarPlusStepFitsUnsigned =
+            indVarPlusStepUmin.countLeadingZeros() >= removedWidth &&
+            indVarPlusStepUmax.countLeadingZeros() >= removedWidth;
+
+        // Both signed and unsigned must fit since loop semantics are unknown.
+        if (!indVarPlusStepFitsSigned || !indVarPlusStepFitsUnsigned)
+          continue;
+
         // Narrow the bounds and step values.
         Value newLb = doCast(rewriter, loc, maybeLb, targetType, castKind);
         Value newUb = doCast(rewriter, loc, maybeUb, targetType, castKind);
diff --git a/mlir/test/Dialect/Arith/int-range-narrowing.mlir b/mlir/test/Dialect/Arith/int-range-narrowing.mlir
index 63e530c63df33..9107bf649b561 100644
--- a/mlir/test/Dialect/Arith/int-range-narrowing.mlir
+++ b/mlir/test/Dialect/Arith/int-range-narrowing.mlir
@@ -314,21 +314,28 @@ func.func @i32_overflows_to_i64(%arg0: i32) -> i64 {
 // Motivating example for negative number support, added as a test case
 // and simplified
 // CHECK-LABEL: func.func @clamp_to_loop_bound_and_id()
+// CHECK-DAG: %[[C64_I8:.+]] = arith.constant 64 : i8
+// CHECK-DAG: %[[C64_I16:.+]] = arith.constant 64 : i16
+// CHECK-DAG: %[[C16_I16:.+]] = arith.constant 16 : i16
+// CHECK-DAG: %[[C0_I16:.+]] = arith.constant 0 : i16
 // CHECK: %[[TID:.+]] = test.with_bounds
 // CHECK-SAME: umax = 63
 // CHECK: %[[BOUND:.+]] = test.with_bounds
 // CHECK-SAME: umax = 112
-// CHECK: %[[BOUND_I8:.+]] = arith.index_castui %[[BOUND]] : index to i8
-// CHECK: scf.for %[[ARG0:.+]] = %{{.*}} to %[[BOUND_I8]] step %{{.*}}
-// CHECK-DAG:   %[[BOUND_I8:.+]] = arith.index_castui %[[BOUND]] : index to i8
-//     CHECK:   %[[V0_I8:.+]] = arith.subi %[[BOUND_I8]], %[[ARG0]] : i8
-//     CHECK:   %[[V1_I8:.+]] = arith.minsi %[[V0_I8]], %{{.*}} : i8
-//     CHECK:   %[[V1_INDEX:.+]] = arith.index_cast %[[V1_I8]] : i8 to index
-//     CHECK:   %[[V1_I16:.+]] = arith.index_cast %[[V1_INDEX]] : index to i16
-//     CHECK:   %[[TID_I16:.+]] = arith.index_castui %[[TID]] : index to i16
-//     CHECK:   %[[V2_I16:.+]] = arith.subi %[[V1_I16]], %[[TID_I16]] : i16
-//     CHECK:   %[[V3:.+]] = arith.cmpi slt, %[[V2_I16]], %{{.*}} : i16
-//     CHECK:   scf.if %[[V3]]
+// Loop narrows to i16 (not i8) because indVar+step=[80,144] doesn't fit in signed i8.
+// CHECK: %[[BOUND_I16:.+]] = arith.index_castui %[[BOUND]] : index to i16
+// CHECK: scf.for %[[ARG0:.+]] = %[[C16_I16]] to %[[BOUND_I16]] step %[[C64_I16]]  : i16 {
+// CHECK:   %[[ARG0_INDEX:.+]] = arith.index_castui %[[ARG0]] : i16 to index
+// CHECK:   %[[BOUND_I8:.+]] = arith.index_castui %[[BOUND]] : index to i8
+// CHECK:   %[[ARG0_I8:.+]] = arith.index_castui %[[ARG0_INDEX]] : index to i8
+// CHECK:   %[[V0_I8:.+]] = arith.subi %[[BOUND_I8]], %[[ARG0_I8]] : i8
+// CHECK:   %[[V1_I8:.+]] = arith.minsi %[[V0_I8]], %[[C64_I8]] : i8
+// CHECK:   %[[V1_INDEX:.+]] = arith.index_cast %[[V1_I8]] : i8 to index
+// CHECK:   %[[V1_I16:.+]] = arith.index_cast %[[V1_INDEX]] : index to i16
+// CHECK:   %[[TID_I16:.+]] = arith.index_castui %[[TID]] : index to i16
+// CHECK:   %[[V2_I16:.+]] = arith.subi %[[V1_I16]], %[[TID_I16]] : i16
+// CHECK:   %[[V3:.+]] = arith.cmpi slt, %[[V2_I16]], %[[C0_I16]] : i16
+// CHECK:   scf.if %[[V3]]
 func.func @clamp_to_loop_bound_and_id() {
   %c0 = arith.constant 0 : index
   %c16 = arith.constant 16 : index

>From 1a1b53c534923944ab84156915c20008d1fd8999 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sun, 11 Jan 2026 21:00:42 +0100
Subject: [PATCH 04/13] handle failure

---
 .../Transforms/IntRangeOptimizations.cpp      | 20 +++++++++++++++++--
 1 file changed, 18 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
index bf396f3bf1eb8..01b8ed7ccae54 100644
--- a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
@@ -486,6 +486,10 @@ struct NarrowLoopBounds final : OpInterfaceRewritePattern<LoopLikeOpInterface> {
 
   LogicalResult matchAndRewrite(LoopLikeOpInterface loopLike,
                                 PatternRewriter &rewriter) const override {
+    // Skip ops where bounds narrowing previously failed.
+    if (loopLike->hasAttr("arith.bounds_narrowing_failed"))
+      return failure();
+
     auto inductionVars = loopLike.getLoopInductionVars();
     if (!inductionVars.has_value() || inductionVars->empty())
       return rewriter.notifyMatchFailure(loopLike, "no induction variables");
@@ -606,12 +610,21 @@ struct NarrowLoopBounds final : OpInterfaceRewritePattern<LoopLikeOpInterface> {
       origTypes.push_back(indVar.getType());
     }
 
+    // Attempt to update bounds and induction variable types.
+    // If this fails, mark the op so we don't try again.
+    bool updateFailed = false;
     rewriter.modifyOpInPlace(loopLike, [&]() {
       // Update the loop bounds and steps.
       if (failed(loopLike.setLoopLowerBounds(newLowerBounds)) ||
           failed(loopLike.setLoopUpperBounds(newUpperBounds)) ||
-          failed(loopLike.setLoopSteps(newSteps)))
-        llvm_unreachable("Failed to update loop bounds/steps");
+          failed(loopLike.setLoopSteps(newSteps))) {
+        // Mark op to prevent future attempts. IR was modified (attribute
+        // added), so we must return success() from the pattern.
+        loopLike->setAttr("arith.bounds_narrowing_failed",
+                          rewriter.getUnitAttr());
+        updateFailed = true;
+        return;
+      }
 
       // Update induction variable types.
       for (auto [idx, typeAndCast] : narrowings) {
@@ -624,6 +637,9 @@ struct NarrowLoopBounds final : OpInterfaceRewritePattern<LoopLikeOpInterface> {
       }
     });
 
+    if (updateFailed)
+      return success();
+
     // Insert casts back to original type for uses.
     for (auto [narrowingIdx, narrowingInfo] : llvm::enumerate(narrowings)) {
       auto [idx, typeAndCast] = narrowingInfo;

>From 13522414ebaab442f9d2cda2a7a329c9aaaf5f8a Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sun, 11 Jan 2026 21:18:09 +0100
Subject: [PATCH 05/13] lib dep

---
 mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt | 1 +
 1 file changed, 1 insertion(+)

diff --git a/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt
index 4780dbb4338bb..637e16a3963d6 100644
--- a/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt
@@ -26,6 +26,7 @@ add_mlir_dialect_library(MLIRArithTransforms
   MLIRFuncTransforms
   MLIRInferIntRangeInterface
   MLIRIR
+  MLIRLoopLikeInterface
   MLIRMemRefDialect
   MLIRShardDialect
   MLIRPass

>From 7d12547c9921c1ce3603b6ab8f95da3afe7ab9bd Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sun, 11 Jan 2026 21:21:22 +0100
Subject: [PATCH 06/13] format

---
 .../Transforms/IntRangeOptimizations.cpp      | 19 ++++++++++++-------
 1 file changed, 12 insertions(+), 7 deletions(-)

diff --git a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
index 01b8ed7ccae54..d46b888247bc8 100644
--- a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
@@ -564,17 +564,22 @@ struct NarrowLoopBounds final : OpInterfaceRewritePattern<LoopLikeOpInterface> {
 
         // Additional check: ensure indVar + step doesn't overflow.
         // During loop increment, we compute iv_next = iv_current + step in the
-        // narrowed type. If this overflows, the loop behavior becomes incorrect.
-        // We must check both signed and unsigned ranges since we don't know
-        // whether the loop semantics treat values as signed or unsigned.
+        // narrowed type. If this overflows, the loop behavior becomes
+        // incorrect. We must check both signed and unsigned ranges since we
+        // don't know whether the loop semantics treat values as signed or
+        // unsigned.
         unsigned srcWidth = lbRange.smin().getBitWidth();
         unsigned removedWidth = srcWidth - targetBitwidth;
 
         // Check that max(indVar) + step fits in the target type.
-        APInt indVarPlusStepSmin = indVarRange.smin().sadd_sat(stepRange.smin());
-        APInt indVarPlusStepSmax = indVarRange.smax().sadd_sat(stepRange.smax());
-        APInt indVarPlusStepUmin = indVarRange.umin().uadd_sat(stepRange.umin());
-        APInt indVarPlusStepUmax = indVarRange.umax().uadd_sat(stepRange.umax());
+        APInt indVarPlusStepSmin =
+            indVarRange.smin().sadd_sat(stepRange.smin());
+        APInt indVarPlusStepSmax =
+            indVarRange.smax().sadd_sat(stepRange.smax());
+        APInt indVarPlusStepUmin =
+            indVarRange.umin().uadd_sat(stepRange.umin());
+        APInt indVarPlusStepUmax =
+            indVarRange.umax().uadd_sat(stepRange.umax());
 
         bool indVarPlusStepFitsSigned =
             indVarPlusStepSmin.getNumSignBits() >= (removedWidth + 1) &&

>From 1d042644da8ea7ddbed76305106b690a593c9f92 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sun, 11 Jan 2026 21:35:12 +0100
Subject: [PATCH 07/13] unused var

---
 mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp | 1 -
 1 file changed, 1 deletion(-)

diff --git a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
index d46b888247bc8..c31570dbc8c0a 100644
--- a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
@@ -533,7 +533,6 @@ struct NarrowLoopBounds final : OpInterfaceRewritePattern<LoopLikeOpInterface> {
         continue;
 
       const ConstantIntRanges &lbRange = ranges[0];
-      const ConstantIntRanges &ubRange = ranges[1];
       const ConstantIntRanges &stepRange = ranges[2];
       const ConstantIntRanges &indVarRange = ranges[3];
 

>From 2d6f62911abb22ade28313e89a96477a57096f9c Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Mon, 12 Jan 2026 01:01:12 +0100
Subject: [PATCH 08/13] cache failed attr

---
 .../Dialect/Arith/Transforms/IntRangeOptimizations.cpp | 10 ++++++----
 1 file changed, 6 insertions(+), 4 deletions(-)

diff --git a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
index c31570dbc8c0a..96829cdbe5a5a 100644
--- a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
@@ -482,12 +482,14 @@ struct NarrowLoopBounds final : OpInterfaceRewritePattern<LoopLikeOpInterface> {
   NarrowLoopBounds(MLIRContext *context, DataFlowSolver &s,
                    ArrayRef<unsigned> target)
       : OpInterfaceRewritePattern<LoopLikeOpInterface>(context), solver(s),
-        targetBitwidths(target) {}
+        targetBitwidths(target),
+        boundsNarrowingFailedAttr(
+            StringAttr::get(context, "arith.bounds_narrowing_failed")) {}
 
   LogicalResult matchAndRewrite(LoopLikeOpInterface loopLike,
                                 PatternRewriter &rewriter) const override {
     // Skip ops where bounds narrowing previously failed.
-    if (loopLike->hasAttr("arith.bounds_narrowing_failed"))
+    if (loopLike->hasAttr(boundsNarrowingFailedAttr))
       return failure();
 
     auto inductionVars = loopLike.getLoopInductionVars();
@@ -624,8 +626,7 @@ struct NarrowLoopBounds final : OpInterfaceRewritePattern<LoopLikeOpInterface> {
           failed(loopLike.setLoopSteps(newSteps))) {
         // Mark op to prevent future attempts. IR was modified (attribute
         // added), so we must return success() from the pattern.
-        loopLike->setAttr("arith.bounds_narrowing_failed",
-                          rewriter.getUnitAttr());
+        loopLike->setAttr(boundsNarrowingFailedAttr, rewriter.getUnitAttr());
         updateFailed = true;
         return;
       }
@@ -667,6 +668,7 @@ struct NarrowLoopBounds final : OpInterfaceRewritePattern<LoopLikeOpInterface> {
 private:
   DataFlowSolver &solver;
   SmallVector<unsigned, 4> targetBitwidths;
+  StringAttr boundsNarrowingFailedAttr;
 };
 
 struct IntRangeOptimizationsPass final

>From b4f8c636e23e616bfa6c41a4edc76c13e5b129d7 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Mon, 12 Jan 2026 01:02:31 +0100
Subject: [PATCH 09/13] error msg

---
 mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp | 5 +++--
 1 file changed, 3 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
index 96829cdbe5a5a..c6a4594481b16 100644
--- a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
@@ -490,7 +490,8 @@ struct NarrowLoopBounds final : OpInterfaceRewritePattern<LoopLikeOpInterface> {
                                 PatternRewriter &rewriter) const override {
     // Skip ops where bounds narrowing previously failed.
     if (loopLike->hasAttr(boundsNarrowingFailedAttr))
-      return failure();
+      return rewriter.notifyMatchFailure(loopLike,
+                                         "bounds narrowing previously failed");
 
     auto inductionVars = loopLike.getLoopInductionVars();
     if (!inductionVars.has_value() || inductionVars->empty())
@@ -607,7 +608,7 @@ struct NarrowLoopBounds final : OpInterfaceRewritePattern<LoopLikeOpInterface> {
     }
 
     if (narrowings.empty())
-      return failure();
+      return rewriter.notifyMatchFailure(loopLike, "no narrowings found");
 
     // Save original types before modifying.
     SmallVector<Type> origTypes;

>From ee75ace7d60d72eb3f4b2aaa4d2a6aeee7c3229f Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Mon, 12 Jan 2026 01:04:12 +0100
Subject: [PATCH 10/13] spell types

---
 .../Arith/Transforms/IntRangeOptimizations.cpp        | 11 +++++++----
 1 file changed, 7 insertions(+), 4 deletions(-)

diff --git a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
index c6a4594481b16..e9b90e48a2b27 100644
--- a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
@@ -493,13 +493,16 @@ struct NarrowLoopBounds final : OpInterfaceRewritePattern<LoopLikeOpInterface> {
       return rewriter.notifyMatchFailure(loopLike,
                                          "bounds narrowing previously failed");
 
-    auto inductionVars = loopLike.getLoopInductionVars();
+    std::optional<SmallVector<Value>> inductionVars =
+        loopLike.getLoopInductionVars();
     if (!inductionVars.has_value() || inductionVars->empty())
       return rewriter.notifyMatchFailure(loopLike, "no induction variables");
 
-    auto lowerBounds = loopLike.getLoopLowerBounds();
-    auto upperBounds = loopLike.getLoopUpperBounds();
-    auto steps = loopLike.getLoopSteps();
+    std::optional<SmallVector<OpFoldResult>> lowerBounds =
+        loopLike.getLoopLowerBounds();
+    std::optional<SmallVector<OpFoldResult>> upperBounds =
+        loopLike.getLoopUpperBounds();
+    std::optional<SmallVector<OpFoldResult>> steps = loopLike.getLoopSteps();
 
     if (!lowerBounds.has_value() || !upperBounds.has_value() ||
         !steps.has_value())

>From 71407f487699576abe890dc5d1e45d50dc37370d Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Mon, 12 Jan 2026 01:07:46 +0100
Subject: [PATCH 11/13] std::tuple

---
 .../Arith/Transforms/IntRangeOptimizations.cpp       | 12 +++++-------
 1 file changed, 5 insertions(+), 7 deletions(-)

diff --git a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
index e9b90e48a2b27..f674731e1ace0 100644
--- a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
@@ -518,7 +518,7 @@ struct NarrowLoopBounds final : OpInterfaceRewritePattern<LoopLikeOpInterface> {
     SmallVector<OpFoldResult> newLowerBounds(*lowerBounds);
     SmallVector<OpFoldResult> newUpperBounds(*upperBounds);
     SmallVector<OpFoldResult> newSteps(*steps);
-    SmallVector<std::pair<size_t, std::pair<Type, CastKind>>> narrowings;
+    SmallVector<std::tuple<size_t, Type, CastKind>> narrowings;
 
     // Check each (indVar, lb, ub, step) tuple.
     for (auto [idx, indVar, lbOFR, ubOFR, stepOFR] :
@@ -605,7 +605,7 @@ struct NarrowLoopBounds final : OpInterfaceRewritePattern<LoopLikeOpInterface> {
         newLowerBounds[idx] = newLb;
         newUpperBounds[idx] = newUb;
         newSteps[idx] = newStep;
-        narrowings.push_back({idx, {targetType, castKind}});
+        narrowings.push_back({idx, targetType, castKind});
         break;
       }
     }
@@ -615,7 +615,7 @@ struct NarrowLoopBounds final : OpInterfaceRewritePattern<LoopLikeOpInterface> {
 
     // Save original types before modifying.
     SmallVector<Type> origTypes;
-    for (auto [idx, typeAndCast] : narrowings) {
+    for (auto [idx, targetType, castKind] : narrowings) {
       Value indVar = (*inductionVars)[idx];
       origTypes.push_back(indVar.getType());
     }
@@ -636,8 +636,7 @@ struct NarrowLoopBounds final : OpInterfaceRewritePattern<LoopLikeOpInterface> {
       }
 
       // Update induction variable types.
-      for (auto [idx, typeAndCast] : narrowings) {
-        auto [targetType, castKind] = typeAndCast;
+      for (auto [idx, targetType, castKind] : narrowings) {
         Value indVar = (*inductionVars)[idx];
         auto blockArg = cast<BlockArgument>(indVar);
 
@@ -651,8 +650,7 @@ struct NarrowLoopBounds final : OpInterfaceRewritePattern<LoopLikeOpInterface> {
 
     // Insert casts back to original type for uses.
     for (auto [narrowingIdx, narrowingInfo] : llvm::enumerate(narrowings)) {
-      auto [idx, typeAndCast] = narrowingInfo;
-      auto [targetType, castKind] = typeAndCast;
+      auto [idx, targetType, castKind] = narrowingInfo;
       Value indVar = (*inductionVars)[idx];
       auto blockArg = cast<BlockArgument>(indVar);
       Type origType = origTypes[narrowingIdx];

>From e32266ab0c0f2490f8eeff4dfbe170fd9f6894c6 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Mon, 12 Jan 2026 01:19:21 +0100
Subject: [PATCH 12/13] refac truncabolity check

---
 .../Transforms/IntRangeOptimizations.cpp      | 44 +++++++------------
 1 file changed, 15 insertions(+), 29 deletions(-)

diff --git a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
index f674731e1ace0..17f579d5918da 100644
--- a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
@@ -538,7 +538,6 @@ struct NarrowLoopBounds final : OpInterfaceRewritePattern<LoopLikeOpInterface> {
               solver, ValueRange{maybeLb, maybeUb, maybeStep, indVar}, ranges)))
         continue;
 
-      const ConstantIntRanges &lbRange = ranges[0];
       const ConstantIntRanges &stepRange = ranges[2];
       const ConstantIntRanges &indVarRange = ranges[3];
 
@@ -567,34 +566,21 @@ struct NarrowLoopBounds final : OpInterfaceRewritePattern<LoopLikeOpInterface> {
         if (castKind == CastKind::None)
           continue;
 
-        // Additional check: ensure indVar + step doesn't overflow.
-        // During loop increment, we compute iv_next = iv_current + step in the
-        // narrowed type. If this overflows, the loop behavior becomes
-        // incorrect. We must check both signed and unsigned ranges since we
-        // don't know whether the loop semantics treat values as signed or
-        // unsigned.
-        unsigned srcWidth = lbRange.smin().getBitWidth();
-        unsigned removedWidth = srcWidth - targetBitwidth;
-
-        // Check that max(indVar) + step fits in the target type.
-        APInt indVarPlusStepSmin =
-            indVarRange.smin().sadd_sat(stepRange.smin());
-        APInt indVarPlusStepSmax =
-            indVarRange.smax().sadd_sat(stepRange.smax());
-        APInt indVarPlusStepUmin =
-            indVarRange.umin().uadd_sat(stepRange.umin());
-        APInt indVarPlusStepUmax =
-            indVarRange.umax().uadd_sat(stepRange.umax());
-
-        bool indVarPlusStepFitsSigned =
-            indVarPlusStepSmin.getNumSignBits() >= (removedWidth + 1) &&
-            indVarPlusStepSmax.getNumSignBits() >= (removedWidth + 1);
-        bool indVarPlusStepFitsUnsigned =
-            indVarPlusStepUmin.countLeadingZeros() >= removedWidth &&
-            indVarPlusStepUmax.countLeadingZeros() >= removedWidth;
-
-        // Both signed and unsigned must fit since loop semantics are unknown.
-        if (!indVarPlusStepFitsSigned || !indVarPlusStepFitsUnsigned)
+        // Check if indVar + step fits in the narrowed type.
+        // This is critical for loop correctness: the loop computes
+        // iv_next = iv_current + step in the narrowed type, then compares
+        // iv_next < ub. If iv_current + step overflows, the comparison may
+        // produce incorrect results and break loop termination.
+        // Both signed and unsigned interpretations must fit because loop
+        // semantics are unknown (e.g., index type is signless).
+        ConstantIntRanges indVarPlusStepRange(
+            indVarRange.smin().sadd_sat(stepRange.smin()),
+            indVarRange.smax().sadd_sat(stepRange.smax()),
+            indVarRange.umin().uadd_sat(stepRange.umin()),
+            indVarRange.umax().uadd_sat(stepRange.umax()));
+
+        if (checkTruncatability(indVarPlusStepRange, targetBitwidth) !=
+            CastKind::Both)
           continue;
 
         // Narrow the bounds and step values.

>From a98578666b25cb4753b08ffa187d3f111b46c557 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Mon, 12 Jan 2026 19:41:08 +0100
Subject: [PATCH 13/13] fix comment

---
 mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
index 17f579d5918da..fefbba989b996 100644
--- a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
@@ -572,7 +572,7 @@ struct NarrowLoopBounds final : OpInterfaceRewritePattern<LoopLikeOpInterface> {
         // iv_next < ub. If iv_current + step overflows, the comparison may
         // produce incorrect results and break loop termination.
         // Both signed and unsigned interpretations must fit because loop
-        // semantics are unknown (e.g., index type is signless).
+        // semantics are unknown (integer types are signless).
         ConstantIntRanges indVarPlusStepRange(
             indVarRange.smin().sadd_sat(stepRange.smin()),
             indVarRange.smax().sadd_sat(stepRange.smax()),



More information about the Mlir-commits mailing list