[Mlir-commits] [mlir] 8c81064 - [MLIR][Arith] Fix index_cast/index_castui chain folding to check intermediate width (#189042)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Apr 3 07:05:14 PDT 2026
Author: Mehdi Amini
Date: 2026-04-03T16:05:08+02:00
New Revision: 8c81064169c5c864f5a3e4b9474164e025b274b7
URL: https://github.com/llvm/llvm-project/commit/8c81064169c5c864f5a3e4b9474164e025b274b7
DIFF: https://github.com/llvm/llvm-project/commit/8c81064169c5c864f5a3e4b9474164e025b274b7.diff
LOG: [MLIR][Arith] Fix index_cast/index_castui chain folding to check intermediate width (#189042)
The patterns `IndexCastOfIndexCast` and `IndexCastUIOfIndexCastUI` in
ArithCanonicalization.td incorrectly eliminated a pair of index casts
whenever the outer result type equalled the original source type,
without verifying that the intermediate cast was lossless.
For example, the following was wrong folded to `%arg0`:
%0 = index_castui %arg0 : i64 to index
%1 = index_castui %0 : index to i8 ← truncates to 8 bits
%2 = index_castui %1 : i8 to index ← incorrectly removed
The pattern matched `%1`/`%2` because `i8.to(index)` has the same result
type as `i64.to(index)`, even though the i8 intermediate silently drops
56 bits. The same bug existed for the signed `index_cast` variant.
Fix: move the optimization into the `fold` methods of `IndexCastOp` and
`IndexCastUIOp` with an explicit check that the intermediate type is at
least as wide as the source type (using
`IndexType::kInternalStorageBitWidth` as the representative width for
`index`). Only then is the round-trip guaranteed lossless and the chain
can be collapsed.
Fixes #90238
Fixes #90296
Assisted-by: Claude Code
Added:
Modified:
mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
mlir/lib/Dialect/Arith/IR/ArithOps.cpp
mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
mlir/test/Conversion/ArithToLLVM/convert-nd-vector-to-llvmir.mlir
mlir/test/Dialect/Arith/canonicalize.mlir
mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
index e22fc1d478e4f..a15e19b24e54b 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
+++ b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
@@ -288,12 +288,6 @@ def SelectI1ToNot :
// IndexCastOp
//===----------------------------------------------------------------------===//
-// index_cast(index_cast(x)) -> x, if dstType == srcType.
-def IndexCastOfIndexCast :
- Pat<(Arith_IndexCastOp:$res (Arith_IndexCastOp $x)),
- (replaceWithValue $x),
- [(Constraint<CPred<"$0.getType() == $1.getType()">> $res, $x)]>;
-
// index_cast(extsi(x)) -> index_cast(x)
def IndexCastOfExtSI :
Pat<(Arith_IndexCastOp (Arith_ExtSIOp $x)), (Arith_IndexCastOp $x)>;
@@ -302,12 +296,6 @@ def IndexCastOfExtSI :
// IndexCastUIOp
//===----------------------------------------------------------------------===//
-// index_castui(index_castui(x)) -> x, if dstType == srcType.
-def IndexCastUIOfIndexCastUI :
- Pat<(Arith_IndexCastUIOp:$res (Arith_IndexCastUIOp $x, $nneg1), $nneg2),
- (replaceWithValue $x),
- [(Constraint<CPred<"$0.getType() == $1.getType()">> $res, $x)]>;
-
// index_castui(extui(x)) -> index_castui(x)
def IndexCastUIOfExtUI :
Pat<(Arith_IndexCastUIOp (Arith_ExtUIOp $x, $nneg1), $nneg2),
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 5f10a94522350..569d1869a5abe 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -1909,6 +1909,15 @@ OpFoldResult arith::FPToSIOp::fold(FoldAdaptor adaptor) {
// IndexCastOp
//===----------------------------------------------------------------------===//
+/// Return the bit-width of \p t for the purpose of index_cast width checks.
+/// For vector types use the element type; index maps to its internal storage
+/// width (64 on all current targets).
+static unsigned getIndexCastWidth(Type t) {
+ if (auto intTy = dyn_cast<IntegerType>(getElementTypeOrSelf(t)))
+ return intTy.getWidth();
+ return IndexType::kInternalStorageBitWidth;
+}
+
static bool areIndexCastCompatible(TypeRange inputs, TypeRange outputs) {
if (!areValidCastInputsAndOutputs(inputs, outputs))
return false;
@@ -1933,16 +1942,29 @@ OpFoldResult arith::IndexCastOp::fold(FoldAdaptor adaptor) {
if (auto intTy = dyn_cast<IntegerType>(getElementTypeOrSelf(getType())))
resultBitwidth = intTy.getWidth();
- return constFoldCastOp<IntegerAttr, IntegerAttr>(
- adaptor.getOperands(), getType(),
- [resultBitwidth](const APInt &a, bool & /*castStatus*/) {
- return a.sextOrTrunc(resultBitwidth);
- });
+ if (auto foldResult = constFoldCastOp<IntegerAttr, IntegerAttr>(
+ adaptor.getOperands(), getType(),
+ [resultBitwidth](const APInt &a, bool & /*castStatus*/) {
+ return a.sextOrTrunc(resultBitwidth);
+ }))
+ return foldResult;
+
+ // index_cast(index_cast(x : A) : B) : A -> x, but only when B is at least
+ // as wide as A. If B is narrower, the inner cast truncates and the outer
+ // cast sign-extends, so the round-trip is lossy.
+ if (auto inner = getOperand().getDefiningOp<arith::IndexCastOp>()) {
+ Value x = inner.getOperand();
+ if (x.getType() == getType()) {
+ if (getIndexCastWidth(inner.getType()) >= getIndexCastWidth(x.getType()))
+ return x;
+ }
+ }
+ return {};
}
void arith::IndexCastOp::getCanonicalizationPatterns(
RewritePatternSet &patterns, MLIRContext *context) {
- patterns.add<IndexCastOfIndexCast, IndexCastOfExtSI>(context);
+ patterns.add<IndexCastOfExtSI>(context);
}
//===----------------------------------------------------------------------===//
@@ -1960,16 +1982,29 @@ OpFoldResult arith::IndexCastUIOp::fold(FoldAdaptor adaptor) {
if (auto intTy = dyn_cast<IntegerType>(getElementTypeOrSelf(getType())))
resultBitwidth = intTy.getWidth();
- return constFoldCastOp<IntegerAttr, IntegerAttr>(
- adaptor.getOperands(), getType(),
- [resultBitwidth](const APInt &a, bool & /*castStatus*/) {
- return a.zextOrTrunc(resultBitwidth);
- });
+ if (auto foldResult = constFoldCastOp<IntegerAttr, IntegerAttr>(
+ adaptor.getOperands(), getType(),
+ [resultBitwidth](const APInt &a, bool & /*castStatus*/) {
+ return a.zextOrTrunc(resultBitwidth);
+ }))
+ return foldResult;
+
+ // index_castui(index_castui(x : A) : B) : A -> x, but only when B is at
+ // least as wide as A. If B is narrower, the inner cast truncates and the
+ // outer cast zero-extends, so the round-trip is lossy.
+ if (auto inner = getOperand().getDefiningOp<arith::IndexCastUIOp>()) {
+ Value x = inner.getOperand();
+ if (x.getType() == getType()) {
+ if (getIndexCastWidth(inner.getType()) >= getIndexCastWidth(x.getType()))
+ return x;
+ }
+ }
+ return {};
}
void arith::IndexCastUIOp::getCanonicalizationPatterns(
RewritePatternSet &patterns, MLIRContext *context) {
- patterns.add<IndexCastUIOfIndexCastUI, IndexCastUIOfExtUI>(context);
+ patterns.add<IndexCastUIOfExtUI>(context);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
index 319dfc31ab637..e45adb7287ac4 100644
--- a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
+++ b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
@@ -693,8 +693,6 @@ func.func @arith_index_cast(%arg0: i32) -> i32 {
// CHECK: %[[Conv0:.*]] = emitc.cast %[[Arg0]] : i32 to !emitc.ptr
diff _t
// CHECK: %[[Conv1:.*]] = emitc.cast %[[Conv0]] : !emitc.ptr
diff _t to !emitc.size_t
%idx = arith.index_cast %arg0 : i32 to index
- // CHECK: %[[Conv2:.*]] = emitc.cast %[[Conv1]] : !emitc.size_t to !emitc.ptr
diff _t
- // CHECK: %[[Conv3:.*]] = emitc.cast %[[Conv2]] : !emitc.ptr
diff _t to i32
%int = arith.index_cast %idx : index to i32
// CHECK: %[[Const:.*]] = "emitc.constant"
@@ -704,6 +702,7 @@ func.func @arith_index_cast(%arg0: i32) -> i32 {
// CHECK: %[[Conv4:.*]] = emitc.cast %[[AndOne]] : !emitc.size_t to i1
%bool = arith.index_cast %idx : index to i1
+ // CHECK: return %[[Arg0]] : i32
return %int : i32
}
@@ -715,8 +714,6 @@ func.func @arith_index_castui(%arg0: i32) -> i32 {
// CHECK: %[[Conv0:.*]] = emitc.cast %[[Arg0]] : i32 to ui32
// CHECK: %[[Conv1:.*]] = emitc.cast %[[Conv0]] : ui32 to !emitc.size_t
%idx = arith.index_castui %arg0 : i32 to index
- // CHECK: %[[Conv2:.*]] = emitc.cast %[[Conv1]] : !emitc.size_t to ui32
- // CHECK: %[[Conv3:.*]] = emitc.cast %[[Conv2]] : ui32 to i32
%int = arith.index_castui %idx : index to i32
// CHECK: %[[Const:.*]] = "emitc.constant"
@@ -726,6 +723,7 @@ func.func @arith_index_castui(%arg0: i32) -> i32 {
// CHECK: %[[Conv4:.*]] = emitc.cast %[[AndOne]] : !emitc.size_t to i1
%bool = arith.index_castui %idx : index to i1
+ // CHECK: return %[[Arg0]] : i32
return %int : i32
}
diff --git a/mlir/test/Conversion/ArithToLLVM/convert-nd-vector-to-llvmir.mlir b/mlir/test/Conversion/ArithToLLVM/convert-nd-vector-to-llvmir.mlir
index bf1e8580a5b76..497574af2a2d8 100644
--- a/mlir/test/Conversion/ArithToLLVM/convert-nd-vector-to-llvmir.mlir
+++ b/mlir/test/Conversion/ArithToLLVM/convert-nd-vector-to-llvmir.mlir
@@ -237,12 +237,9 @@ func.func @index_cast_2d(%arg0: vector<1x2x3xi1>) {
// CHECK: %[[SEXT2:.*]] = llvm.sext %[[EXTRACT2]] : vector<3xi1> to vector<3xi{{.*}}>
// CHECK: %[[INSERT2:.*]] = llvm.insertvalue %[[SEXT2]], %[[INSERT1]][0, 1] : !llvm.array<1 x array<2 x vector<3xi{{.*}}>>>
%0 = arith.index_cast %arg0: vector<1x2x3xi1> to vector<1x2x3xindex>
- // CHECK: %[[EXTRACT3:.*]] = llvm.extractvalue %[[INSERT2]][0, 0] : !llvm.array<1 x array<2 x vector<3xi{{.*}}>>>
- // CHECK: %[[TRUNC1:.*]] = llvm.trunc %[[EXTRACT3]] : vector<3xi{{.*}}> to vector<3xi1>
- // CHECK: %[[INSERT3:.*]] = llvm.insertvalue %[[TRUNC1]], %{{.*}}[0, 0] : !llvm.array<1 x array<2 x vector<3xi1>>>
- // CHECK: %[[EXTRACT4:.*]] = llvm.extractvalue %[[INSERT2]][0, 1] : !llvm.array<1 x array<2 x vector<3xi{{.*}}>>>
- // CHECK: %[[TRUNC2:.*]] = llvm.trunc %[[EXTRACT4]] : vector<3xi{{.*}}> to vector<3xi1>
- // CHECK: %[[INSERT4:.*]] = llvm.insertvalue %[[TRUNC2]], %[[INSERT3]][0, 1] : !llvm.array<1 x array<2 x vector<3xi1>>>
+ // The back-cast folds away: index_cast(index_cast(x:i1):index):i1 -> x
+ // because index (64-bit) is wider than i1, so the round-trip is lossless.
+ // CHECK-NOT: llvm.trunc
%1 = arith.index_cast %0: vector<1x2x3xindex> to vector<1x2x3xi1>
return
}
diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index 18665e2eb6f4a..ee3e713f8481e 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -724,6 +724,113 @@ func.func @indexCastUIFoldVectorIndexToInt() -> vector<3xi32> {
return %int : vector<3xi32>
}
+// CHECK-LABEL: @indexCastOfIndexCast_lossless
+// The intermediate index type (64 bits) is at least as wide as i64 (64 bits),
+// so the round-trip is lossless and the chain folds away.
+// CHECK: return %arg0
+func.func @indexCastOfIndexCast_lossless(%arg0: i64) -> i64 {
+ %0 = arith.index_cast %arg0 : i64 to index
+ %1 = arith.index_cast %0 : index to i64
+ return %1 : i64
+}
+
+// -----
+
+// CHECK-LABEL: @indexCastOfIndexCast_lossy
+// The intermediate i8 type (8 bits) is narrower than index (64 bits), so
+// folding would drop the truncation — must be preserved.
+// CHECK: %[[a:.+]] = arith.index_cast %arg0 : index to i8
+// CHECK: %[[b:.+]] = arith.index_cast %[[a]] : i8 to index
+// CHECK: return %[[b]]
+func.func @indexCastOfIndexCast_lossy(%arg0: index) -> index {
+ %0 = arith.index_cast %arg0 : index to i8
+ %1 = arith.index_cast %0 : i8 to index
+ return %1 : index
+}
+
+// -----
+
+// CHECK-LABEL: @indexCastUIOfIndexCastUI_lossless
+// The intermediate index type is at least as wide as i64, so the chain folds.
+// CHECK: return %arg0
+func.func @indexCastUIOfIndexCastUI_lossless(%arg0: i64) -> i64 {
+ %0 = arith.index_castui %arg0 : i64 to index
+ %1 = arith.index_castui %0 : index to i64
+ return %1 : i64
+}
+
+// -----
+
+// CHECK-LABEL: @indexCastUIOfIndexCastUI_lossy
+// The intermediate i8 is narrower than index, so the truncation must be kept.
+// CHECK: %[[a:.+]] = arith.index_castui %arg0 : index to i8
+// CHECK: %[[b:.+]] = arith.index_castui %[[a]] : i8 to index
+// CHECK: return %[[b]]
+func.func @indexCastUIOfIndexCastUI_lossy(%arg0: index) -> index {
+ %0 = arith.index_castui %arg0 : index to i8
+ %1 = arith.index_castui %0 : i8 to index
+ return %1 : index
+}
+
+// -----
+
+// CHECK-LABEL: @indexCastUIOfIndexCastUI_3way_lossy
+// Regression test for the original bug: a 3-element chain where the outermost
+// cast pair would be incorrectly folded away, dropping the i8 truncation.
+// CHECK: %[[a:.*]] = arith.index_castui %arg0 : i64 to index
+// CHECK: %[[b:.*]] = arith.index_castui %[[a]] : index to i8
+// CHECK: %[[c:.*]] = arith.index_castui %[[b]] : i8 to index
+// CHECK: return %[[c]]
+func.func @indexCastUIOfIndexCastUI_3way_lossy(%arg0: i64) -> index {
+ %0 = arith.index_castui %arg0 : i64 to index
+ %1 = arith.index_castui %0 : index to i8
+ %2 = arith.index_castui %1 : i8 to index
+ return %2 : index
+}
+
+// -----
+
+// CHECK-LABEL: @indexCastOfIndexCast_3way_lossy
+// Signed 3-way chain where the outermost pair folds (i64->index is lossless
+// since 64 >= 64) but the inner i8 truncation is preserved. The net result
+// is that %2 becomes %0 directly, collapsing the last two casts.
+// CHECK: %[[a:.*]] = arith.index_cast %arg0 : i8 to index
+// CHECK: return %[[a]]
+func.func @indexCastOfIndexCast_3way_lossy(%arg0: i8) -> index {
+ %0 = arith.index_cast %arg0 : i8 to index
+ %1 = arith.index_cast %0 : index to i64
+ %2 = arith.index_cast %1 : i64 to index
+ return %2 : index
+}
+
+// -----
+
+// CHECK-LABEL: @indexCastOfIndexCast_i8_roundtrip
+// i8 -> index -> i8: the intermediate index is at least as wide as i8 (64 >= 8),
+// so the round-trip is lossless and the chain folds away.
+// CHECK: return %arg0
+func.func @indexCastOfIndexCast_i8_roundtrip(%arg0: i8) -> i8 {
+ %0 = arith.index_cast %arg0 : i8 to index
+ %1 = arith.index_cast %0 : index to i8
+ return %1 : i8
+}
+
+// -----
+
+// CHECK-LABEL: @indexCastOfIndexCast_vector_lossy
+// vector<3xi128> -> vector<3xindex> -> vector<3xi128>: i128 (128 bits) is wider
+// than the 64-bit index, so the cast is lossy and must NOT fold.
+// CHECK: %[[a:.+]] = arith.index_cast %arg0 : vector<3xi128> to vector<3xindex>
+// CHECK: %[[b:.+]] = arith.index_cast %[[a]] : vector<3xindex> to vector<3xi128>
+// CHECK: return %[[b]]
+func.func @indexCastOfIndexCast_vector_lossy(%arg0: vector<3xi128>) -> vector<3xi128> {
+ %0 = arith.index_cast %arg0 : vector<3xi128> to vector<3xindex>
+ %1 = arith.index_cast %0 : vector<3xindex> to vector<3xi128>
+ return %1 : vector<3xi128>
+}
+
+// -----
+
// CHECK-LABEL: @signExtendConstant
// CHECK: %[[cres:.+]] = arith.constant -2 : i16
// CHECK: return %[[cres]]
diff --git a/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir b/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir
index 6f1a422324e08..4f0d4bb0d8f5d 100644
--- a/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir
+++ b/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir
@@ -853,8 +853,7 @@ func.func @fusion_
diff erent_axes(%arg0 : tensor<5000xi64>, %arg1 : tensor<5000xi
// CHECK-SAME: %[[B1:.+]]: i32
// CHECK-DAG: %[[T0:.+]] = linalg.index 0
// CHECK-DAG: %[[CAST1:.+]] = arith.index_cast %[[T0]] : index to i64
-// CHECK-DAG: %[[CAST2:.+]] = arith.index_cast %[[CAST1]] : i64 to index
-// CHECK: %[[EXTRACT:.+]] = tensor.extract %[[ARG1]][%[[CAST2]]]
+// CHECK: %[[EXTRACT:.+]] = tensor.extract %[[ARG1]][%[[T0]]]
// CHECK: linalg.yield %[[CAST1]], %[[EXTRACT]]
// CHECK: return %[[RESULT]]#1
More information about the Mlir-commits
mailing list