[Mlir-commits] [mlir] [mlir][arith] Move canonicalization patterns to fold (PR #184381)

Erick Ochoa Lopez llvmlistbot at llvm.org
Tue Mar 3 12:52:49 PST 2026


https://github.com/amd-eochoalo updated https://github.com/llvm/llvm-project/pull/184381

>From ca1cdd5f808f93fd0a256ad35fa21d560162e45d Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Tue, 3 Mar 2026 11:04:30 -0500
Subject: [PATCH 1/3] Move canonicalization to fold

---
 .../Dialect/Arith/IR/ArithCanonicalization.td |  9 -----
 mlir/lib/Dialect/Arith/IR/ArithOps.cpp        | 10 ++++-
 .../Transforms/IntRangeOptimizations.cpp      | 12 +++---
 mlir/test/Dialect/Arith/canonicalize.mlir     | 37 +++++++++++++++++++
 .../Dialect/Arith/int-range-narrowing.mlir    | 12 +++---
 mlir/test/Transforms/canonicalize.mlir        |  9 -----
 6 files changed, 59 insertions(+), 30 deletions(-)

diff --git a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
index f26af4816ce85..baabbca99e37c 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
+++ b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
@@ -288,15 +288,6 @@ def SelectI1ToNot :
 // IndexCastOp
 //===----------------------------------------------------------------------===//
 
-// index_cast(index_cast(x, exact)) -> x, if dstType == srcType.
-// The inner exact guarantees the iN -> index conversion is lossless,
-// so the roundtrip through index preserves the value.
-def IndexCastOfIndexCast :
-    Pat<(Arith_IndexCastOp:$res (Arith_IndexCastOp $x, $exact1), $exact2),
-        (replaceWithValue $x),
-        [(Constraint<CPred<"$0.getType() == $1.getType()">> $res, $x),
-         (Constraint<CPred<"(bool)$0">> $exact1)]>;
-
 // index_cast(extsi(x)) -> index_cast(x)
 def IndexCastOfExtSI :
     Pat<(Arith_IndexCastOp (Arith_ExtSIOp $x), $exact),
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index b99f77fdc8b30..ad837bea01510 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -1845,6 +1845,14 @@ bool arith::IndexCastOp::areCastCompatible(TypeRange inputs,
 }
 
 OpFoldResult arith::IndexCastOp::fold(FoldAdaptor adaptor) {
+  // index_cast(index_cast(x, exact)) -> x, when result type matches x's type.
+  // The inner exact guarantees the iN -> index conversion is lossless,
+  // so the roundtrip through index preserves the value.
+  if (auto innerCast = getIn().getDefiningOp<IndexCastOp>()) {
+    if (innerCast.getIn().getType() == getType() && innerCast.getIsExact())
+      return innerCast.getIn();
+  }
+
   // index_cast(constant) -> constant
   unsigned resultBitwidth = 64; // Default for index integer attributes.
   if (auto intTy = dyn_cast<IntegerType>(getElementTypeOrSelf(getType())))
@@ -1859,7 +1867,7 @@ OpFoldResult arith::IndexCastOp::fold(FoldAdaptor adaptor) {
 
 void arith::IndexCastOp::getCanonicalizationPatterns(
     RewritePatternSet &patterns, MLIRContext *context) {
-  patterns.add<IndexCastOfIndexCast, IndexCastOfExtSI>(context);
+  patterns.add<IndexCastOfExtSI>(context);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
index fefbba989b996..df3eea25a8d20 100644
--- a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
@@ -305,8 +305,11 @@ static Value doCast(OpBuilder &builder, Location loc, Value src, Type dstType,
     return src;
 
   if (isa<IndexType>(srcElemType) || isa<IndexType>(dstElemType)) {
-    if (castKind == CastKind::Signed)
-      return arith::IndexCastOp::create(builder, loc, dstType, src);
+    if (castKind == CastKind::Signed) {
+      auto cast = arith::IndexCastOp::create(builder, loc, dstType, src);
+      cast.setExact(true);
+      return cast;
+    }
     return arith::IndexCastUIOp::create(builder, loc, dstType, src);
   }
 
@@ -725,9 +728,8 @@ void mlir::arith::populateIntRangeNarrowingPatterns(
     ArrayRef<unsigned> bitwidthsSupported) {
   patterns.add<NarrowElementwise, NarrowCmpI>(patterns.getContext(), solver,
                                               bitwidthsSupported);
-  patterns.add<FoldIndexCastChain<arith::IndexCastUIOp>,
-               FoldIndexCastChain<arith::IndexCastOp>>(patterns.getContext(),
-                                                       bitwidthsSupported);
+  patterns.add<FoldIndexCastChain<arith::IndexCastUIOp>>(patterns.getContext(),
+                                                        bitwidthsSupported);
 }
 
 void mlir::arith::populateControlFlowValuesNarrowingPatterns(
diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index a2d0eff47ad92..0a5a7cc56517a 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -643,6 +643,43 @@ func.func @indexCastUIOfUnsignedExtend_nneg_exact(%arg0: i8) -> index {
   return %idx : index
 }
 
+// index_cast(index_cast(x)) -> x only when exact is on the inner cast.
+// CHECK-LABEL: @indexCastOfIndexCast_no_exact
+//       CHECK:   arith.index_cast
+//       CHECK:   arith.index_cast
+func.func @indexCastOfIndexCast_no_exact(%arg0: i16) -> i16 {
+  %idx = arith.index_cast %arg0 : i16 to index
+  %res = arith.index_cast %idx : index to i16
+  return %res : i16
+}
+
+// CHECK-LABEL: @indexCastOfIndexCast_exact_inner
+//       CHECK:   return %arg0 : i16
+func.func @indexCastOfIndexCast_exact_inner(%arg0: i16) -> i16 {
+  %idx = arith.index_cast %arg0 exact : i16 to index
+  %res = arith.index_cast %idx : index to i16
+  return %res : i16
+}
+
+// exact on outer only does NOT trigger the fold (outer exact on widening
+// is vacuously true and does not guarantee the inner truncation is lossless).
+// CHECK-LABEL: @indexCastOfIndexCast_exact_outer
+//       CHECK:   arith.index_cast
+//       CHECK:   arith.index_cast
+func.func @indexCastOfIndexCast_exact_outer(%arg0: i16) -> i16 {
+  %idx = arith.index_cast %arg0 : i16 to index
+  %res = arith.index_cast %idx exact : index to i16
+  return %res : i16
+}
+
+// CHECK-LABEL: @indexCastOfIndexCast_exact_both
+//       CHECK:   return %arg0 : i16
+func.func @indexCastOfIndexCast_exact_both(%arg0: i16) -> i16 {
+  %idx = arith.index_cast %arg0 exact : i16 to index
+  %res = arith.index_cast %idx exact : index to i16
+  return %res : i16
+}
+
 // index_castui(index_castui(x)) -> x only when exact is on the inner cast.
 // CHECK-LABEL: @indexCastUIOfIndexCastUI_no_exact
 //       CHECK:   arith.index_castui
diff --git a/mlir/test/Dialect/Arith/int-range-narrowing.mlir b/mlir/test/Dialect/Arith/int-range-narrowing.mlir
index c3b0d280b1350..5a26d1d0e7fa4 100644
--- a/mlir/test/Dialect/Arith/int-range-narrowing.mlir
+++ b/mlir/test/Dialect/Arith/int-range-narrowing.mlir
@@ -9,9 +9,9 @@
 //       CHECK:  %[[POS:.*]] = test.with_bounds {smax = 1 : index, smin = 0 : index, umax = 1 : index, umin = 0 : index} : index
 //       CHECK:  %[[NEG:.*]] = test.with_bounds {smax = 0 : index, smin = -1 : index, umax = -1 : index, umin = 0 : index} : index
 //       CHECK:  %[[POS_I8:.*]] = arith.index_castui %[[POS]] : index to i8
-//       CHECK:  %[[NEG_I8:.*]] = arith.index_cast %[[NEG]] : index to i8
+//       CHECK:  %[[NEG_I8:.*]] = arith.index_cast %[[NEG]] exact : index to i8
 //       CHECK:  %[[RES_I8:.*]] = arith.addi %[[POS_I8]], %[[NEG_I8]] : i8
-//       CHECK:  %[[RES:.*]] = arith.index_cast %[[RES_I8]] : i8 to index
+//       CHECK:  %[[RES:.*]] = arith.index_cast %[[RES_I8]] exact : i8 to index
 //       CHECK:  return %[[RES]] : index
 func.func @test_addi_neg() -> index {
   %0 = test.with_bounds { umin = 0 : index, umax = 1 : index, smin = 0 : index, smax = 1 : index } : index
@@ -330,8 +330,8 @@ func.func @i32_overflows_to_i64(%arg0: i32) -> i64 {
 // CHECK:   %[[ARG0_I8:.+]] = arith.index_castui %[[ARG0_INDEX]] : index to i8
 // CHECK:   %[[V0_I8:.+]] = arith.subi %[[BOUND_I8]], %[[ARG0_I8]] : i8
 // CHECK:   %[[V1_I8:.+]] = arith.minsi %[[V0_I8]], %[[C64_I8]] : i8
-// CHECK:   %[[V1_INDEX:.+]] = arith.index_cast %[[V1_I8]] : i8 to index
-// CHECK:   %[[V1_I16:.+]] = arith.index_cast %[[V1_INDEX]] : index to i16
+// CHECK:   %[[V1_INDEX:.+]] = arith.index_cast %[[V1_I8]] exact : i8 to index
+// CHECK:   %[[V1_I16:.+]] = arith.index_cast %[[V1_INDEX]] exact : index to i16
 // CHECK:   %[[TID_I16:.+]] = arith.index_castui %[[TID]] : index to i16
 // CHECK:   %[[V2_I16:.+]] = arith.subi %[[V1_I16]], %[[TID_I16]] : i16
 // CHECK:   %[[V3:.+]] = arith.cmpi slt, %[[V2_I16]], %[[C0_I16]] : i16
@@ -371,9 +371,9 @@ func.func @loop_with_iter_arg() {
 // Check iter args are still present
 //       CHECK:  scf.for {{.*}} iter_args({{.*}})
 //       CHECK:  %[[POS_I8:.*]] = arith.index_castui %[[POS]] : index to i8
-//       CHECK:  %[[NEG_I8:.*]] = arith.index_cast %[[NEG]] : index to i8
+//       CHECK:  %[[NEG_I8:.*]] = arith.index_cast %[[NEG]] exact : index to i8
 //       CHECK:  %[[RES_I8:.*]] = arith.addi %[[POS_I8]], %[[NEG_I8]] : i8
-//       CHECK:  %[[RES:.*]] = arith.index_cast %[[RES_I8]] : i8 to index
+//       CHECK:  %[[RES:.*]] = arith.index_cast %[[RES_I8]] exact : i8 to index
 //       CHECK:  call @use(%[[RES]])
 
   %0 = test.with_bounds { umin = 0 : index, umax = 1 : index, smin = 0 : index, smax = 1 : index } : index
diff --git a/mlir/test/Transforms/canonicalize.mlir b/mlir/test/Transforms/canonicalize.mlir
index 3930575c45b3e..3ea6619ee589d 100644
--- a/mlir/test/Transforms/canonicalize.mlir
+++ b/mlir/test/Transforms/canonicalize.mlir
@@ -892,15 +892,6 @@ func.func @subview(%arg0 : index, %arg1 : index) -> (index, index) {
   return %7, %8 : index, index
 }
 
-// CHECK-LABEL: func @index_cast
-// CHECK-SAME: %[[ARG_0:arg[0-9]+]]: i16
-func.func @index_cast(%arg0: i16) -> (i16) {
-  %11 = arith.index_cast %arg0 exact : i16 to index
-  %12 = arith.index_cast %11 : index to i16
-  // CHECK: return %[[ARG_0]] : i16
-  return %12 : i16
-}
-
 // CHECK-LABEL: func @index_cast_fold
 func.func @index_cast_fold() -> (i16, index) {
   %c4 = arith.constant 4 : index

>From cc13e2161e9dd42bd3bfdb59e7cb4f7688050509 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Tue, 3 Mar 2026 11:23:15 -0500
Subject: [PATCH 2/3] Move cast_ui to fold

---
 .../Dialect/Arith/IR/ArithCanonicalization.td | 10 -----
 mlir/lib/Dialect/Arith/IR/ArithOps.cpp        |  8 +++-
 .../Transforms/IntRangeOptimizations.cpp      | 39 ++----------------
 .../Dialect/Arith/int-range-narrowing.mlir    | 40 +++++++++----------
 4 files changed, 30 insertions(+), 67 deletions(-)

diff --git a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
index baabbca99e37c..0d87092b7721f 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
+++ b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
@@ -297,16 +297,6 @@ def IndexCastOfExtSI :
 // IndexCastUIOp
 //===----------------------------------------------------------------------===//
 
-// index_castui(index_castui(x, exact)) -> x, if dstType == srcType.
-// The inner exact guarantees the iN -> index conversion is lossless,
-// so the roundtrip through index preserves the value.
-def IndexCastUIOfIndexCastUI :
-    Pat<(Arith_IndexCastUIOp:$res
-          (Arith_IndexCastUIOp $x, $nneg1, $exact1), $nneg2, $exact2),
-        (replaceWithValue $x),
-        [(Constraint<CPred<"$0.getType() == $1.getType()">> $res, $x),
-         (Constraint<CPred<"static_cast<bool>($0)">> $exact1)]>;
-
 // index_castui(extui(x)) -> index_castui(x)
 def IndexCastUIOfExtUI :
     Pat<(Arith_IndexCastUIOp (Arith_ExtUIOp $x, $nneg1), $nneg2, $exact),
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index ad837bea01510..625e2c3bb13fc 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -1880,6 +1880,12 @@ bool arith::IndexCastUIOp::areCastCompatible(TypeRange inputs,
 }
 
 OpFoldResult arith::IndexCastUIOp::fold(FoldAdaptor adaptor) {
+  // index_castui(index_castui(x, exact)) -> x
+  if (auto innerCast = getIn().getDefiningOp<IndexCastUIOp>()) {
+    if (innerCast.getIn().getType() == getType() && innerCast.getIsExact())
+      return innerCast.getIn();
+  }
+
   // index_castui(constant) -> constant
   unsigned resultBitwidth = 64; // Default for index integer attributes.
   if (auto intTy = dyn_cast<IntegerType>(getElementTypeOrSelf(getType())))
@@ -1894,7 +1900,7 @@ OpFoldResult arith::IndexCastUIOp::fold(FoldAdaptor adaptor) {
 
 void arith::IndexCastUIOp::getCanonicalizationPatterns(
     RewritePatternSet &patterns, MLIRContext *context) {
-  patterns.add<IndexCastUIOfIndexCastUI, IndexCastUIOfExtUI>(context);
+  patterns.add<IndexCastUIOfExtUI>(context);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
index df3eea25a8d20..b7f6e525a6a74 100644
--- a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
@@ -310,7 +310,9 @@ static Value doCast(OpBuilder &builder, Location loc, Value src, Type dstType,
       cast.setExact(true);
       return cast;
     }
-    return arith::IndexCastUIOp::create(builder, loc, dstType, src);
+    auto cast = arith::IndexCastUIOp::create(builder, loc, dstType, src);
+    cast.setExact(true);
+    return cast;
   }
 
   auto srcInt = cast<IntegerType>(srcElemType);
@@ -448,39 +450,6 @@ struct NarrowCmpI final : OpRewritePattern<arith::CmpIOp> {
   SmallVector<unsigned, 4> targetBitwidths;
 };
 
-/// Fold index_cast(index_cast(%arg: i8, index), i8) -> %arg
-/// This pattern assumes all passed `targetBitwidths` are not wider than index
-/// type.
-template <typename CastOp>
-struct FoldIndexCastChain final : OpRewritePattern<CastOp> {
-  FoldIndexCastChain(MLIRContext *context, ArrayRef<unsigned> target)
-      : OpRewritePattern<CastOp>(context), targetBitwidths(target) {}
-
-  LogicalResult matchAndRewrite(CastOp op,
-                                PatternRewriter &rewriter) const override {
-    auto srcOp = op.getIn().template getDefiningOp<CastOp>();
-    if (!srcOp)
-      return rewriter.notifyMatchFailure(op, "doesn't come from an index cast");
-
-    Value src = srcOp.getIn();
-    if (src.getType() != op.getType())
-      return rewriter.notifyMatchFailure(op, "outer types don't match");
-
-    if (!srcOp.getType().isIndex())
-      return rewriter.notifyMatchFailure(op, "intermediate type isn't index");
-
-    auto intType = dyn_cast<IntegerType>(op.getType());
-    if (!intType || !llvm::is_contained(targetBitwidths, intType.getWidth()))
-      return failure();
-
-    rewriter.replaceOp(op, src);
-    return success();
-  }
-
-private:
-  SmallVector<unsigned, 4> targetBitwidths;
-};
-
 struct NarrowLoopBounds final : OpInterfaceRewritePattern<LoopLikeOpInterface> {
   NarrowLoopBounds(MLIRContext *context, DataFlowSolver &s,
                    ArrayRef<unsigned> target)
@@ -728,8 +697,6 @@ void mlir::arith::populateIntRangeNarrowingPatterns(
     ArrayRef<unsigned> bitwidthsSupported) {
   patterns.add<NarrowElementwise, NarrowCmpI>(patterns.getContext(), solver,
                                               bitwidthsSupported);
-  patterns.add<FoldIndexCastChain<arith::IndexCastUIOp>>(patterns.getContext(),
-                                                        bitwidthsSupported);
 }
 
 void mlir::arith::populateControlFlowValuesNarrowingPatterns(
diff --git a/mlir/test/Dialect/Arith/int-range-narrowing.mlir b/mlir/test/Dialect/Arith/int-range-narrowing.mlir
index 5a26d1d0e7fa4..e9c232df06b77 100644
--- a/mlir/test/Dialect/Arith/int-range-narrowing.mlir
+++ b/mlir/test/Dialect/Arith/int-range-narrowing.mlir
@@ -8,7 +8,7 @@
 // CHECK-LABEL: func @test_addi_neg
 //       CHECK:  %[[POS:.*]] = test.with_bounds {smax = 1 : index, smin = 0 : index, umax = 1 : index, umin = 0 : index} : index
 //       CHECK:  %[[NEG:.*]] = test.with_bounds {smax = 0 : index, smin = -1 : index, umax = -1 : index, umin = 0 : index} : index
-//       CHECK:  %[[POS_I8:.*]] = arith.index_castui %[[POS]] : index to i8
+//       CHECK:  %[[POS_I8:.*]] = arith.index_castui %[[POS]] exact : index to i8
 //       CHECK:  %[[NEG_I8:.*]] = arith.index_cast %[[NEG]] exact : index to i8
 //       CHECK:  %[[RES_I8:.*]] = arith.addi %[[POS_I8]], %[[NEG_I8]] : i8
 //       CHECK:  %[[RES:.*]] = arith.index_cast %[[RES_I8]] exact : i8 to index
@@ -23,10 +23,10 @@ func.func @test_addi_neg() -> index {
 // CHECK-LABEL: func @test_addi
 //       CHECK:  %[[A:.*]] = test.with_bounds {smax = 5 : index, smin = 4 : index, umax = 5 : index, umin = 4 : index} : index
 //       CHECK:  %[[B:.*]] = test.with_bounds {smax = 7 : index, smin = 6 : index, umax = 7 : index, umin = 6 : index} : index
-//       CHECK:  %[[A_CASTED:.*]] = arith.index_castui %[[A]] : index to i8
-//       CHECK:  %[[B_CASTED:.*]] = arith.index_castui %[[B]] : index to i8
+//       CHECK:  %[[A_CASTED:.*]] = arith.index_castui %[[A]] exact : index to i8
+//       CHECK:  %[[B_CASTED:.*]] = arith.index_castui %[[B]] exact : index to i8
 //       CHECK:  %[[RES:.*]] = arith.addi %[[A_CASTED]], %[[B_CASTED]] : i8
-//       CHECK:  %[[RES_CASTED:.*]] = arith.index_castui %[[RES]] : i8 to index
+//       CHECK:  %[[RES_CASTED:.*]] = arith.index_castui %[[RES]] exact : i8 to index
 //       CHECK:  return %[[RES_CASTED]] : index
 func.func @test_addi() -> index {
   %0 = test.with_bounds { umin = 4 : index, umax = 5 : index, smin = 4 : index, smax = 5 : index } : index
@@ -38,10 +38,10 @@ func.func @test_addi() -> index {
 // CHECK-LABEL: func @test_addi_vec
 //       CHECK:  %[[A:.*]] = test.with_bounds {smax = 5 : index, smin = 4 : index, umax = 5 : index, umin = 4 : index} : vector<4xindex>
 //       CHECK:  %[[B:.*]] = test.with_bounds {smax = 7 : index, smin = 6 : index, umax = 7 : index, umin = 6 : index} : vector<4xindex>
-//       CHECK:  %[[A_CASTED:.*]] = arith.index_castui %[[A]] : vector<4xindex> to vector<4xi8>
-//       CHECK:  %[[B_CASTED:.*]] = arith.index_castui %[[B]] : vector<4xindex> to vector<4xi8>
+//       CHECK:  %[[A_CASTED:.*]] = arith.index_castui %[[A]] exact : vector<4xindex> to vector<4xi8>
+//       CHECK:  %[[B_CASTED:.*]] = arith.index_castui %[[B]] exact : vector<4xindex> to vector<4xi8>
 //       CHECK:  %[[RES:.*]] = arith.addi %[[A_CASTED]], %[[B_CASTED]] : vector<4xi8>
-//       CHECK:  %[[RES_CASTED:.*]] = arith.index_castui %[[RES]] : vector<4xi8> to vector<4xindex>
+//       CHECK:  %[[RES_CASTED:.*]] = arith.index_castui %[[RES]] exact : vector<4xi8> to vector<4xindex>
 //       CHECK:  return %[[RES_CASTED]] : vector<4xindex>
 func.func @test_addi_vec() -> vector<4xindex> {
   %0 = test.with_bounds { umin = 4 : index, umax = 5 : index, smin = 4 : index, smax = 5 : index } : vector<4xindex>
@@ -68,8 +68,8 @@ func.func @test_addi_i64() -> i64 {
 // CHECK-LABEL: func @test_cmpi
 //       CHECK:  %[[A:.*]] = test.with_bounds {smax = 10 : index, smin = 0 : index, umax = 10 : index, umin = 0 : index} : index
 //       CHECK:  %[[B:.*]] = test.with_bounds {smax = 10 : index, smin = 0 : index, umax = 10 : index, umin = 0 : index} : index
-//       CHECK:  %[[A_CASTED:.*]] = arith.index_castui %[[A]] : index to i8
-//       CHECK:  %[[B_CASTED:.*]] = arith.index_castui %[[B]] : index to i8
+//       CHECK:  %[[A_CASTED:.*]] = arith.index_castui %[[A]] exact : index to i8
+//       CHECK:  %[[B_CASTED:.*]] = arith.index_castui %[[B]] exact : index to i8
 //       CHECK:  %[[RES:.*]] = arith.cmpi slt, %[[A_CASTED]], %[[B_CASTED]] : i8
 //       CHECK:  return %[[RES]] : i1
 func.func @test_cmpi() -> i1 {
@@ -82,8 +82,8 @@ func.func @test_cmpi() -> i1 {
 // CHECK-LABEL: func @test_cmpi_vec
 //       CHECK:  %[[A:.*]] = test.with_bounds {smax = 10 : index, smin = 0 : index, umax = 10 : index, umin = 0 : index} : vector<4xindex>
 //       CHECK:  %[[B:.*]] = test.with_bounds {smax = 10 : index, smin = 0 : index, umax = 10 : index, umin = 0 : index} : vector<4xindex>
-//       CHECK:  %[[A_CASTED:.*]] = arith.index_castui %[[A]] : vector<4xindex> to vector<4xi8>
-//       CHECK:  %[[B_CASTED:.*]] = arith.index_castui %[[B]] : vector<4xindex> to vector<4xi8>
+//       CHECK:  %[[A_CASTED:.*]] = arith.index_castui %[[A]] exact : vector<4xindex> to vector<4xi8>
+//       CHECK:  %[[B_CASTED:.*]] = arith.index_castui %[[B]] exact : vector<4xindex> to vector<4xi8>
 //       CHECK:  %[[RES:.*]] = arith.cmpi slt, %[[A_CASTED]], %[[B_CASTED]] : vector<4xi8>
 //       CHECK:  return %[[RES]] : vector<4xi1>
 func.func @test_cmpi_vec() -> vector<4xi1> {
@@ -97,10 +97,10 @@ func.func @test_cmpi_vec() -> vector<4xi1> {
 //       CHECK:  %[[A:.*]] = test.with_bounds {smax = 10 : index, smin = 0 : index, umax = 10 : index, umin = 0 : index} : index
 //       CHECK:  %[[B:.*]] = test.with_bounds {smax = 10 : index, smin = 0 : index, umax = 10 : index, umin = 0 : index} : index
 //       CHECK:  %[[C:.*]] = test.with_bounds {smax = 10 : index, smin = 0 : index, umax = 10 : index, umin = 0 : index} : index
-//       CHECK:  %[[A_CASTED:.*]] = arith.index_castui %[[A]] : index to i8
-//       CHECK:  %[[B_CASTED:.*]] = arith.index_castui %[[B]] : index to i8
+//       CHECK:  %[[A_CASTED:.*]] = arith.index_castui %[[A]] exact : index to i8
+//       CHECK:  %[[B_CASTED:.*]] = arith.index_castui %[[B]] exact : index to i8
 //       CHECK:  %[[RES1:.*]] = arith.addi %[[A_CASTED]], %[[B_CASTED]] : i8
-//       CHECK:  %[[C_CASTED:.*]] = arith.index_castui %[[C]] : index to i8
+//       CHECK:  %[[C_CASTED:.*]] = arith.index_castui %[[C]] exact : index to i8
 //       CHECK:  %[[RES2:.*]] = arith.cmpi slt, %[[C_CASTED]], %[[RES1]] : i8
 //       CHECK:  return %[[RES2]] : i1
 func.func @test_add_cmpi() -> i1 {
@@ -323,16 +323,16 @@ func.func @i32_overflows_to_i64(%arg0: i32) -> i64 {
 // CHECK: %[[BOUND:.+]] = test.with_bounds
 // CHECK-SAME: umax = 112
 // Loop narrows to i16 (not i8) because indVar+step=[80,144] doesn't fit in signed i8.
-// CHECK: %[[BOUND_I16:.+]] = arith.index_castui %[[BOUND]] : index to i16
+// CHECK: %[[BOUND_I16:.+]] = arith.index_castui %[[BOUND]] exact : index to i16
 // CHECK: scf.for %[[ARG0:.+]] = %[[C16_I16]] to %[[BOUND_I16]] step %[[C64_I16]]  : i16 {
-// CHECK:   %[[ARG0_INDEX:.+]] = arith.index_castui %[[ARG0]] : i16 to index
-// CHECK:   %[[BOUND_I8:.+]] = arith.index_castui %[[BOUND]] : index to i8
-// CHECK:   %[[ARG0_I8:.+]] = arith.index_castui %[[ARG0_INDEX]] : index to i8
+// CHECK:   %[[ARG0_INDEX:.+]] = arith.index_castui %[[ARG0]] exact : i16 to index
+// CHECK:   %[[BOUND_I8:.+]] = arith.index_castui %[[BOUND]] exact : index to i8
+// CHECK:   %[[ARG0_I8:.+]] = arith.index_castui %[[ARG0_INDEX]] exact : index to i8
 // CHECK:   %[[V0_I8:.+]] = arith.subi %[[BOUND_I8]], %[[ARG0_I8]] : i8
 // CHECK:   %[[V1_I8:.+]] = arith.minsi %[[V0_I8]], %[[C64_I8]] : i8
 // CHECK:   %[[V1_INDEX:.+]] = arith.index_cast %[[V1_I8]] exact : i8 to index
 // CHECK:   %[[V1_I16:.+]] = arith.index_cast %[[V1_INDEX]] exact : index to i16
-// CHECK:   %[[TID_I16:.+]] = arith.index_castui %[[TID]] : index to i16
+// CHECK:   %[[TID_I16:.+]] = arith.index_castui %[[TID]] exact : index to i16
 // CHECK:   %[[V2_I16:.+]] = arith.subi %[[V1_I16]], %[[TID_I16]] : i16
 // CHECK:   %[[V3:.+]] = arith.cmpi slt, %[[V2_I16]], %[[C0_I16]] : i16
 // CHECK:   scf.if %[[V3]]
@@ -370,7 +370,7 @@ func.func @loop_with_iter_arg() {
 //       CHECK:  %[[NEG:.*]] = test.with_bounds {smax = 0 : index, smin = -1 : index, umax = -1 : index, umin = 0 : index} : index
 // Check iter args are still present
 //       CHECK:  scf.for {{.*}} iter_args({{.*}})
-//       CHECK:  %[[POS_I8:.*]] = arith.index_castui %[[POS]] : index to i8
+//       CHECK:  %[[POS_I8:.*]] = arith.index_castui %[[POS]] exact : index to i8
 //       CHECK:  %[[NEG_I8:.*]] = arith.index_cast %[[NEG]] exact : index to i8
 //       CHECK:  %[[RES_I8:.*]] = arith.addi %[[POS_I8]], %[[NEG_I8]] : i8
 //       CHECK:  %[[RES:.*]] = arith.index_cast %[[RES_I8]] exact : i8 to index

>From 180f3cf34e8385b5dd47d5b921e7de6fc0a8afd7 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Tue, 3 Mar 2026 15:52:31 -0500
Subject: [PATCH 3/3] Set nneg flag

---
 .../Transforms/IntRangeOptimizations.cpp      | 11 ++++-
 .../Dialect/Arith/int-range-narrowing.mlir    | 46 +++++++++----------
 2 files changed, 33 insertions(+), 24 deletions(-)

diff --git a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
index b7f6e525a6a74..81ec8cf3fd863 100644
--- a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
@@ -312,6 +312,12 @@ static Value doCast(OpBuilder &builder, Location loc, Value src, Type dstType,
     }
     auto cast = arith::IndexCastUIOp::create(builder, loc, dstType, src);
     cast.setExact(true);
+    // Narrowing (index -> iN): the unsigned range fits in N < index width
+    // bits, so the top bits including the MSB are all zero.
+    // Widening (iN -> index): the MSB is zero only when the iN value fits
+    // in both signed and unsigned (Both).
+    if (isa<IndexType>(srcElemType) || castKind == CastKind::Both)
+      cast.setNonNeg(true);
     return cast;
   }
 
@@ -322,7 +328,10 @@ static Value doCast(OpBuilder &builder, Location loc, Value src, Type dstType,
 
   if (castKind == CastKind::Signed)
     return arith::ExtSIOp::create(builder, loc, dstType, src);
-  return arith::ExtUIOp::create(builder, loc, dstType, src);
+  auto ext = arith::ExtUIOp::create(builder, loc, dstType, src);
+  if (castKind == CastKind::Both)
+    ext.setNonNeg(true);
+  return ext;
 }
 
 struct NarrowElementwise final : OpTraitRewritePattern<OpTrait::Elementwise> {
diff --git a/mlir/test/Dialect/Arith/int-range-narrowing.mlir b/mlir/test/Dialect/Arith/int-range-narrowing.mlir
index e9c232df06b77..5797a0c83bca7 100644
--- a/mlir/test/Dialect/Arith/int-range-narrowing.mlir
+++ b/mlir/test/Dialect/Arith/int-range-narrowing.mlir
@@ -8,7 +8,7 @@
 // CHECK-LABEL: func @test_addi_neg
 //       CHECK:  %[[POS:.*]] = test.with_bounds {smax = 1 : index, smin = 0 : index, umax = 1 : index, umin = 0 : index} : index
 //       CHECK:  %[[NEG:.*]] = test.with_bounds {smax = 0 : index, smin = -1 : index, umax = -1 : index, umin = 0 : index} : index
-//       CHECK:  %[[POS_I8:.*]] = arith.index_castui %[[POS]] exact : index to i8
+//       CHECK:  %[[POS_I8:.*]] = arith.index_castui %[[POS]] exact nneg : index to i8
 //       CHECK:  %[[NEG_I8:.*]] = arith.index_cast %[[NEG]] exact : index to i8
 //       CHECK:  %[[RES_I8:.*]] = arith.addi %[[POS_I8]], %[[NEG_I8]] : i8
 //       CHECK:  %[[RES:.*]] = arith.index_cast %[[RES_I8]] exact : i8 to index
@@ -23,10 +23,10 @@ func.func @test_addi_neg() -> index {
 // CHECK-LABEL: func @test_addi
 //       CHECK:  %[[A:.*]] = test.with_bounds {smax = 5 : index, smin = 4 : index, umax = 5 : index, umin = 4 : index} : index
 //       CHECK:  %[[B:.*]] = test.with_bounds {smax = 7 : index, smin = 6 : index, umax = 7 : index, umin = 6 : index} : index
-//       CHECK:  %[[A_CASTED:.*]] = arith.index_castui %[[A]] exact : index to i8
-//       CHECK:  %[[B_CASTED:.*]] = arith.index_castui %[[B]] exact : index to i8
+//       CHECK:  %[[A_CASTED:.*]] = arith.index_castui %[[A]] exact nneg : index to i8
+//       CHECK:  %[[B_CASTED:.*]] = arith.index_castui %[[B]] exact nneg : index to i8
 //       CHECK:  %[[RES:.*]] = arith.addi %[[A_CASTED]], %[[B_CASTED]] : i8
-//       CHECK:  %[[RES_CASTED:.*]] = arith.index_castui %[[RES]] exact : i8 to index
+//       CHECK:  %[[RES_CASTED:.*]] = arith.index_castui %[[RES]] exact nneg : i8 to index
 //       CHECK:  return %[[RES_CASTED]] : index
 func.func @test_addi() -> index {
   %0 = test.with_bounds { umin = 4 : index, umax = 5 : index, smin = 4 : index, smax = 5 : index } : index
@@ -38,10 +38,10 @@ func.func @test_addi() -> index {
 // CHECK-LABEL: func @test_addi_vec
 //       CHECK:  %[[A:.*]] = test.with_bounds {smax = 5 : index, smin = 4 : index, umax = 5 : index, umin = 4 : index} : vector<4xindex>
 //       CHECK:  %[[B:.*]] = test.with_bounds {smax = 7 : index, smin = 6 : index, umax = 7 : index, umin = 6 : index} : vector<4xindex>
-//       CHECK:  %[[A_CASTED:.*]] = arith.index_castui %[[A]] exact : vector<4xindex> to vector<4xi8>
-//       CHECK:  %[[B_CASTED:.*]] = arith.index_castui %[[B]] exact : vector<4xindex> to vector<4xi8>
+//       CHECK:  %[[A_CASTED:.*]] = arith.index_castui %[[A]] exact nneg : vector<4xindex> to vector<4xi8>
+//       CHECK:  %[[B_CASTED:.*]] = arith.index_castui %[[B]] exact nneg : vector<4xindex> to vector<4xi8>
 //       CHECK:  %[[RES:.*]] = arith.addi %[[A_CASTED]], %[[B_CASTED]] : vector<4xi8>
-//       CHECK:  %[[RES_CASTED:.*]] = arith.index_castui %[[RES]] exact : vector<4xi8> to vector<4xindex>
+//       CHECK:  %[[RES_CASTED:.*]] = arith.index_castui %[[RES]] exact nneg : vector<4xi8> to vector<4xindex>
 //       CHECK:  return %[[RES_CASTED]] : vector<4xindex>
 func.func @test_addi_vec() -> vector<4xindex> {
   %0 = test.with_bounds { umin = 4 : index, umax = 5 : index, smin = 4 : index, smax = 5 : index } : vector<4xindex>
@@ -56,7 +56,7 @@ func.func @test_addi_vec() -> vector<4xindex> {
 //       CHECK:  %[[A_CASTED:.*]] = arith.trunci %[[A]] : i64 to i8
 //       CHECK:  %[[B_CASTED:.*]] = arith.trunci %[[B]] : i64 to i8
 //       CHECK:  %[[RES:.*]] = arith.addi %[[A_CASTED]], %[[B_CASTED]] : i8
-//       CHECK:  %[[RES_CASTED:.*]] = arith.extui %[[RES]] : i8 to i64
+//       CHECK:  %[[RES_CASTED:.*]] = arith.extui %[[RES]] nneg : i8 to i64
 //       CHECK:  return %[[RES_CASTED]] : i64
 func.func @test_addi_i64() -> i64 {
   %0 = test.with_bounds { umin = 4 : i64, umax = 5 : i64, smin = 4 : i64, smax = 5 : i64 } : i64
@@ -68,8 +68,8 @@ func.func @test_addi_i64() -> i64 {
 // CHECK-LABEL: func @test_cmpi
 //       CHECK:  %[[A:.*]] = test.with_bounds {smax = 10 : index, smin = 0 : index, umax = 10 : index, umin = 0 : index} : index
 //       CHECK:  %[[B:.*]] = test.with_bounds {smax = 10 : index, smin = 0 : index, umax = 10 : index, umin = 0 : index} : index
-//       CHECK:  %[[A_CASTED:.*]] = arith.index_castui %[[A]] exact : index to i8
-//       CHECK:  %[[B_CASTED:.*]] = arith.index_castui %[[B]] exact : index to i8
+//       CHECK:  %[[A_CASTED:.*]] = arith.index_castui %[[A]] exact nneg : index to i8
+//       CHECK:  %[[B_CASTED:.*]] = arith.index_castui %[[B]] exact nneg : index to i8
 //       CHECK:  %[[RES:.*]] = arith.cmpi slt, %[[A_CASTED]], %[[B_CASTED]] : i8
 //       CHECK:  return %[[RES]] : i1
 func.func @test_cmpi() -> i1 {
@@ -82,8 +82,8 @@ func.func @test_cmpi() -> i1 {
 // CHECK-LABEL: func @test_cmpi_vec
 //       CHECK:  %[[A:.*]] = test.with_bounds {smax = 10 : index, smin = 0 : index, umax = 10 : index, umin = 0 : index} : vector<4xindex>
 //       CHECK:  %[[B:.*]] = test.with_bounds {smax = 10 : index, smin = 0 : index, umax = 10 : index, umin = 0 : index} : vector<4xindex>
-//       CHECK:  %[[A_CASTED:.*]] = arith.index_castui %[[A]] exact : vector<4xindex> to vector<4xi8>
-//       CHECK:  %[[B_CASTED:.*]] = arith.index_castui %[[B]] exact : vector<4xindex> to vector<4xi8>
+//       CHECK:  %[[A_CASTED:.*]] = arith.index_castui %[[A]] exact nneg : vector<4xindex> to vector<4xi8>
+//       CHECK:  %[[B_CASTED:.*]] = arith.index_castui %[[B]] exact nneg : vector<4xindex> to vector<4xi8>
 //       CHECK:  %[[RES:.*]] = arith.cmpi slt, %[[A_CASTED]], %[[B_CASTED]] : vector<4xi8>
 //       CHECK:  return %[[RES]] : vector<4xi1>
 func.func @test_cmpi_vec() -> vector<4xi1> {
@@ -97,10 +97,10 @@ func.func @test_cmpi_vec() -> vector<4xi1> {
 //       CHECK:  %[[A:.*]] = test.with_bounds {smax = 10 : index, smin = 0 : index, umax = 10 : index, umin = 0 : index} : index
 //       CHECK:  %[[B:.*]] = test.with_bounds {smax = 10 : index, smin = 0 : index, umax = 10 : index, umin = 0 : index} : index
 //       CHECK:  %[[C:.*]] = test.with_bounds {smax = 10 : index, smin = 0 : index, umax = 10 : index, umin = 0 : index} : index
-//       CHECK:  %[[A_CASTED:.*]] = arith.index_castui %[[A]] exact : index to i8
-//       CHECK:  %[[B_CASTED:.*]] = arith.index_castui %[[B]] exact : index to i8
+//       CHECK:  %[[A_CASTED:.*]] = arith.index_castui %[[A]] exact nneg : index to i8
+//       CHECK:  %[[B_CASTED:.*]] = arith.index_castui %[[B]] exact nneg : index to i8
 //       CHECK:  %[[RES1:.*]] = arith.addi %[[A_CASTED]], %[[B_CASTED]] : i8
-//       CHECK:  %[[C_CASTED:.*]] = arith.index_castui %[[C]] exact : index to i8
+//       CHECK:  %[[C_CASTED:.*]] = arith.index_castui %[[C]] exact nneg : index to i8
 //       CHECK:  %[[RES2:.*]] = arith.cmpi slt, %[[C_CASTED]], %[[RES1]] : i8
 //       CHECK:  return %[[RES2]] : i1
 func.func @test_add_cmpi() -> i1 {
@@ -142,7 +142,7 @@ func.func @test_add_cmpi_i64() -> i1 {
 // CHECK-NEXT:    %[[LHS:.+]]  = arith.trunci %[[EXT0]] : i32 to i16
 // CHECK-NEXT:    %[[RHS:.+]]  = arith.trunci %[[EXT1]] : i32 to i16
 // CHECK-NEXT:    %[[ADD:.+]]  = arith.addi %[[LHS]], %[[RHS]] : i16
-// CHECK-NEXT:    %[[RET:.+]]  = arith.extui %[[ADD]] : i16 to i32
+// CHECK-NEXT:    %[[RET:.+]]  = arith.extui %[[ADD]] nneg : i16 to i32
 // CHECK-NEXT:    return %[[RET]] : i32
 func.func @addi_extui_i8(%lhs: i8, %rhs: i8) -> i32 {
   %a = arith.extui %lhs : i8 to i32
@@ -323,16 +323,16 @@ func.func @i32_overflows_to_i64(%arg0: i32) -> i64 {
 // CHECK: %[[BOUND:.+]] = test.with_bounds
 // CHECK-SAME: umax = 112
 // Loop narrows to i16 (not i8) because indVar+step=[80,144] doesn't fit in signed i8.
-// CHECK: %[[BOUND_I16:.+]] = arith.index_castui %[[BOUND]] exact : index to i16
+// CHECK: %[[BOUND_I16:.+]] = arith.index_castui %[[BOUND]] exact nneg : index to i16
 // CHECK: scf.for %[[ARG0:.+]] = %[[C16_I16]] to %[[BOUND_I16]] step %[[C64_I16]]  : i16 {
-// CHECK:   %[[ARG0_INDEX:.+]] = arith.index_castui %[[ARG0]] exact : i16 to index
-// CHECK:   %[[BOUND_I8:.+]] = arith.index_castui %[[BOUND]] exact : index to i8
-// CHECK:   %[[ARG0_I8:.+]] = arith.index_castui %[[ARG0_INDEX]] exact : index to i8
+// CHECK:   %[[ARG0_INDEX:.+]] = arith.index_castui %[[ARG0]] exact nneg : i16 to index
+// CHECK:   %[[BOUND_I8:.+]] = arith.index_castui %[[BOUND]] exact nneg : index to i8
+// CHECK:   %[[ARG0_I8:.+]] = arith.index_castui %[[ARG0_INDEX]] exact nneg : index to i8
 // CHECK:   %[[V0_I8:.+]] = arith.subi %[[BOUND_I8]], %[[ARG0_I8]] : i8
 // CHECK:   %[[V1_I8:.+]] = arith.minsi %[[V0_I8]], %[[C64_I8]] : i8
 // CHECK:   %[[V1_INDEX:.+]] = arith.index_cast %[[V1_I8]] exact : i8 to index
 // CHECK:   %[[V1_I16:.+]] = arith.index_cast %[[V1_INDEX]] exact : index to i16
-// CHECK:   %[[TID_I16:.+]] = arith.index_castui %[[TID]] exact : index to i16
+// CHECK:   %[[TID_I16:.+]] = arith.index_castui %[[TID]] exact nneg : index to i16
 // CHECK:   %[[V2_I16:.+]] = arith.subi %[[V1_I16]], %[[TID_I16]] : i16
 // CHECK:   %[[V3:.+]] = arith.cmpi slt, %[[V2_I16]], %[[C0_I16]] : i16
 // CHECK:   scf.if %[[V3]]
@@ -370,7 +370,7 @@ func.func @loop_with_iter_arg() {
 //       CHECK:  %[[NEG:.*]] = test.with_bounds {smax = 0 : index, smin = -1 : index, umax = -1 : index, umin = 0 : index} : index
 // Check iter args are still present
 //       CHECK:  scf.for {{.*}} iter_args({{.*}})
-//       CHECK:  %[[POS_I8:.*]] = arith.index_castui %[[POS]] exact : index to i8
+//       CHECK:  %[[POS_I8:.*]] = arith.index_castui %[[POS]] exact nneg : index to i8
 //       CHECK:  %[[NEG_I8:.*]] = arith.index_cast %[[NEG]] exact : index to i8
 //       CHECK:  %[[RES_I8:.*]] = arith.addi %[[POS_I8]], %[[NEG_I8]] : i8
 //       CHECK:  %[[RES:.*]] = arith.index_cast %[[RES_I8]] exact : i8 to index
@@ -409,7 +409,7 @@ func.func @narrow_loop_bounds() {
   // CHECK-DAG: %[[STEP_I8:.*]] = arith.trunci %[[STEP]] : i64 to i8
   // CHECK: scf.for %[[IV:.*]] = %[[LB_I8]] to %[[UB_I8]] step %[[STEP_I8]] : i8 {
   // CHECK:   %[[ADD_I8:.*]] = arith.addi %[[IV]], %[[C1_I8]] : i8
-  // CHECK:   %[[ADD_I64:.*]] = arith.extui %[[ADD_I8]] : i8 to i64
+  // CHECK:   %[[ADD_I64:.*]] = arith.extui %[[ADD_I8]] nneg : i8 to i64
   // CHECK:   call @use_i64(%[[ADD_I64]])
   scf.for %iv = %lb to %ub step %step : i64 {
     %add = arith.addi %iv, %c1_i64 : i64



More information about the Mlir-commits mailing list