[Mlir-commits] [mlir] [mlir][arith] Move canonicalization patterns to fold (PR #184381)

Erick Ochoa Lopez llvmlistbot at llvm.org
Tue Mar 3 08:59:05 PST 2026


https://github.com/amd-eochoalo created https://github.com/llvm/llvm-project/pull/184381

https://github.com/llvm/llvm-project/pull/183395 added canonicalization patterns `cast(cast(x, exact)) -> x`. 
[One review comment stated the following:](https://github.com/llvm/llvm-project/pull/183395#discussion_r2866412276)

> Other note to future us: these should probably be folders, not canonicalization patterns. (the easy ones where there are matching pairs to cancel)

Moving these to be folders allows one to reuse this patterns in the narrow integer ranges which will run folders but not canonicalization. This allows us to remove `FoldIndexCastChain`. 

>From ca1cdd5f808f93fd0a256ad35fa21d560162e45d Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Tue, 3 Mar 2026 11:04:30 -0500
Subject: [PATCH 1/2] Move canonicalization to fold

---
 .../Dialect/Arith/IR/ArithCanonicalization.td |  9 -----
 mlir/lib/Dialect/Arith/IR/ArithOps.cpp        | 10 ++++-
 .../Transforms/IntRangeOptimizations.cpp      | 12 +++---
 mlir/test/Dialect/Arith/canonicalize.mlir     | 37 +++++++++++++++++++
 .../Dialect/Arith/int-range-narrowing.mlir    | 12 +++---
 mlir/test/Transforms/canonicalize.mlir        |  9 -----
 6 files changed, 59 insertions(+), 30 deletions(-)

diff --git a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
index f26af4816ce85..baabbca99e37c 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
+++ b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
@@ -288,15 +288,6 @@ def SelectI1ToNot :
 // IndexCastOp
 //===----------------------------------------------------------------------===//
 
-// index_cast(index_cast(x, exact)) -> x, if dstType == srcType.
-// The inner exact guarantees the iN -> index conversion is lossless,
-// so the roundtrip through index preserves the value.
-def IndexCastOfIndexCast :
-    Pat<(Arith_IndexCastOp:$res (Arith_IndexCastOp $x, $exact1), $exact2),
-        (replaceWithValue $x),
-        [(Constraint<CPred<"$0.getType() == $1.getType()">> $res, $x),
-         (Constraint<CPred<"(bool)$0">> $exact1)]>;
-
 // index_cast(extsi(x)) -> index_cast(x)
 def IndexCastOfExtSI :
     Pat<(Arith_IndexCastOp (Arith_ExtSIOp $x), $exact),
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index b99f77fdc8b30..ad837bea01510 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -1845,6 +1845,14 @@ bool arith::IndexCastOp::areCastCompatible(TypeRange inputs,
 }
 
 OpFoldResult arith::IndexCastOp::fold(FoldAdaptor adaptor) {
+  // index_cast(index_cast(x, exact)) -> x, when result type matches x's type.
+  // The inner exact guarantees the iN -> index conversion is lossless,
+  // so the roundtrip through index preserves the value.
+  if (auto innerCast = getIn().getDefiningOp<IndexCastOp>()) {
+    if (innerCast.getIn().getType() == getType() && innerCast.getIsExact())
+      return innerCast.getIn();
+  }
+
   // index_cast(constant) -> constant
   unsigned resultBitwidth = 64; // Default for index integer attributes.
   if (auto intTy = dyn_cast<IntegerType>(getElementTypeOrSelf(getType())))
@@ -1859,7 +1867,7 @@ OpFoldResult arith::IndexCastOp::fold(FoldAdaptor adaptor) {
 
 void arith::IndexCastOp::getCanonicalizationPatterns(
     RewritePatternSet &patterns, MLIRContext *context) {
-  patterns.add<IndexCastOfIndexCast, IndexCastOfExtSI>(context);
+  patterns.add<IndexCastOfExtSI>(context);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
index fefbba989b996..df3eea25a8d20 100644
--- a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
@@ -305,8 +305,11 @@ static Value doCast(OpBuilder &builder, Location loc, Value src, Type dstType,
     return src;
 
   if (isa<IndexType>(srcElemType) || isa<IndexType>(dstElemType)) {
-    if (castKind == CastKind::Signed)
-      return arith::IndexCastOp::create(builder, loc, dstType, src);
+    if (castKind == CastKind::Signed) {
+      auto cast = arith::IndexCastOp::create(builder, loc, dstType, src);
+      cast.setExact(true);
+      return cast;
+    }
     return arith::IndexCastUIOp::create(builder, loc, dstType, src);
   }
 
@@ -725,9 +728,8 @@ void mlir::arith::populateIntRangeNarrowingPatterns(
     ArrayRef<unsigned> bitwidthsSupported) {
   patterns.add<NarrowElementwise, NarrowCmpI>(patterns.getContext(), solver,
                                               bitwidthsSupported);
-  patterns.add<FoldIndexCastChain<arith::IndexCastUIOp>,
-               FoldIndexCastChain<arith::IndexCastOp>>(patterns.getContext(),
-                                                       bitwidthsSupported);
+  patterns.add<FoldIndexCastChain<arith::IndexCastUIOp>>(patterns.getContext(),
+                                                        bitwidthsSupported);
 }
 
 void mlir::arith::populateControlFlowValuesNarrowingPatterns(
diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index a2d0eff47ad92..0a5a7cc56517a 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -643,6 +643,43 @@ func.func @indexCastUIOfUnsignedExtend_nneg_exact(%arg0: i8) -> index {
   return %idx : index
 }
 
+// index_cast(index_cast(x)) -> x only when exact is on the inner cast.
+// CHECK-LABEL: @indexCastOfIndexCast_no_exact
+//       CHECK:   arith.index_cast
+//       CHECK:   arith.index_cast
+func.func @indexCastOfIndexCast_no_exact(%arg0: i16) -> i16 {
+  %idx = arith.index_cast %arg0 : i16 to index
+  %res = arith.index_cast %idx : index to i16
+  return %res : i16
+}
+
+// CHECK-LABEL: @indexCastOfIndexCast_exact_inner
+//       CHECK:   return %arg0 : i16
+func.func @indexCastOfIndexCast_exact_inner(%arg0: i16) -> i16 {
+  %idx = arith.index_cast %arg0 exact : i16 to index
+  %res = arith.index_cast %idx : index to i16
+  return %res : i16
+}
+
+// exact on outer only does NOT trigger the fold (outer exact on widening
+// is vacuously true and does not guarantee the inner truncation is lossless).
+// CHECK-LABEL: @indexCastOfIndexCast_exact_outer
+//       CHECK:   arith.index_cast
+//       CHECK:   arith.index_cast
+func.func @indexCastOfIndexCast_exact_outer(%arg0: i16) -> i16 {
+  %idx = arith.index_cast %arg0 : i16 to index
+  %res = arith.index_cast %idx exact : index to i16
+  return %res : i16
+}
+
+// CHECK-LABEL: @indexCastOfIndexCast_exact_both
+//       CHECK:   return %arg0 : i16
+func.func @indexCastOfIndexCast_exact_both(%arg0: i16) -> i16 {
+  %idx = arith.index_cast %arg0 exact : i16 to index
+  %res = arith.index_cast %idx exact : index to i16
+  return %res : i16
+}
+
 // index_castui(index_castui(x)) -> x only when exact is on the inner cast.
 // CHECK-LABEL: @indexCastUIOfIndexCastUI_no_exact
 //       CHECK:   arith.index_castui
diff --git a/mlir/test/Dialect/Arith/int-range-narrowing.mlir b/mlir/test/Dialect/Arith/int-range-narrowing.mlir
index c3b0d280b1350..5a26d1d0e7fa4 100644
--- a/mlir/test/Dialect/Arith/int-range-narrowing.mlir
+++ b/mlir/test/Dialect/Arith/int-range-narrowing.mlir
@@ -9,9 +9,9 @@
 //       CHECK:  %[[POS:.*]] = test.with_bounds {smax = 1 : index, smin = 0 : index, umax = 1 : index, umin = 0 : index} : index
 //       CHECK:  %[[NEG:.*]] = test.with_bounds {smax = 0 : index, smin = -1 : index, umax = -1 : index, umin = 0 : index} : index
 //       CHECK:  %[[POS_I8:.*]] = arith.index_castui %[[POS]] : index to i8
-//       CHECK:  %[[NEG_I8:.*]] = arith.index_cast %[[NEG]] : index to i8
+//       CHECK:  %[[NEG_I8:.*]] = arith.index_cast %[[NEG]] exact : index to i8
 //       CHECK:  %[[RES_I8:.*]] = arith.addi %[[POS_I8]], %[[NEG_I8]] : i8
-//       CHECK:  %[[RES:.*]] = arith.index_cast %[[RES_I8]] : i8 to index
+//       CHECK:  %[[RES:.*]] = arith.index_cast %[[RES_I8]] exact : i8 to index
 //       CHECK:  return %[[RES]] : index
 func.func @test_addi_neg() -> index {
   %0 = test.with_bounds { umin = 0 : index, umax = 1 : index, smin = 0 : index, smax = 1 : index } : index
@@ -330,8 +330,8 @@ func.func @i32_overflows_to_i64(%arg0: i32) -> i64 {
 // CHECK:   %[[ARG0_I8:.+]] = arith.index_castui %[[ARG0_INDEX]] : index to i8
 // CHECK:   %[[V0_I8:.+]] = arith.subi %[[BOUND_I8]], %[[ARG0_I8]] : i8
 // CHECK:   %[[V1_I8:.+]] = arith.minsi %[[V0_I8]], %[[C64_I8]] : i8
-// CHECK:   %[[V1_INDEX:.+]] = arith.index_cast %[[V1_I8]] : i8 to index
-// CHECK:   %[[V1_I16:.+]] = arith.index_cast %[[V1_INDEX]] : index to i16
+// CHECK:   %[[V1_INDEX:.+]] = arith.index_cast %[[V1_I8]] exact : i8 to index
+// CHECK:   %[[V1_I16:.+]] = arith.index_cast %[[V1_INDEX]] exact : index to i16
 // CHECK:   %[[TID_I16:.+]] = arith.index_castui %[[TID]] : index to i16
 // CHECK:   %[[V2_I16:.+]] = arith.subi %[[V1_I16]], %[[TID_I16]] : i16
 // CHECK:   %[[V3:.+]] = arith.cmpi slt, %[[V2_I16]], %[[C0_I16]] : i16
@@ -371,9 +371,9 @@ func.func @loop_with_iter_arg() {
 // Check iter args are still present
 //       CHECK:  scf.for {{.*}} iter_args({{.*}})
 //       CHECK:  %[[POS_I8:.*]] = arith.index_castui %[[POS]] : index to i8
-//       CHECK:  %[[NEG_I8:.*]] = arith.index_cast %[[NEG]] : index to i8
+//       CHECK:  %[[NEG_I8:.*]] = arith.index_cast %[[NEG]] exact : index to i8
 //       CHECK:  %[[RES_I8:.*]] = arith.addi %[[POS_I8]], %[[NEG_I8]] : i8
-//       CHECK:  %[[RES:.*]] = arith.index_cast %[[RES_I8]] : i8 to index
+//       CHECK:  %[[RES:.*]] = arith.index_cast %[[RES_I8]] exact : i8 to index
 //       CHECK:  call @use(%[[RES]])
 
   %0 = test.with_bounds { umin = 0 : index, umax = 1 : index, smin = 0 : index, smax = 1 : index } : index
diff --git a/mlir/test/Transforms/canonicalize.mlir b/mlir/test/Transforms/canonicalize.mlir
index 3930575c45b3e..3ea6619ee589d 100644
--- a/mlir/test/Transforms/canonicalize.mlir
+++ b/mlir/test/Transforms/canonicalize.mlir
@@ -892,15 +892,6 @@ func.func @subview(%arg0 : index, %arg1 : index) -> (index, index) {
   return %7, %8 : index, index
 }
 
-// CHECK-LABEL: func @index_cast
-// CHECK-SAME: %[[ARG_0:arg[0-9]+]]: i16
-func.func @index_cast(%arg0: i16) -> (i16) {
-  %11 = arith.index_cast %arg0 exact : i16 to index
-  %12 = arith.index_cast %11 : index to i16
-  // CHECK: return %[[ARG_0]] : i16
-  return %12 : i16
-}
-
 // CHECK-LABEL: func @index_cast_fold
 func.func @index_cast_fold() -> (i16, index) {
   %c4 = arith.constant 4 : index

>From cc13e2161e9dd42bd3bfdb59e7cb4f7688050509 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Tue, 3 Mar 2026 11:23:15 -0500
Subject: [PATCH 2/2] Move cast_ui to fold

---
 .../Dialect/Arith/IR/ArithCanonicalization.td | 10 -----
 mlir/lib/Dialect/Arith/IR/ArithOps.cpp        |  8 +++-
 .../Transforms/IntRangeOptimizations.cpp      | 39 ++----------------
 .../Dialect/Arith/int-range-narrowing.mlir    | 40 +++++++++----------
 4 files changed, 30 insertions(+), 67 deletions(-)

diff --git a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
index baabbca99e37c..0d87092b7721f 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
+++ b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
@@ -297,16 +297,6 @@ def IndexCastOfExtSI :
 // IndexCastUIOp
 //===----------------------------------------------------------------------===//
 
-// index_castui(index_castui(x, exact)) -> x, if dstType == srcType.
-// The inner exact guarantees the iN -> index conversion is lossless,
-// so the roundtrip through index preserves the value.
-def IndexCastUIOfIndexCastUI :
-    Pat<(Arith_IndexCastUIOp:$res
-          (Arith_IndexCastUIOp $x, $nneg1, $exact1), $nneg2, $exact2),
-        (replaceWithValue $x),
-        [(Constraint<CPred<"$0.getType() == $1.getType()">> $res, $x),
-         (Constraint<CPred<"static_cast<bool>($0)">> $exact1)]>;
-
 // index_castui(extui(x)) -> index_castui(x)
 def IndexCastUIOfExtUI :
     Pat<(Arith_IndexCastUIOp (Arith_ExtUIOp $x, $nneg1), $nneg2, $exact),
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index ad837bea01510..625e2c3bb13fc 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -1880,6 +1880,12 @@ bool arith::IndexCastUIOp::areCastCompatible(TypeRange inputs,
 }
 
 OpFoldResult arith::IndexCastUIOp::fold(FoldAdaptor adaptor) {
+  // index_castui(index_castui(x, exact)) -> x
+  if (auto innerCast = getIn().getDefiningOp<IndexCastUIOp>()) {
+    if (innerCast.getIn().getType() == getType() && innerCast.getIsExact())
+      return innerCast.getIn();
+  }
+
   // index_castui(constant) -> constant
   unsigned resultBitwidth = 64; // Default for index integer attributes.
   if (auto intTy = dyn_cast<IntegerType>(getElementTypeOrSelf(getType())))
@@ -1894,7 +1900,7 @@ OpFoldResult arith::IndexCastUIOp::fold(FoldAdaptor adaptor) {
 
 void arith::IndexCastUIOp::getCanonicalizationPatterns(
     RewritePatternSet &patterns, MLIRContext *context) {
-  patterns.add<IndexCastUIOfIndexCastUI, IndexCastUIOfExtUI>(context);
+  patterns.add<IndexCastUIOfExtUI>(context);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
index df3eea25a8d20..b7f6e525a6a74 100644
--- a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
@@ -310,7 +310,9 @@ static Value doCast(OpBuilder &builder, Location loc, Value src, Type dstType,
       cast.setExact(true);
       return cast;
     }
-    return arith::IndexCastUIOp::create(builder, loc, dstType, src);
+    auto cast = arith::IndexCastUIOp::create(builder, loc, dstType, src);
+    cast.setExact(true);
+    return cast;
   }
 
   auto srcInt = cast<IntegerType>(srcElemType);
@@ -448,39 +450,6 @@ struct NarrowCmpI final : OpRewritePattern<arith::CmpIOp> {
   SmallVector<unsigned, 4> targetBitwidths;
 };
 
-/// Fold index_cast(index_cast(%arg: i8, index), i8) -> %arg
-/// This pattern assumes all passed `targetBitwidths` are not wider than index
-/// type.
-template <typename CastOp>
-struct FoldIndexCastChain final : OpRewritePattern<CastOp> {
-  FoldIndexCastChain(MLIRContext *context, ArrayRef<unsigned> target)
-      : OpRewritePattern<CastOp>(context), targetBitwidths(target) {}
-
-  LogicalResult matchAndRewrite(CastOp op,
-                                PatternRewriter &rewriter) const override {
-    auto srcOp = op.getIn().template getDefiningOp<CastOp>();
-    if (!srcOp)
-      return rewriter.notifyMatchFailure(op, "doesn't come from an index cast");
-
-    Value src = srcOp.getIn();
-    if (src.getType() != op.getType())
-      return rewriter.notifyMatchFailure(op, "outer types don't match");
-
-    if (!srcOp.getType().isIndex())
-      return rewriter.notifyMatchFailure(op, "intermediate type isn't index");
-
-    auto intType = dyn_cast<IntegerType>(op.getType());
-    if (!intType || !llvm::is_contained(targetBitwidths, intType.getWidth()))
-      return failure();
-
-    rewriter.replaceOp(op, src);
-    return success();
-  }
-
-private:
-  SmallVector<unsigned, 4> targetBitwidths;
-};
-
 struct NarrowLoopBounds final : OpInterfaceRewritePattern<LoopLikeOpInterface> {
   NarrowLoopBounds(MLIRContext *context, DataFlowSolver &s,
                    ArrayRef<unsigned> target)
@@ -728,8 +697,6 @@ void mlir::arith::populateIntRangeNarrowingPatterns(
     ArrayRef<unsigned> bitwidthsSupported) {
   patterns.add<NarrowElementwise, NarrowCmpI>(patterns.getContext(), solver,
                                               bitwidthsSupported);
-  patterns.add<FoldIndexCastChain<arith::IndexCastUIOp>>(patterns.getContext(),
-                                                        bitwidthsSupported);
 }
 
 void mlir::arith::populateControlFlowValuesNarrowingPatterns(
diff --git a/mlir/test/Dialect/Arith/int-range-narrowing.mlir b/mlir/test/Dialect/Arith/int-range-narrowing.mlir
index 5a26d1d0e7fa4..e9c232df06b77 100644
--- a/mlir/test/Dialect/Arith/int-range-narrowing.mlir
+++ b/mlir/test/Dialect/Arith/int-range-narrowing.mlir
@@ -8,7 +8,7 @@
 // CHECK-LABEL: func @test_addi_neg
 //       CHECK:  %[[POS:.*]] = test.with_bounds {smax = 1 : index, smin = 0 : index, umax = 1 : index, umin = 0 : index} : index
 //       CHECK:  %[[NEG:.*]] = test.with_bounds {smax = 0 : index, smin = -1 : index, umax = -1 : index, umin = 0 : index} : index
-//       CHECK:  %[[POS_I8:.*]] = arith.index_castui %[[POS]] : index to i8
+//       CHECK:  %[[POS_I8:.*]] = arith.index_castui %[[POS]] exact : index to i8
 //       CHECK:  %[[NEG_I8:.*]] = arith.index_cast %[[NEG]] exact : index to i8
 //       CHECK:  %[[RES_I8:.*]] = arith.addi %[[POS_I8]], %[[NEG_I8]] : i8
 //       CHECK:  %[[RES:.*]] = arith.index_cast %[[RES_I8]] exact : i8 to index
@@ -23,10 +23,10 @@ func.func @test_addi_neg() -> index {
 // CHECK-LABEL: func @test_addi
 //       CHECK:  %[[A:.*]] = test.with_bounds {smax = 5 : index, smin = 4 : index, umax = 5 : index, umin = 4 : index} : index
 //       CHECK:  %[[B:.*]] = test.with_bounds {smax = 7 : index, smin = 6 : index, umax = 7 : index, umin = 6 : index} : index
-//       CHECK:  %[[A_CASTED:.*]] = arith.index_castui %[[A]] : index to i8
-//       CHECK:  %[[B_CASTED:.*]] = arith.index_castui %[[B]] : index to i8
+//       CHECK:  %[[A_CASTED:.*]] = arith.index_castui %[[A]] exact : index to i8
+//       CHECK:  %[[B_CASTED:.*]] = arith.index_castui %[[B]] exact : index to i8
 //       CHECK:  %[[RES:.*]] = arith.addi %[[A_CASTED]], %[[B_CASTED]] : i8
-//       CHECK:  %[[RES_CASTED:.*]] = arith.index_castui %[[RES]] : i8 to index
+//       CHECK:  %[[RES_CASTED:.*]] = arith.index_castui %[[RES]] exact : i8 to index
 //       CHECK:  return %[[RES_CASTED]] : index
 func.func @test_addi() -> index {
   %0 = test.with_bounds { umin = 4 : index, umax = 5 : index, smin = 4 : index, smax = 5 : index } : index
@@ -38,10 +38,10 @@ func.func @test_addi() -> index {
 // CHECK-LABEL: func @test_addi_vec
 //       CHECK:  %[[A:.*]] = test.with_bounds {smax = 5 : index, smin = 4 : index, umax = 5 : index, umin = 4 : index} : vector<4xindex>
 //       CHECK:  %[[B:.*]] = test.with_bounds {smax = 7 : index, smin = 6 : index, umax = 7 : index, umin = 6 : index} : vector<4xindex>
-//       CHECK:  %[[A_CASTED:.*]] = arith.index_castui %[[A]] : vector<4xindex> to vector<4xi8>
-//       CHECK:  %[[B_CASTED:.*]] = arith.index_castui %[[B]] : vector<4xindex> to vector<4xi8>
+//       CHECK:  %[[A_CASTED:.*]] = arith.index_castui %[[A]] exact : vector<4xindex> to vector<4xi8>
+//       CHECK:  %[[B_CASTED:.*]] = arith.index_castui %[[B]] exact : vector<4xindex> to vector<4xi8>
 //       CHECK:  %[[RES:.*]] = arith.addi %[[A_CASTED]], %[[B_CASTED]] : vector<4xi8>
-//       CHECK:  %[[RES_CASTED:.*]] = arith.index_castui %[[RES]] : vector<4xi8> to vector<4xindex>
+//       CHECK:  %[[RES_CASTED:.*]] = arith.index_castui %[[RES]] exact : vector<4xi8> to vector<4xindex>
 //       CHECK:  return %[[RES_CASTED]] : vector<4xindex>
 func.func @test_addi_vec() -> vector<4xindex> {
   %0 = test.with_bounds { umin = 4 : index, umax = 5 : index, smin = 4 : index, smax = 5 : index } : vector<4xindex>
@@ -68,8 +68,8 @@ func.func @test_addi_i64() -> i64 {
 // CHECK-LABEL: func @test_cmpi
 //       CHECK:  %[[A:.*]] = test.with_bounds {smax = 10 : index, smin = 0 : index, umax = 10 : index, umin = 0 : index} : index
 //       CHECK:  %[[B:.*]] = test.with_bounds {smax = 10 : index, smin = 0 : index, umax = 10 : index, umin = 0 : index} : index
-//       CHECK:  %[[A_CASTED:.*]] = arith.index_castui %[[A]] : index to i8
-//       CHECK:  %[[B_CASTED:.*]] = arith.index_castui %[[B]] : index to i8
+//       CHECK:  %[[A_CASTED:.*]] = arith.index_castui %[[A]] exact : index to i8
+//       CHECK:  %[[B_CASTED:.*]] = arith.index_castui %[[B]] exact : index to i8
 //       CHECK:  %[[RES:.*]] = arith.cmpi slt, %[[A_CASTED]], %[[B_CASTED]] : i8
 //       CHECK:  return %[[RES]] : i1
 func.func @test_cmpi() -> i1 {
@@ -82,8 +82,8 @@ func.func @test_cmpi() -> i1 {
 // CHECK-LABEL: func @test_cmpi_vec
 //       CHECK:  %[[A:.*]] = test.with_bounds {smax = 10 : index, smin = 0 : index, umax = 10 : index, umin = 0 : index} : vector<4xindex>
 //       CHECK:  %[[B:.*]] = test.with_bounds {smax = 10 : index, smin = 0 : index, umax = 10 : index, umin = 0 : index} : vector<4xindex>
-//       CHECK:  %[[A_CASTED:.*]] = arith.index_castui %[[A]] : vector<4xindex> to vector<4xi8>
-//       CHECK:  %[[B_CASTED:.*]] = arith.index_castui %[[B]] : vector<4xindex> to vector<4xi8>
+//       CHECK:  %[[A_CASTED:.*]] = arith.index_castui %[[A]] exact : vector<4xindex> to vector<4xi8>
+//       CHECK:  %[[B_CASTED:.*]] = arith.index_castui %[[B]] exact : vector<4xindex> to vector<4xi8>
 //       CHECK:  %[[RES:.*]] = arith.cmpi slt, %[[A_CASTED]], %[[B_CASTED]] : vector<4xi8>
 //       CHECK:  return %[[RES]] : vector<4xi1>
 func.func @test_cmpi_vec() -> vector<4xi1> {
@@ -97,10 +97,10 @@ func.func @test_cmpi_vec() -> vector<4xi1> {
 //       CHECK:  %[[A:.*]] = test.with_bounds {smax = 10 : index, smin = 0 : index, umax = 10 : index, umin = 0 : index} : index
 //       CHECK:  %[[B:.*]] = test.with_bounds {smax = 10 : index, smin = 0 : index, umax = 10 : index, umin = 0 : index} : index
 //       CHECK:  %[[C:.*]] = test.with_bounds {smax = 10 : index, smin = 0 : index, umax = 10 : index, umin = 0 : index} : index
-//       CHECK:  %[[A_CASTED:.*]] = arith.index_castui %[[A]] : index to i8
-//       CHECK:  %[[B_CASTED:.*]] = arith.index_castui %[[B]] : index to i8
+//       CHECK:  %[[A_CASTED:.*]] = arith.index_castui %[[A]] exact : index to i8
+//       CHECK:  %[[B_CASTED:.*]] = arith.index_castui %[[B]] exact : index to i8
 //       CHECK:  %[[RES1:.*]] = arith.addi %[[A_CASTED]], %[[B_CASTED]] : i8
-//       CHECK:  %[[C_CASTED:.*]] = arith.index_castui %[[C]] : index to i8
+//       CHECK:  %[[C_CASTED:.*]] = arith.index_castui %[[C]] exact : index to i8
 //       CHECK:  %[[RES2:.*]] = arith.cmpi slt, %[[C_CASTED]], %[[RES1]] : i8
 //       CHECK:  return %[[RES2]] : i1
 func.func @test_add_cmpi() -> i1 {
@@ -323,16 +323,16 @@ func.func @i32_overflows_to_i64(%arg0: i32) -> i64 {
 // CHECK: %[[BOUND:.+]] = test.with_bounds
 // CHECK-SAME: umax = 112
 // Loop narrows to i16 (not i8) because indVar+step=[80,144] doesn't fit in signed i8.
-// CHECK: %[[BOUND_I16:.+]] = arith.index_castui %[[BOUND]] : index to i16
+// CHECK: %[[BOUND_I16:.+]] = arith.index_castui %[[BOUND]] exact : index to i16
 // CHECK: scf.for %[[ARG0:.+]] = %[[C16_I16]] to %[[BOUND_I16]] step %[[C64_I16]]  : i16 {
-// CHECK:   %[[ARG0_INDEX:.+]] = arith.index_castui %[[ARG0]] : i16 to index
-// CHECK:   %[[BOUND_I8:.+]] = arith.index_castui %[[BOUND]] : index to i8
-// CHECK:   %[[ARG0_I8:.+]] = arith.index_castui %[[ARG0_INDEX]] : index to i8
+// CHECK:   %[[ARG0_INDEX:.+]] = arith.index_castui %[[ARG0]] exact : i16 to index
+// CHECK:   %[[BOUND_I8:.+]] = arith.index_castui %[[BOUND]] exact : index to i8
+// CHECK:   %[[ARG0_I8:.+]] = arith.index_castui %[[ARG0_INDEX]] exact : index to i8
 // CHECK:   %[[V0_I8:.+]] = arith.subi %[[BOUND_I8]], %[[ARG0_I8]] : i8
 // CHECK:   %[[V1_I8:.+]] = arith.minsi %[[V0_I8]], %[[C64_I8]] : i8
 // CHECK:   %[[V1_INDEX:.+]] = arith.index_cast %[[V1_I8]] exact : i8 to index
 // CHECK:   %[[V1_I16:.+]] = arith.index_cast %[[V1_INDEX]] exact : index to i16
-// CHECK:   %[[TID_I16:.+]] = arith.index_castui %[[TID]] : index to i16
+// CHECK:   %[[TID_I16:.+]] = arith.index_castui %[[TID]] exact : index to i16
 // CHECK:   %[[V2_I16:.+]] = arith.subi %[[V1_I16]], %[[TID_I16]] : i16
 // CHECK:   %[[V3:.+]] = arith.cmpi slt, %[[V2_I16]], %[[C0_I16]] : i16
 // CHECK:   scf.if %[[V3]]
@@ -370,7 +370,7 @@ func.func @loop_with_iter_arg() {
 //       CHECK:  %[[NEG:.*]] = test.with_bounds {smax = 0 : index, smin = -1 : index, umax = -1 : index, umin = 0 : index} : index
 // Check iter args are still present
 //       CHECK:  scf.for {{.*}} iter_args({{.*}})
-//       CHECK:  %[[POS_I8:.*]] = arith.index_castui %[[POS]] : index to i8
+//       CHECK:  %[[POS_I8:.*]] = arith.index_castui %[[POS]] exact : index to i8
 //       CHECK:  %[[NEG_I8:.*]] = arith.index_cast %[[NEG]] exact : index to i8
 //       CHECK:  %[[RES_I8:.*]] = arith.addi %[[POS_I8]], %[[NEG_I8]] : i8
 //       CHECK:  %[[RES:.*]] = arith.index_cast %[[RES_I8]] exact : i8 to index



More information about the Mlir-commits mailing list