[Mlir-commits] [mlir] [mlir][arith] Move canonicalization patterns to fold (PR #184381)
Erick Ochoa Lopez
llvmlistbot at llvm.org
Tue Mar 3 12:52:49 PST 2026
https://github.com/amd-eochoalo updated https://github.com/llvm/llvm-project/pull/184381
>From ca1cdd5f808f93fd0a256ad35fa21d560162e45d Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Tue, 3 Mar 2026 11:04:30 -0500
Subject: [PATCH 1/3] Move canonicalization to fold
---
.../Dialect/Arith/IR/ArithCanonicalization.td | 9 -----
mlir/lib/Dialect/Arith/IR/ArithOps.cpp | 10 ++++-
.../Transforms/IntRangeOptimizations.cpp | 12 +++---
mlir/test/Dialect/Arith/canonicalize.mlir | 37 +++++++++++++++++++
.../Dialect/Arith/int-range-narrowing.mlir | 12 +++---
mlir/test/Transforms/canonicalize.mlir | 9 -----
6 files changed, 59 insertions(+), 30 deletions(-)
diff --git a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
index f26af4816ce85..baabbca99e37c 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
+++ b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
@@ -288,15 +288,6 @@ def SelectI1ToNot :
// IndexCastOp
//===----------------------------------------------------------------------===//
-// index_cast(index_cast(x, exact)) -> x, if dstType == srcType.
-// The inner exact guarantees the iN -> index conversion is lossless,
-// so the roundtrip through index preserves the value.
-def IndexCastOfIndexCast :
- Pat<(Arith_IndexCastOp:$res (Arith_IndexCastOp $x, $exact1), $exact2),
- (replaceWithValue $x),
- [(Constraint<CPred<"$0.getType() == $1.getType()">> $res, $x),
- (Constraint<CPred<"(bool)$0">> $exact1)]>;
-
// index_cast(extsi(x)) -> index_cast(x)
def IndexCastOfExtSI :
Pat<(Arith_IndexCastOp (Arith_ExtSIOp $x), $exact),
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index b99f77fdc8b30..ad837bea01510 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -1845,6 +1845,14 @@ bool arith::IndexCastOp::areCastCompatible(TypeRange inputs,
}
OpFoldResult arith::IndexCastOp::fold(FoldAdaptor adaptor) {
+ // index_cast(index_cast(x, exact)) -> x, when result type matches x's type.
+ // The inner exact guarantees the iN -> index conversion is lossless,
+ // so the roundtrip through index preserves the value.
+ if (auto innerCast = getIn().getDefiningOp<IndexCastOp>()) {
+ if (innerCast.getIn().getType() == getType() && innerCast.getIsExact())
+ return innerCast.getIn();
+ }
+
// index_cast(constant) -> constant
unsigned resultBitwidth = 64; // Default for index integer attributes.
if (auto intTy = dyn_cast<IntegerType>(getElementTypeOrSelf(getType())))
@@ -1859,7 +1867,7 @@ OpFoldResult arith::IndexCastOp::fold(FoldAdaptor adaptor) {
void arith::IndexCastOp::getCanonicalizationPatterns(
RewritePatternSet &patterns, MLIRContext *context) {
- patterns.add<IndexCastOfIndexCast, IndexCastOfExtSI>(context);
+ patterns.add<IndexCastOfExtSI>(context);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
index fefbba989b996..df3eea25a8d20 100644
--- a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
@@ -305,8 +305,11 @@ static Value doCast(OpBuilder &builder, Location loc, Value src, Type dstType,
return src;
if (isa<IndexType>(srcElemType) || isa<IndexType>(dstElemType)) {
- if (castKind == CastKind::Signed)
- return arith::IndexCastOp::create(builder, loc, dstType, src);
+ if (castKind == CastKind::Signed) {
+ auto cast = arith::IndexCastOp::create(builder, loc, dstType, src);
+ cast.setExact(true);
+ return cast;
+ }
return arith::IndexCastUIOp::create(builder, loc, dstType, src);
}
@@ -725,9 +728,8 @@ void mlir::arith::populateIntRangeNarrowingPatterns(
ArrayRef<unsigned> bitwidthsSupported) {
patterns.add<NarrowElementwise, NarrowCmpI>(patterns.getContext(), solver,
bitwidthsSupported);
- patterns.add<FoldIndexCastChain<arith::IndexCastUIOp>,
- FoldIndexCastChain<arith::IndexCastOp>>(patterns.getContext(),
- bitwidthsSupported);
+ patterns.add<FoldIndexCastChain<arith::IndexCastUIOp>>(patterns.getContext(),
+ bitwidthsSupported);
}
void mlir::arith::populateControlFlowValuesNarrowingPatterns(
diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index a2d0eff47ad92..0a5a7cc56517a 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -643,6 +643,43 @@ func.func @indexCastUIOfUnsignedExtend_nneg_exact(%arg0: i8) -> index {
return %idx : index
}
+// index_cast(index_cast(x)) -> x only when exact is on the inner cast.
+// CHECK-LABEL: @indexCastOfIndexCast_no_exact
+// CHECK: arith.index_cast
+// CHECK: arith.index_cast
+func.func @indexCastOfIndexCast_no_exact(%arg0: i16) -> i16 {
+ %idx = arith.index_cast %arg0 : i16 to index
+ %res = arith.index_cast %idx : index to i16
+ return %res : i16
+}
+
+// CHECK-LABEL: @indexCastOfIndexCast_exact_inner
+// CHECK: return %arg0 : i16
+func.func @indexCastOfIndexCast_exact_inner(%arg0: i16) -> i16 {
+ %idx = arith.index_cast %arg0 exact : i16 to index
+ %res = arith.index_cast %idx : index to i16
+ return %res : i16
+}
+
+// exact on outer only does NOT trigger the fold (outer exact on widening
+// is vacuously true and does not guarantee the inner truncation is lossless).
+// CHECK-LABEL: @indexCastOfIndexCast_exact_outer
+// CHECK: arith.index_cast
+// CHECK: arith.index_cast
+func.func @indexCastOfIndexCast_exact_outer(%arg0: i16) -> i16 {
+ %idx = arith.index_cast %arg0 : i16 to index
+ %res = arith.index_cast %idx exact : index to i16
+ return %res : i16
+}
+
+// CHECK-LABEL: @indexCastOfIndexCast_exact_both
+// CHECK: return %arg0 : i16
+func.func @indexCastOfIndexCast_exact_both(%arg0: i16) -> i16 {
+ %idx = arith.index_cast %arg0 exact : i16 to index
+ %res = arith.index_cast %idx exact : index to i16
+ return %res : i16
+}
+
// index_castui(index_castui(x)) -> x only when exact is on the inner cast.
// CHECK-LABEL: @indexCastUIOfIndexCastUI_no_exact
// CHECK: arith.index_castui
diff --git a/mlir/test/Dialect/Arith/int-range-narrowing.mlir b/mlir/test/Dialect/Arith/int-range-narrowing.mlir
index c3b0d280b1350..5a26d1d0e7fa4 100644
--- a/mlir/test/Dialect/Arith/int-range-narrowing.mlir
+++ b/mlir/test/Dialect/Arith/int-range-narrowing.mlir
@@ -9,9 +9,9 @@
// CHECK: %[[POS:.*]] = test.with_bounds {smax = 1 : index, smin = 0 : index, umax = 1 : index, umin = 0 : index} : index
// CHECK: %[[NEG:.*]] = test.with_bounds {smax = 0 : index, smin = -1 : index, umax = -1 : index, umin = 0 : index} : index
// CHECK: %[[POS_I8:.*]] = arith.index_castui %[[POS]] : index to i8
-// CHECK: %[[NEG_I8:.*]] = arith.index_cast %[[NEG]] : index to i8
+// CHECK: %[[NEG_I8:.*]] = arith.index_cast %[[NEG]] exact : index to i8
// CHECK: %[[RES_I8:.*]] = arith.addi %[[POS_I8]], %[[NEG_I8]] : i8
-// CHECK: %[[RES:.*]] = arith.index_cast %[[RES_I8]] : i8 to index
+// CHECK: %[[RES:.*]] = arith.index_cast %[[RES_I8]] exact : i8 to index
// CHECK: return %[[RES]] : index
func.func @test_addi_neg() -> index {
%0 = test.with_bounds { umin = 0 : index, umax = 1 : index, smin = 0 : index, smax = 1 : index } : index
@@ -330,8 +330,8 @@ func.func @i32_overflows_to_i64(%arg0: i32) -> i64 {
// CHECK: %[[ARG0_I8:.+]] = arith.index_castui %[[ARG0_INDEX]] : index to i8
// CHECK: %[[V0_I8:.+]] = arith.subi %[[BOUND_I8]], %[[ARG0_I8]] : i8
// CHECK: %[[V1_I8:.+]] = arith.minsi %[[V0_I8]], %[[C64_I8]] : i8
-// CHECK: %[[V1_INDEX:.+]] = arith.index_cast %[[V1_I8]] : i8 to index
-// CHECK: %[[V1_I16:.+]] = arith.index_cast %[[V1_INDEX]] : index to i16
+// CHECK: %[[V1_INDEX:.+]] = arith.index_cast %[[V1_I8]] exact : i8 to index
+// CHECK: %[[V1_I16:.+]] = arith.index_cast %[[V1_INDEX]] exact : index to i16
// CHECK: %[[TID_I16:.+]] = arith.index_castui %[[TID]] : index to i16
// CHECK: %[[V2_I16:.+]] = arith.subi %[[V1_I16]], %[[TID_I16]] : i16
// CHECK: %[[V3:.+]] = arith.cmpi slt, %[[V2_I16]], %[[C0_I16]] : i16
@@ -371,9 +371,9 @@ func.func @loop_with_iter_arg() {
// Check iter args are still present
// CHECK: scf.for {{.*}} iter_args({{.*}})
// CHECK: %[[POS_I8:.*]] = arith.index_castui %[[POS]] : index to i8
-// CHECK: %[[NEG_I8:.*]] = arith.index_cast %[[NEG]] : index to i8
+// CHECK: %[[NEG_I8:.*]] = arith.index_cast %[[NEG]] exact : index to i8
// CHECK: %[[RES_I8:.*]] = arith.addi %[[POS_I8]], %[[NEG_I8]] : i8
-// CHECK: %[[RES:.*]] = arith.index_cast %[[RES_I8]] : i8 to index
+// CHECK: %[[RES:.*]] = arith.index_cast %[[RES_I8]] exact : i8 to index
// CHECK: call @use(%[[RES]])
%0 = test.with_bounds { umin = 0 : index, umax = 1 : index, smin = 0 : index, smax = 1 : index } : index
diff --git a/mlir/test/Transforms/canonicalize.mlir b/mlir/test/Transforms/canonicalize.mlir
index 3930575c45b3e..3ea6619ee589d 100644
--- a/mlir/test/Transforms/canonicalize.mlir
+++ b/mlir/test/Transforms/canonicalize.mlir
@@ -892,15 +892,6 @@ func.func @subview(%arg0 : index, %arg1 : index) -> (index, index) {
return %7, %8 : index, index
}
-// CHECK-LABEL: func @index_cast
-// CHECK-SAME: %[[ARG_0:arg[0-9]+]]: i16
-func.func @index_cast(%arg0: i16) -> (i16) {
- %11 = arith.index_cast %arg0 exact : i16 to index
- %12 = arith.index_cast %11 : index to i16
- // CHECK: return %[[ARG_0]] : i16
- return %12 : i16
-}
-
// CHECK-LABEL: func @index_cast_fold
func.func @index_cast_fold() -> (i16, index) {
%c4 = arith.constant 4 : index
>From cc13e2161e9dd42bd3bfdb59e7cb4f7688050509 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Tue, 3 Mar 2026 11:23:15 -0500
Subject: [PATCH 2/3] Move cast_ui to fold
---
.../Dialect/Arith/IR/ArithCanonicalization.td | 10 -----
mlir/lib/Dialect/Arith/IR/ArithOps.cpp | 8 +++-
.../Transforms/IntRangeOptimizations.cpp | 39 ++----------------
.../Dialect/Arith/int-range-narrowing.mlir | 40 +++++++++----------
4 files changed, 30 insertions(+), 67 deletions(-)
diff --git a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
index baabbca99e37c..0d87092b7721f 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
+++ b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
@@ -297,16 +297,6 @@ def IndexCastOfExtSI :
// IndexCastUIOp
//===----------------------------------------------------------------------===//
-// index_castui(index_castui(x, exact)) -> x, if dstType == srcType.
-// The inner exact guarantees the iN -> index conversion is lossless,
-// so the roundtrip through index preserves the value.
-def IndexCastUIOfIndexCastUI :
- Pat<(Arith_IndexCastUIOp:$res
- (Arith_IndexCastUIOp $x, $nneg1, $exact1), $nneg2, $exact2),
- (replaceWithValue $x),
- [(Constraint<CPred<"$0.getType() == $1.getType()">> $res, $x),
- (Constraint<CPred<"static_cast<bool>($0)">> $exact1)]>;
-
// index_castui(extui(x)) -> index_castui(x)
def IndexCastUIOfExtUI :
Pat<(Arith_IndexCastUIOp (Arith_ExtUIOp $x, $nneg1), $nneg2, $exact),
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index ad837bea01510..625e2c3bb13fc 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -1880,6 +1880,12 @@ bool arith::IndexCastUIOp::areCastCompatible(TypeRange inputs,
}
OpFoldResult arith::IndexCastUIOp::fold(FoldAdaptor adaptor) {
+ // index_castui(index_castui(x, exact)) -> x
+ if (auto innerCast = getIn().getDefiningOp<IndexCastUIOp>()) {
+ if (innerCast.getIn().getType() == getType() && innerCast.getIsExact())
+ return innerCast.getIn();
+ }
+
// index_castui(constant) -> constant
unsigned resultBitwidth = 64; // Default for index integer attributes.
if (auto intTy = dyn_cast<IntegerType>(getElementTypeOrSelf(getType())))
@@ -1894,7 +1900,7 @@ OpFoldResult arith::IndexCastUIOp::fold(FoldAdaptor adaptor) {
void arith::IndexCastUIOp::getCanonicalizationPatterns(
RewritePatternSet &patterns, MLIRContext *context) {
- patterns.add<IndexCastUIOfIndexCastUI, IndexCastUIOfExtUI>(context);
+ patterns.add<IndexCastUIOfExtUI>(context);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
index df3eea25a8d20..b7f6e525a6a74 100644
--- a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
@@ -310,7 +310,9 @@ static Value doCast(OpBuilder &builder, Location loc, Value src, Type dstType,
cast.setExact(true);
return cast;
}
- return arith::IndexCastUIOp::create(builder, loc, dstType, src);
+ auto cast = arith::IndexCastUIOp::create(builder, loc, dstType, src);
+ cast.setExact(true);
+ return cast;
}
auto srcInt = cast<IntegerType>(srcElemType);
@@ -448,39 +450,6 @@ struct NarrowCmpI final : OpRewritePattern<arith::CmpIOp> {
SmallVector<unsigned, 4> targetBitwidths;
};
-/// Fold index_cast(index_cast(%arg: i8, index), i8) -> %arg
-/// This pattern assumes all passed `targetBitwidths` are not wider than index
-/// type.
-template <typename CastOp>
-struct FoldIndexCastChain final : OpRewritePattern<CastOp> {
- FoldIndexCastChain(MLIRContext *context, ArrayRef<unsigned> target)
- : OpRewritePattern<CastOp>(context), targetBitwidths(target) {}
-
- LogicalResult matchAndRewrite(CastOp op,
- PatternRewriter &rewriter) const override {
- auto srcOp = op.getIn().template getDefiningOp<CastOp>();
- if (!srcOp)
- return rewriter.notifyMatchFailure(op, "doesn't come from an index cast");
-
- Value src = srcOp.getIn();
- if (src.getType() != op.getType())
- return rewriter.notifyMatchFailure(op, "outer types don't match");
-
- if (!srcOp.getType().isIndex())
- return rewriter.notifyMatchFailure(op, "intermediate type isn't index");
-
- auto intType = dyn_cast<IntegerType>(op.getType());
- if (!intType || !llvm::is_contained(targetBitwidths, intType.getWidth()))
- return failure();
-
- rewriter.replaceOp(op, src);
- return success();
- }
-
-private:
- SmallVector<unsigned, 4> targetBitwidths;
-};
-
struct NarrowLoopBounds final : OpInterfaceRewritePattern<LoopLikeOpInterface> {
NarrowLoopBounds(MLIRContext *context, DataFlowSolver &s,
ArrayRef<unsigned> target)
@@ -728,8 +697,6 @@ void mlir::arith::populateIntRangeNarrowingPatterns(
ArrayRef<unsigned> bitwidthsSupported) {
patterns.add<NarrowElementwise, NarrowCmpI>(patterns.getContext(), solver,
bitwidthsSupported);
- patterns.add<FoldIndexCastChain<arith::IndexCastUIOp>>(patterns.getContext(),
- bitwidthsSupported);
}
void mlir::arith::populateControlFlowValuesNarrowingPatterns(
diff --git a/mlir/test/Dialect/Arith/int-range-narrowing.mlir b/mlir/test/Dialect/Arith/int-range-narrowing.mlir
index 5a26d1d0e7fa4..e9c232df06b77 100644
--- a/mlir/test/Dialect/Arith/int-range-narrowing.mlir
+++ b/mlir/test/Dialect/Arith/int-range-narrowing.mlir
@@ -8,7 +8,7 @@
// CHECK-LABEL: func @test_addi_neg
// CHECK: %[[POS:.*]] = test.with_bounds {smax = 1 : index, smin = 0 : index, umax = 1 : index, umin = 0 : index} : index
// CHECK: %[[NEG:.*]] = test.with_bounds {smax = 0 : index, smin = -1 : index, umax = -1 : index, umin = 0 : index} : index
-// CHECK: %[[POS_I8:.*]] = arith.index_castui %[[POS]] : index to i8
+// CHECK: %[[POS_I8:.*]] = arith.index_castui %[[POS]] exact : index to i8
// CHECK: %[[NEG_I8:.*]] = arith.index_cast %[[NEG]] exact : index to i8
// CHECK: %[[RES_I8:.*]] = arith.addi %[[POS_I8]], %[[NEG_I8]] : i8
// CHECK: %[[RES:.*]] = arith.index_cast %[[RES_I8]] exact : i8 to index
@@ -23,10 +23,10 @@ func.func @test_addi_neg() -> index {
// CHECK-LABEL: func @test_addi
// CHECK: %[[A:.*]] = test.with_bounds {smax = 5 : index, smin = 4 : index, umax = 5 : index, umin = 4 : index} : index
// CHECK: %[[B:.*]] = test.with_bounds {smax = 7 : index, smin = 6 : index, umax = 7 : index, umin = 6 : index} : index
-// CHECK: %[[A_CASTED:.*]] = arith.index_castui %[[A]] : index to i8
-// CHECK: %[[B_CASTED:.*]] = arith.index_castui %[[B]] : index to i8
+// CHECK: %[[A_CASTED:.*]] = arith.index_castui %[[A]] exact : index to i8
+// CHECK: %[[B_CASTED:.*]] = arith.index_castui %[[B]] exact : index to i8
// CHECK: %[[RES:.*]] = arith.addi %[[A_CASTED]], %[[B_CASTED]] : i8
-// CHECK: %[[RES_CASTED:.*]] = arith.index_castui %[[RES]] : i8 to index
+// CHECK: %[[RES_CASTED:.*]] = arith.index_castui %[[RES]] exact : i8 to index
// CHECK: return %[[RES_CASTED]] : index
func.func @test_addi() -> index {
%0 = test.with_bounds { umin = 4 : index, umax = 5 : index, smin = 4 : index, smax = 5 : index } : index
@@ -38,10 +38,10 @@ func.func @test_addi() -> index {
// CHECK-LABEL: func @test_addi_vec
// CHECK: %[[A:.*]] = test.with_bounds {smax = 5 : index, smin = 4 : index, umax = 5 : index, umin = 4 : index} : vector<4xindex>
// CHECK: %[[B:.*]] = test.with_bounds {smax = 7 : index, smin = 6 : index, umax = 7 : index, umin = 6 : index} : vector<4xindex>
-// CHECK: %[[A_CASTED:.*]] = arith.index_castui %[[A]] : vector<4xindex> to vector<4xi8>
-// CHECK: %[[B_CASTED:.*]] = arith.index_castui %[[B]] : vector<4xindex> to vector<4xi8>
+// CHECK: %[[A_CASTED:.*]] = arith.index_castui %[[A]] exact : vector<4xindex> to vector<4xi8>
+// CHECK: %[[B_CASTED:.*]] = arith.index_castui %[[B]] exact : vector<4xindex> to vector<4xi8>
// CHECK: %[[RES:.*]] = arith.addi %[[A_CASTED]], %[[B_CASTED]] : vector<4xi8>
-// CHECK: %[[RES_CASTED:.*]] = arith.index_castui %[[RES]] : vector<4xi8> to vector<4xindex>
+// CHECK: %[[RES_CASTED:.*]] = arith.index_castui %[[RES]] exact : vector<4xi8> to vector<4xindex>
// CHECK: return %[[RES_CASTED]] : vector<4xindex>
func.func @test_addi_vec() -> vector<4xindex> {
%0 = test.with_bounds { umin = 4 : index, umax = 5 : index, smin = 4 : index, smax = 5 : index } : vector<4xindex>
@@ -68,8 +68,8 @@ func.func @test_addi_i64() -> i64 {
// CHECK-LABEL: func @test_cmpi
// CHECK: %[[A:.*]] = test.with_bounds {smax = 10 : index, smin = 0 : index, umax = 10 : index, umin = 0 : index} : index
// CHECK: %[[B:.*]] = test.with_bounds {smax = 10 : index, smin = 0 : index, umax = 10 : index, umin = 0 : index} : index
-// CHECK: %[[A_CASTED:.*]] = arith.index_castui %[[A]] : index to i8
-// CHECK: %[[B_CASTED:.*]] = arith.index_castui %[[B]] : index to i8
+// CHECK: %[[A_CASTED:.*]] = arith.index_castui %[[A]] exact : index to i8
+// CHECK: %[[B_CASTED:.*]] = arith.index_castui %[[B]] exact : index to i8
// CHECK: %[[RES:.*]] = arith.cmpi slt, %[[A_CASTED]], %[[B_CASTED]] : i8
// CHECK: return %[[RES]] : i1
func.func @test_cmpi() -> i1 {
@@ -82,8 +82,8 @@ func.func @test_cmpi() -> i1 {
// CHECK-LABEL: func @test_cmpi_vec
// CHECK: %[[A:.*]] = test.with_bounds {smax = 10 : index, smin = 0 : index, umax = 10 : index, umin = 0 : index} : vector<4xindex>
// CHECK: %[[B:.*]] = test.with_bounds {smax = 10 : index, smin = 0 : index, umax = 10 : index, umin = 0 : index} : vector<4xindex>
-// CHECK: %[[A_CASTED:.*]] = arith.index_castui %[[A]] : vector<4xindex> to vector<4xi8>
-// CHECK: %[[B_CASTED:.*]] = arith.index_castui %[[B]] : vector<4xindex> to vector<4xi8>
+// CHECK: %[[A_CASTED:.*]] = arith.index_castui %[[A]] exact : vector<4xindex> to vector<4xi8>
+// CHECK: %[[B_CASTED:.*]] = arith.index_castui %[[B]] exact : vector<4xindex> to vector<4xi8>
// CHECK: %[[RES:.*]] = arith.cmpi slt, %[[A_CASTED]], %[[B_CASTED]] : vector<4xi8>
// CHECK: return %[[RES]] : vector<4xi1>
func.func @test_cmpi_vec() -> vector<4xi1> {
@@ -97,10 +97,10 @@ func.func @test_cmpi_vec() -> vector<4xi1> {
// CHECK: %[[A:.*]] = test.with_bounds {smax = 10 : index, smin = 0 : index, umax = 10 : index, umin = 0 : index} : index
// CHECK: %[[B:.*]] = test.with_bounds {smax = 10 : index, smin = 0 : index, umax = 10 : index, umin = 0 : index} : index
// CHECK: %[[C:.*]] = test.with_bounds {smax = 10 : index, smin = 0 : index, umax = 10 : index, umin = 0 : index} : index
-// CHECK: %[[A_CASTED:.*]] = arith.index_castui %[[A]] : index to i8
-// CHECK: %[[B_CASTED:.*]] = arith.index_castui %[[B]] : index to i8
+// CHECK: %[[A_CASTED:.*]] = arith.index_castui %[[A]] exact : index to i8
+// CHECK: %[[B_CASTED:.*]] = arith.index_castui %[[B]] exact : index to i8
// CHECK: %[[RES1:.*]] = arith.addi %[[A_CASTED]], %[[B_CASTED]] : i8
-// CHECK: %[[C_CASTED:.*]] = arith.index_castui %[[C]] : index to i8
+// CHECK: %[[C_CASTED:.*]] = arith.index_castui %[[C]] exact : index to i8
// CHECK: %[[RES2:.*]] = arith.cmpi slt, %[[C_CASTED]], %[[RES1]] : i8
// CHECK: return %[[RES2]] : i1
func.func @test_add_cmpi() -> i1 {
@@ -323,16 +323,16 @@ func.func @i32_overflows_to_i64(%arg0: i32) -> i64 {
// CHECK: %[[BOUND:.+]] = test.with_bounds
// CHECK-SAME: umax = 112
// Loop narrows to i16 (not i8) because indVar+step=[80,144] doesn't fit in signed i8.
-// CHECK: %[[BOUND_I16:.+]] = arith.index_castui %[[BOUND]] : index to i16
+// CHECK: %[[BOUND_I16:.+]] = arith.index_castui %[[BOUND]] exact : index to i16
// CHECK: scf.for %[[ARG0:.+]] = %[[C16_I16]] to %[[BOUND_I16]] step %[[C64_I16]] : i16 {
-// CHECK: %[[ARG0_INDEX:.+]] = arith.index_castui %[[ARG0]] : i16 to index
-// CHECK: %[[BOUND_I8:.+]] = arith.index_castui %[[BOUND]] : index to i8
-// CHECK: %[[ARG0_I8:.+]] = arith.index_castui %[[ARG0_INDEX]] : index to i8
+// CHECK: %[[ARG0_INDEX:.+]] = arith.index_castui %[[ARG0]] exact : i16 to index
+// CHECK: %[[BOUND_I8:.+]] = arith.index_castui %[[BOUND]] exact : index to i8
+// CHECK: %[[ARG0_I8:.+]] = arith.index_castui %[[ARG0_INDEX]] exact : index to i8
// CHECK: %[[V0_I8:.+]] = arith.subi %[[BOUND_I8]], %[[ARG0_I8]] : i8
// CHECK: %[[V1_I8:.+]] = arith.minsi %[[V0_I8]], %[[C64_I8]] : i8
// CHECK: %[[V1_INDEX:.+]] = arith.index_cast %[[V1_I8]] exact : i8 to index
// CHECK: %[[V1_I16:.+]] = arith.index_cast %[[V1_INDEX]] exact : index to i16
-// CHECK: %[[TID_I16:.+]] = arith.index_castui %[[TID]] : index to i16
+// CHECK: %[[TID_I16:.+]] = arith.index_castui %[[TID]] exact : index to i16
// CHECK: %[[V2_I16:.+]] = arith.subi %[[V1_I16]], %[[TID_I16]] : i16
// CHECK: %[[V3:.+]] = arith.cmpi slt, %[[V2_I16]], %[[C0_I16]] : i16
// CHECK: scf.if %[[V3]]
@@ -370,7 +370,7 @@ func.func @loop_with_iter_arg() {
// CHECK: %[[NEG:.*]] = test.with_bounds {smax = 0 : index, smin = -1 : index, umax = -1 : index, umin = 0 : index} : index
// Check iter args are still present
// CHECK: scf.for {{.*}} iter_args({{.*}})
-// CHECK: %[[POS_I8:.*]] = arith.index_castui %[[POS]] : index to i8
+// CHECK: %[[POS_I8:.*]] = arith.index_castui %[[POS]] exact : index to i8
// CHECK: %[[NEG_I8:.*]] = arith.index_cast %[[NEG]] exact : index to i8
// CHECK: %[[RES_I8:.*]] = arith.addi %[[POS_I8]], %[[NEG_I8]] : i8
// CHECK: %[[RES:.*]] = arith.index_cast %[[RES_I8]] exact : i8 to index
>From 180f3cf34e8385b5dd47d5b921e7de6fc0a8afd7 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Tue, 3 Mar 2026 15:52:31 -0500
Subject: [PATCH 3/3] Set nneg flag
---
.../Transforms/IntRangeOptimizations.cpp | 11 ++++-
.../Dialect/Arith/int-range-narrowing.mlir | 46 +++++++++----------
2 files changed, 33 insertions(+), 24 deletions(-)
diff --git a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
index b7f6e525a6a74..81ec8cf3fd863 100644
--- a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
@@ -312,6 +312,12 @@ static Value doCast(OpBuilder &builder, Location loc, Value src, Type dstType,
}
auto cast = arith::IndexCastUIOp::create(builder, loc, dstType, src);
cast.setExact(true);
+ // Narrowing (index -> iN): the unsigned range fits in N < index width
+ // bits, so the top bits including the MSB are all zero.
+ // Widening (iN -> index): the MSB is zero only when the iN value fits
+ // in both signed and unsigned (Both).
+ if (isa<IndexType>(srcElemType) || castKind == CastKind::Both)
+ cast.setNonNeg(true);
return cast;
}
@@ -322,7 +328,10 @@ static Value doCast(OpBuilder &builder, Location loc, Value src, Type dstType,
if (castKind == CastKind::Signed)
return arith::ExtSIOp::create(builder, loc, dstType, src);
- return arith::ExtUIOp::create(builder, loc, dstType, src);
+ auto ext = arith::ExtUIOp::create(builder, loc, dstType, src);
+ if (castKind == CastKind::Both)
+ ext.setNonNeg(true);
+ return ext;
}
struct NarrowElementwise final : OpTraitRewritePattern<OpTrait::Elementwise> {
diff --git a/mlir/test/Dialect/Arith/int-range-narrowing.mlir b/mlir/test/Dialect/Arith/int-range-narrowing.mlir
index e9c232df06b77..5797a0c83bca7 100644
--- a/mlir/test/Dialect/Arith/int-range-narrowing.mlir
+++ b/mlir/test/Dialect/Arith/int-range-narrowing.mlir
@@ -8,7 +8,7 @@
// CHECK-LABEL: func @test_addi_neg
// CHECK: %[[POS:.*]] = test.with_bounds {smax = 1 : index, smin = 0 : index, umax = 1 : index, umin = 0 : index} : index
// CHECK: %[[NEG:.*]] = test.with_bounds {smax = 0 : index, smin = -1 : index, umax = -1 : index, umin = 0 : index} : index
-// CHECK: %[[POS_I8:.*]] = arith.index_castui %[[POS]] exact : index to i8
+// CHECK: %[[POS_I8:.*]] = arith.index_castui %[[POS]] exact nneg : index to i8
// CHECK: %[[NEG_I8:.*]] = arith.index_cast %[[NEG]] exact : index to i8
// CHECK: %[[RES_I8:.*]] = arith.addi %[[POS_I8]], %[[NEG_I8]] : i8
// CHECK: %[[RES:.*]] = arith.index_cast %[[RES_I8]] exact : i8 to index
@@ -23,10 +23,10 @@ func.func @test_addi_neg() -> index {
// CHECK-LABEL: func @test_addi
// CHECK: %[[A:.*]] = test.with_bounds {smax = 5 : index, smin = 4 : index, umax = 5 : index, umin = 4 : index} : index
// CHECK: %[[B:.*]] = test.with_bounds {smax = 7 : index, smin = 6 : index, umax = 7 : index, umin = 6 : index} : index
-// CHECK: %[[A_CASTED:.*]] = arith.index_castui %[[A]] exact : index to i8
-// CHECK: %[[B_CASTED:.*]] = arith.index_castui %[[B]] exact : index to i8
+// CHECK: %[[A_CASTED:.*]] = arith.index_castui %[[A]] exact nneg : index to i8
+// CHECK: %[[B_CASTED:.*]] = arith.index_castui %[[B]] exact nneg : index to i8
// CHECK: %[[RES:.*]] = arith.addi %[[A_CASTED]], %[[B_CASTED]] : i8
-// CHECK: %[[RES_CASTED:.*]] = arith.index_castui %[[RES]] exact : i8 to index
+// CHECK: %[[RES_CASTED:.*]] = arith.index_castui %[[RES]] exact nneg : i8 to index
// CHECK: return %[[RES_CASTED]] : index
func.func @test_addi() -> index {
%0 = test.with_bounds { umin = 4 : index, umax = 5 : index, smin = 4 : index, smax = 5 : index } : index
@@ -38,10 +38,10 @@ func.func @test_addi() -> index {
// CHECK-LABEL: func @test_addi_vec
// CHECK: %[[A:.*]] = test.with_bounds {smax = 5 : index, smin = 4 : index, umax = 5 : index, umin = 4 : index} : vector<4xindex>
// CHECK: %[[B:.*]] = test.with_bounds {smax = 7 : index, smin = 6 : index, umax = 7 : index, umin = 6 : index} : vector<4xindex>
-// CHECK: %[[A_CASTED:.*]] = arith.index_castui %[[A]] exact : vector<4xindex> to vector<4xi8>
-// CHECK: %[[B_CASTED:.*]] = arith.index_castui %[[B]] exact : vector<4xindex> to vector<4xi8>
+// CHECK: %[[A_CASTED:.*]] = arith.index_castui %[[A]] exact nneg : vector<4xindex> to vector<4xi8>
+// CHECK: %[[B_CASTED:.*]] = arith.index_castui %[[B]] exact nneg : vector<4xindex> to vector<4xi8>
// CHECK: %[[RES:.*]] = arith.addi %[[A_CASTED]], %[[B_CASTED]] : vector<4xi8>
-// CHECK: %[[RES_CASTED:.*]] = arith.index_castui %[[RES]] exact : vector<4xi8> to vector<4xindex>
+// CHECK: %[[RES_CASTED:.*]] = arith.index_castui %[[RES]] exact nneg : vector<4xi8> to vector<4xindex>
// CHECK: return %[[RES_CASTED]] : vector<4xindex>
func.func @test_addi_vec() -> vector<4xindex> {
%0 = test.with_bounds { umin = 4 : index, umax = 5 : index, smin = 4 : index, smax = 5 : index } : vector<4xindex>
@@ -56,7 +56,7 @@ func.func @test_addi_vec() -> vector<4xindex> {
// CHECK: %[[A_CASTED:.*]] = arith.trunci %[[A]] : i64 to i8
// CHECK: %[[B_CASTED:.*]] = arith.trunci %[[B]] : i64 to i8
// CHECK: %[[RES:.*]] = arith.addi %[[A_CASTED]], %[[B_CASTED]] : i8
-// CHECK: %[[RES_CASTED:.*]] = arith.extui %[[RES]] : i8 to i64
+// CHECK: %[[RES_CASTED:.*]] = arith.extui %[[RES]] nneg : i8 to i64
// CHECK: return %[[RES_CASTED]] : i64
func.func @test_addi_i64() -> i64 {
%0 = test.with_bounds { umin = 4 : i64, umax = 5 : i64, smin = 4 : i64, smax = 5 : i64 } : i64
@@ -68,8 +68,8 @@ func.func @test_addi_i64() -> i64 {
// CHECK-LABEL: func @test_cmpi
// CHECK: %[[A:.*]] = test.with_bounds {smax = 10 : index, smin = 0 : index, umax = 10 : index, umin = 0 : index} : index
// CHECK: %[[B:.*]] = test.with_bounds {smax = 10 : index, smin = 0 : index, umax = 10 : index, umin = 0 : index} : index
-// CHECK: %[[A_CASTED:.*]] = arith.index_castui %[[A]] exact : index to i8
-// CHECK: %[[B_CASTED:.*]] = arith.index_castui %[[B]] exact : index to i8
+// CHECK: %[[A_CASTED:.*]] = arith.index_castui %[[A]] exact nneg : index to i8
+// CHECK: %[[B_CASTED:.*]] = arith.index_castui %[[B]] exact nneg : index to i8
// CHECK: %[[RES:.*]] = arith.cmpi slt, %[[A_CASTED]], %[[B_CASTED]] : i8
// CHECK: return %[[RES]] : i1
func.func @test_cmpi() -> i1 {
@@ -82,8 +82,8 @@ func.func @test_cmpi() -> i1 {
// CHECK-LABEL: func @test_cmpi_vec
// CHECK: %[[A:.*]] = test.with_bounds {smax = 10 : index, smin = 0 : index, umax = 10 : index, umin = 0 : index} : vector<4xindex>
// CHECK: %[[B:.*]] = test.with_bounds {smax = 10 : index, smin = 0 : index, umax = 10 : index, umin = 0 : index} : vector<4xindex>
-// CHECK: %[[A_CASTED:.*]] = arith.index_castui %[[A]] exact : vector<4xindex> to vector<4xi8>
-// CHECK: %[[B_CASTED:.*]] = arith.index_castui %[[B]] exact : vector<4xindex> to vector<4xi8>
+// CHECK: %[[A_CASTED:.*]] = arith.index_castui %[[A]] exact nneg : vector<4xindex> to vector<4xi8>
+// CHECK: %[[B_CASTED:.*]] = arith.index_castui %[[B]] exact nneg : vector<4xindex> to vector<4xi8>
// CHECK: %[[RES:.*]] = arith.cmpi slt, %[[A_CASTED]], %[[B_CASTED]] : vector<4xi8>
// CHECK: return %[[RES]] : vector<4xi1>
func.func @test_cmpi_vec() -> vector<4xi1> {
@@ -97,10 +97,10 @@ func.func @test_cmpi_vec() -> vector<4xi1> {
// CHECK: %[[A:.*]] = test.with_bounds {smax = 10 : index, smin = 0 : index, umax = 10 : index, umin = 0 : index} : index
// CHECK: %[[B:.*]] = test.with_bounds {smax = 10 : index, smin = 0 : index, umax = 10 : index, umin = 0 : index} : index
// CHECK: %[[C:.*]] = test.with_bounds {smax = 10 : index, smin = 0 : index, umax = 10 : index, umin = 0 : index} : index
-// CHECK: %[[A_CASTED:.*]] = arith.index_castui %[[A]] exact : index to i8
-// CHECK: %[[B_CASTED:.*]] = arith.index_castui %[[B]] exact : index to i8
+// CHECK: %[[A_CASTED:.*]] = arith.index_castui %[[A]] exact nneg : index to i8
+// CHECK: %[[B_CASTED:.*]] = arith.index_castui %[[B]] exact nneg : index to i8
// CHECK: %[[RES1:.*]] = arith.addi %[[A_CASTED]], %[[B_CASTED]] : i8
-// CHECK: %[[C_CASTED:.*]] = arith.index_castui %[[C]] exact : index to i8
+// CHECK: %[[C_CASTED:.*]] = arith.index_castui %[[C]] exact nneg : index to i8
// CHECK: %[[RES2:.*]] = arith.cmpi slt, %[[C_CASTED]], %[[RES1]] : i8
// CHECK: return %[[RES2]] : i1
func.func @test_add_cmpi() -> i1 {
@@ -142,7 +142,7 @@ func.func @test_add_cmpi_i64() -> i1 {
// CHECK-NEXT: %[[LHS:.+]] = arith.trunci %[[EXT0]] : i32 to i16
// CHECK-NEXT: %[[RHS:.+]] = arith.trunci %[[EXT1]] : i32 to i16
// CHECK-NEXT: %[[ADD:.+]] = arith.addi %[[LHS]], %[[RHS]] : i16
-// CHECK-NEXT: %[[RET:.+]] = arith.extui %[[ADD]] : i16 to i32
+// CHECK-NEXT: %[[RET:.+]] = arith.extui %[[ADD]] nneg : i16 to i32
// CHECK-NEXT: return %[[RET]] : i32
func.func @addi_extui_i8(%lhs: i8, %rhs: i8) -> i32 {
%a = arith.extui %lhs : i8 to i32
@@ -323,16 +323,16 @@ func.func @i32_overflows_to_i64(%arg0: i32) -> i64 {
// CHECK: %[[BOUND:.+]] = test.with_bounds
// CHECK-SAME: umax = 112
// Loop narrows to i16 (not i8) because indVar+step=[80,144] doesn't fit in signed i8.
-// CHECK: %[[BOUND_I16:.+]] = arith.index_castui %[[BOUND]] exact : index to i16
+// CHECK: %[[BOUND_I16:.+]] = arith.index_castui %[[BOUND]] exact nneg : index to i16
// CHECK: scf.for %[[ARG0:.+]] = %[[C16_I16]] to %[[BOUND_I16]] step %[[C64_I16]] : i16 {
-// CHECK: %[[ARG0_INDEX:.+]] = arith.index_castui %[[ARG0]] exact : i16 to index
-// CHECK: %[[BOUND_I8:.+]] = arith.index_castui %[[BOUND]] exact : index to i8
-// CHECK: %[[ARG0_I8:.+]] = arith.index_castui %[[ARG0_INDEX]] exact : index to i8
+// CHECK: %[[ARG0_INDEX:.+]] = arith.index_castui %[[ARG0]] exact nneg : i16 to index
+// CHECK: %[[BOUND_I8:.+]] = arith.index_castui %[[BOUND]] exact nneg : index to i8
+// CHECK: %[[ARG0_I8:.+]] = arith.index_castui %[[ARG0_INDEX]] exact nneg : index to i8
// CHECK: %[[V0_I8:.+]] = arith.subi %[[BOUND_I8]], %[[ARG0_I8]] : i8
// CHECK: %[[V1_I8:.+]] = arith.minsi %[[V0_I8]], %[[C64_I8]] : i8
// CHECK: %[[V1_INDEX:.+]] = arith.index_cast %[[V1_I8]] exact : i8 to index
// CHECK: %[[V1_I16:.+]] = arith.index_cast %[[V1_INDEX]] exact : index to i16
-// CHECK: %[[TID_I16:.+]] = arith.index_castui %[[TID]] exact : index to i16
+// CHECK: %[[TID_I16:.+]] = arith.index_castui %[[TID]] exact nneg : index to i16
// CHECK: %[[V2_I16:.+]] = arith.subi %[[V1_I16]], %[[TID_I16]] : i16
// CHECK: %[[V3:.+]] = arith.cmpi slt, %[[V2_I16]], %[[C0_I16]] : i16
// CHECK: scf.if %[[V3]]
@@ -370,7 +370,7 @@ func.func @loop_with_iter_arg() {
// CHECK: %[[NEG:.*]] = test.with_bounds {smax = 0 : index, smin = -1 : index, umax = -1 : index, umin = 0 : index} : index
// Check iter args are still present
// CHECK: scf.for {{.*}} iter_args({{.*}})
-// CHECK: %[[POS_I8:.*]] = arith.index_castui %[[POS]] exact : index to i8
+// CHECK: %[[POS_I8:.*]] = arith.index_castui %[[POS]] exact nneg : index to i8
// CHECK: %[[NEG_I8:.*]] = arith.index_cast %[[NEG]] exact : index to i8
// CHECK: %[[RES_I8:.*]] = arith.addi %[[POS_I8]], %[[NEG_I8]] : i8
// CHECK: %[[RES:.*]] = arith.index_cast %[[RES_I8]] exact : i8 to index
@@ -409,7 +409,7 @@ func.func @narrow_loop_bounds() {
// CHECK-DAG: %[[STEP_I8:.*]] = arith.trunci %[[STEP]] : i64 to i8
// CHECK: scf.for %[[IV:.*]] = %[[LB_I8]] to %[[UB_I8]] step %[[STEP_I8]] : i8 {
// CHECK: %[[ADD_I8:.*]] = arith.addi %[[IV]], %[[C1_I8]] : i8
- // CHECK: %[[ADD_I64:.*]] = arith.extui %[[ADD_I8]] : i8 to i64
+ // CHECK: %[[ADD_I64:.*]] = arith.extui %[[ADD_I8]] nneg : i8 to i64
// CHECK: call @use_i64(%[[ADD_I64]])
scf.for %iv = %lb to %ub step %step : i64 {
%add = arith.addi %iv, %c1_i64 : i64
More information about the Mlir-commits
mailing list