[Mlir-commits] [mlir] [MLIR][Arith] Fix index_cast/index_castui chain folding to check intermediate width (PR #189042)

Mehdi Amini llvmlistbot at llvm.org
Fri Mar 27 09:15:47 PDT 2026


https://github.com/joker-eph created https://github.com/llvm/llvm-project/pull/189042

The patterns `IndexCastOfIndexCast` and `IndexCastUIOfIndexCastUI` in ArithCanonicalization.td incorrectly eliminated a pair of index casts whenever the outer result type equalled the original source type, without verifying that the intermediate cast was lossless.

For example, the following was wrong folded to `%arg0`:
  %0 = index_castui %arg0 : i64 to index
  %1 = index_castui %0    : index to i8    ← truncates to 8 bits
  %2 = index_castui %1    : i8 to index    ← incorrectly removed

The pattern matched `%1`/`%2` because `i8.to(index)` has the same result type as `i64.to(index)`, even though the i8 intermediate silently drops 56 bits. The same bug existed for the signed `index_cast` variant.

Fix: move the optimization into the `fold` methods of `IndexCastOp` and `IndexCastUIOp` with an explicit check that the intermediate type is at least as wide as the source type (using `IndexType::kInternalStorageBitWidth` as the representative width for `index`). Only then is the round-trip guaranteed lossless and the chain can be collapsed.

Fixes #90238

Assisted-by: Claude Code

>From 22fdc9b277fee644d25bde2ac3cb1f9cf5bf71c8 Mon Sep 17 00:00:00 2001
From: Mehdi Amini <joker.eph at gmail.com>
Date: Fri, 27 Mar 2026 06:23:42 -0700
Subject: [PATCH] [MLIR][Arith] Fix index_cast/index_castui chain folding to
 check intermediate width
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

The patterns `IndexCastOfIndexCast` and `IndexCastUIOfIndexCastUI` in
ArithCanonicalization.td incorrectly eliminated a pair of index casts
whenever the outer result type equalled the original source type, without
verifying that the intermediate cast was lossless.

For example, the following was wrong folded to `%arg0`:
  %0 = index_castui %arg0 : i64 to index
  %1 = index_castui %0    : index to i8    ← truncates to 8 bits
  %2 = index_castui %1    : i8 to index    ← incorrectly removed

The pattern matched `%1`/`%2` because `i8.to(index)` has the same result
type as `i64.to(index)`, even though the i8 intermediate silently drops
56 bits. The same bug existed for the signed `index_cast` variant.

Fix: move the optimization into the `fold` methods of `IndexCastOp` and
`IndexCastUIOp` with an explicit check that the intermediate type is at
least as wide as the source type (using `IndexType::kInternalStorageBitWidth`
as the representative width for `index`). Only then is the round-trip
guaranteed lossless and the chain can be collapsed.

Fixes #90238

Assisted-by: Claude Code
---
 .../Dialect/Arith/IR/ArithCanonicalization.td |  12 --
 mlir/lib/Dialect/Arith/IR/ArithOps.cpp        |  59 ++++++++--
 mlir/test/Dialect/Arith/canonicalize.mlir     | 107 ++++++++++++++++++
 3 files changed, 154 insertions(+), 24 deletions(-)

diff --git a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
index e22fc1d478e4f..a15e19b24e54b 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
+++ b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
@@ -288,12 +288,6 @@ def SelectI1ToNot :
 // IndexCastOp
 //===----------------------------------------------------------------------===//
 
-// index_cast(index_cast(x)) -> x, if dstType == srcType.
-def IndexCastOfIndexCast :
-    Pat<(Arith_IndexCastOp:$res (Arith_IndexCastOp $x)),
-        (replaceWithValue $x),
-        [(Constraint<CPred<"$0.getType() == $1.getType()">> $res, $x)]>;
-
 // index_cast(extsi(x)) -> index_cast(x)
 def IndexCastOfExtSI :
     Pat<(Arith_IndexCastOp (Arith_ExtSIOp $x)), (Arith_IndexCastOp $x)>;
@@ -302,12 +296,6 @@ def IndexCastOfExtSI :
 // IndexCastUIOp
 //===----------------------------------------------------------------------===//
 
-// index_castui(index_castui(x)) -> x, if dstType == srcType.
-def IndexCastUIOfIndexCastUI :
-    Pat<(Arith_IndexCastUIOp:$res (Arith_IndexCastUIOp $x, $nneg1), $nneg2),
-        (replaceWithValue $x),
-        [(Constraint<CPred<"$0.getType() == $1.getType()">> $res, $x)]>;
-
 // index_castui(extui(x)) -> index_castui(x)
 def IndexCastUIOfExtUI :
     Pat<(Arith_IndexCastUIOp (Arith_ExtUIOp $x, $nneg1), $nneg2),
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 5f10a94522350..569d1869a5abe 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -1909,6 +1909,15 @@ OpFoldResult arith::FPToSIOp::fold(FoldAdaptor adaptor) {
 // IndexCastOp
 //===----------------------------------------------------------------------===//
 
+/// Return the bit-width of \p t for the purpose of index_cast width checks.
+/// For vector types use the element type; index maps to its internal storage
+/// width (64 on all current targets).
+static unsigned getIndexCastWidth(Type t) {
+  if (auto intTy = dyn_cast<IntegerType>(getElementTypeOrSelf(t)))
+    return intTy.getWidth();
+  return IndexType::kInternalStorageBitWidth;
+}
+
 static bool areIndexCastCompatible(TypeRange inputs, TypeRange outputs) {
   if (!areValidCastInputsAndOutputs(inputs, outputs))
     return false;
@@ -1933,16 +1942,29 @@ OpFoldResult arith::IndexCastOp::fold(FoldAdaptor adaptor) {
   if (auto intTy = dyn_cast<IntegerType>(getElementTypeOrSelf(getType())))
     resultBitwidth = intTy.getWidth();
 
-  return constFoldCastOp<IntegerAttr, IntegerAttr>(
-      adaptor.getOperands(), getType(),
-      [resultBitwidth](const APInt &a, bool & /*castStatus*/) {
-        return a.sextOrTrunc(resultBitwidth);
-      });
+  if (auto foldResult = constFoldCastOp<IntegerAttr, IntegerAttr>(
+          adaptor.getOperands(), getType(),
+          [resultBitwidth](const APInt &a, bool & /*castStatus*/) {
+            return a.sextOrTrunc(resultBitwidth);
+          }))
+    return foldResult;
+
+  // index_cast(index_cast(x : A) : B) : A -> x, but only when B is at least
+  // as wide as A. If B is narrower, the inner cast truncates and the outer
+  // cast sign-extends, so the round-trip is lossy.
+  if (auto inner = getOperand().getDefiningOp<arith::IndexCastOp>()) {
+    Value x = inner.getOperand();
+    if (x.getType() == getType()) {
+      if (getIndexCastWidth(inner.getType()) >= getIndexCastWidth(x.getType()))
+        return x;
+    }
+  }
+  return {};
 }
 
 void arith::IndexCastOp::getCanonicalizationPatterns(
     RewritePatternSet &patterns, MLIRContext *context) {
-  patterns.add<IndexCastOfIndexCast, IndexCastOfExtSI>(context);
+  patterns.add<IndexCastOfExtSI>(context);
 }
 
 //===----------------------------------------------------------------------===//
@@ -1960,16 +1982,29 @@ OpFoldResult arith::IndexCastUIOp::fold(FoldAdaptor adaptor) {
   if (auto intTy = dyn_cast<IntegerType>(getElementTypeOrSelf(getType())))
     resultBitwidth = intTy.getWidth();
 
-  return constFoldCastOp<IntegerAttr, IntegerAttr>(
-      adaptor.getOperands(), getType(),
-      [resultBitwidth](const APInt &a, bool & /*castStatus*/) {
-        return a.zextOrTrunc(resultBitwidth);
-      });
+  if (auto foldResult = constFoldCastOp<IntegerAttr, IntegerAttr>(
+          adaptor.getOperands(), getType(),
+          [resultBitwidth](const APInt &a, bool & /*castStatus*/) {
+            return a.zextOrTrunc(resultBitwidth);
+          }))
+    return foldResult;
+
+  // index_castui(index_castui(x : A) : B) : A -> x, but only when B is at
+  // least as wide as A. If B is narrower, the inner cast truncates and the
+  // outer cast zero-extends, so the round-trip is lossy.
+  if (auto inner = getOperand().getDefiningOp<arith::IndexCastUIOp>()) {
+    Value x = inner.getOperand();
+    if (x.getType() == getType()) {
+      if (getIndexCastWidth(inner.getType()) >= getIndexCastWidth(x.getType()))
+        return x;
+    }
+  }
+  return {};
 }
 
 void arith::IndexCastUIOp::getCanonicalizationPatterns(
     RewritePatternSet &patterns, MLIRContext *context) {
-  patterns.add<IndexCastUIOfIndexCastUI, IndexCastUIOfExtUI>(context);
+  patterns.add<IndexCastUIOfExtUI>(context);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index 18665e2eb6f4a..ee3e713f8481e 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -724,6 +724,113 @@ func.func @indexCastUIFoldVectorIndexToInt() -> vector<3xi32> {
   return %int : vector<3xi32>
 }
 
+// CHECK-LABEL: @indexCastOfIndexCast_lossless
+// The intermediate index type (64 bits) is at least as wide as i64 (64 bits),
+// so the round-trip is lossless and the chain folds away.
+//       CHECK:   return %arg0
+func.func @indexCastOfIndexCast_lossless(%arg0: i64) -> i64 {
+  %0 = arith.index_cast %arg0 : i64 to index
+  %1 = arith.index_cast %0 : index to i64
+  return %1 : i64
+}
+
+// -----
+
+// CHECK-LABEL: @indexCastOfIndexCast_lossy
+// The intermediate i8 type (8 bits) is narrower than index (64 bits), so
+// folding would drop the truncation — must be preserved.
+//       CHECK:   %[[a:.+]] = arith.index_cast %arg0 : index to i8
+//       CHECK:   %[[b:.+]] = arith.index_cast %[[a]] : i8 to index
+//       CHECK:   return %[[b]]
+func.func @indexCastOfIndexCast_lossy(%arg0: index) -> index {
+  %0 = arith.index_cast %arg0 : index to i8
+  %1 = arith.index_cast %0 : i8 to index
+  return %1 : index
+}
+
+// -----
+
+// CHECK-LABEL: @indexCastUIOfIndexCastUI_lossless
+// The intermediate index type is at least as wide as i64, so the chain folds.
+//       CHECK:   return %arg0
+func.func @indexCastUIOfIndexCastUI_lossless(%arg0: i64) -> i64 {
+  %0 = arith.index_castui %arg0 : i64 to index
+  %1 = arith.index_castui %0 : index to i64
+  return %1 : i64
+}
+
+// -----
+
+// CHECK-LABEL: @indexCastUIOfIndexCastUI_lossy
+// The intermediate i8 is narrower than index, so the truncation must be kept.
+//       CHECK:   %[[a:.+]] = arith.index_castui %arg0 : index to i8
+//       CHECK:   %[[b:.+]] = arith.index_castui %[[a]] : i8 to index
+//       CHECK:   return %[[b]]
+func.func @indexCastUIOfIndexCastUI_lossy(%arg0: index) -> index {
+  %0 = arith.index_castui %arg0 : index to i8
+  %1 = arith.index_castui %0 : i8 to index
+  return %1 : index
+}
+
+// -----
+
+// CHECK-LABEL: @indexCastUIOfIndexCastUI_3way_lossy
+// Regression test for the original bug: a 3-element chain where the outermost
+// cast pair would be incorrectly folded away, dropping the i8 truncation.
+//       CHECK:   %[[a:.*]] = arith.index_castui %arg0 : i64 to index
+//       CHECK:   %[[b:.*]] = arith.index_castui %[[a]] : index to i8
+//       CHECK:   %[[c:.*]] = arith.index_castui %[[b]] : i8 to index
+//       CHECK:   return %[[c]]
+func.func @indexCastUIOfIndexCastUI_3way_lossy(%arg0: i64) -> index {
+  %0 = arith.index_castui %arg0 : i64 to index
+  %1 = arith.index_castui %0 : index to i8
+  %2 = arith.index_castui %1 : i8 to index
+  return %2 : index
+}
+
+// -----
+
+// CHECK-LABEL: @indexCastOfIndexCast_3way_lossy
+// Signed 3-way chain where the outermost pair folds (i64->index is lossless
+// since 64 >= 64) but the inner i8 truncation is preserved.  The net result
+// is that %2 becomes %0 directly, collapsing the last two casts.
+//       CHECK:   %[[a:.*]] = arith.index_cast %arg0 : i8 to index
+//       CHECK:   return %[[a]]
+func.func @indexCastOfIndexCast_3way_lossy(%arg0: i8) -> index {
+  %0 = arith.index_cast %arg0 : i8 to index
+  %1 = arith.index_cast %0 : index to i64
+  %2 = arith.index_cast %1 : i64 to index
+  return %2 : index
+}
+
+// -----
+
+// CHECK-LABEL: @indexCastOfIndexCast_i8_roundtrip
+// i8 -> index -> i8: the intermediate index is at least as wide as i8 (64 >= 8),
+// so the round-trip is lossless and the chain folds away.
+//       CHECK:   return %arg0
+func.func @indexCastOfIndexCast_i8_roundtrip(%arg0: i8) -> i8 {
+  %0 = arith.index_cast %arg0 : i8 to index
+  %1 = arith.index_cast %0 : index to i8
+  return %1 : i8
+}
+
+// -----
+
+// CHECK-LABEL: @indexCastOfIndexCast_vector_lossy
+// vector<3xi128> -> vector<3xindex> -> vector<3xi128>: i128 (128 bits) is wider
+// than the 64-bit index, so the cast is lossy and must NOT fold.
+//       CHECK:   %[[a:.+]] = arith.index_cast %arg0 : vector<3xi128> to vector<3xindex>
+//       CHECK:   %[[b:.+]] = arith.index_cast %[[a]] : vector<3xindex> to vector<3xi128>
+//       CHECK:   return %[[b]]
+func.func @indexCastOfIndexCast_vector_lossy(%arg0: vector<3xi128>) -> vector<3xi128> {
+  %0 = arith.index_cast %arg0 : vector<3xi128> to vector<3xindex>
+  %1 = arith.index_cast %0 : vector<3xindex> to vector<3xi128>
+  return %1 : vector<3xi128>
+}
+
+// -----
+
 // CHECK-LABEL: @signExtendConstant
 //       CHECK:   %[[cres:.+]] = arith.constant -2 : i16
 //       CHECK:   return %[[cres]]



More information about the Mlir-commits mailing list