[Mlir-commits] [mlir] [mlir][arith] Add `exact` to `index_cast{, ui}` (PR #183395)
Erick Ochoa Lopez
llvmlistbot at llvm.org
Thu Feb 26 06:51:44 PST 2026
https://github.com/amd-eochoalo updated https://github.com/llvm/llvm-project/pull/183395
>From e329975132e2bc2ecf3ca49e7ce51b2813aca813 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Wed, 25 Feb 2026 15:12:14 -0500
Subject: [PATCH 1/7] [mlir][arith] Add nneg to index_castui.
---
.../include/mlir/Dialect/Arith/IR/ArithOps.td | 19 ++++++++++++++-
.../Conversion/ArithToLLVM/ArithToLLVM.cpp | 23 +++++++++++++++----
.../Dialect/Arith/IR/ArithCanonicalization.td | 6 ++---
.../Conversion/ArithToLLVM/arith-to-llvm.mlir | 19 +++++++++++++++
mlir/test/Dialect/Arith/canonicalize.mlir | 16 ++++++++++---
mlir/test/Dialect/Arith/ops.mlir | 15 +++++++++++-
6 files changed, 85 insertions(+), 13 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
index c372038a8d43e..50cbd970ef6ac 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
@@ -1598,15 +1598,32 @@ def Arith_IndexCastOp
def Arith_IndexCastUIOp
: Arith_CastOp<"index_castui", IndexCastTypeConstraint, IndexCastTypeConstraint,
- [DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>]> {
+ [DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>,
+ DeclareOpInterfaceMethods<ArithNonNegFlagInterface>]> {
let summary = "unsigned cast between index and integer types";
let description = [{
Casts between scalar or vector integers and corresponding 'index' scalar or
vectors. Index is an integer of platform-specific bit width. If casting to
a wider integer, the value is zero-extended. If casting to a narrower
integer, the value is truncated.
+
+ When the `nneg` flag is present, the operand is assumed to be non-negative.
+ In this case, zero extension is equivalent to sign extension. When this
+ assumption is violated, the result is poison.
+
+ Example:
+
+ ```mlir
+ %0 = arith.index_castui %a : i32 to index
+ %1 = arith.index_castui %a nneg : i32 to index
+ ```
}];
+ let arguments = (ins IndexCastTypeConstraint:$in, UnitAttr:$nonNeg);
+ let results = (outs IndexCastTypeConstraint:$out);
+ let assemblyFormat = [{
+ $in (`nneg` $nonNeg^)? attr-dict `:` type($in) `to` type($out)
+ }];
let hasFolder = 1;
let hasCanonicalizer = 1;
}
diff --git a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
index 178dcd419264d..e7f561e8a4d67 100644
--- a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
+++ b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
@@ -307,15 +307,23 @@ LogicalResult IndexCastOpLowering<OpTy, ExtCastTy>::matchAndRewrite(
return success();
}
+ bool isNonNeg = false;
+ if constexpr (std::is_same_v<ExtCastTy, LLVM::ZExtOp>)
+ isNonNeg = op.getNonNeg();
+
// Handle the scalar and 1D vector cases.
Type operandType = adaptor.getIn().getType();
if (!isa<LLVM::LLVMArrayType>(operandType)) {
Type targetType = this->typeConverter->convertType(resultType);
- if (targetBits < sourceBits)
+ if (targetBits < sourceBits) {
rewriter.replaceOpWithNewOp<LLVM::TruncOp>(op, targetType,
adaptor.getIn());
- else
- rewriter.replaceOpWithNewOp<ExtCastTy>(op, targetType, adaptor.getIn());
+ } else {
+ auto extOp = rewriter.replaceOpWithNewOp<ExtCastTy>(op, targetType,
+ adaptor.getIn());
+ if constexpr (std::is_same_v<ExtCastTy, LLVM::ZExtOp>)
+ extOp.setNonNeg(isNonNeg);
+ }
return success();
}
@@ -330,8 +338,13 @@ LogicalResult IndexCastOpLowering<OpTy, ExtCastTy>::matchAndRewrite(
return LLVM::TruncOp::create(rewriter, op.getLoc(), llvm1DVectorTy,
adaptor.getIn());
}
- return ExtCastTy::create(rewriter, op.getLoc(), llvm1DVectorTy,
- adaptor.getIn());
+ auto extOp = ExtCastTy::create(rewriter, op.getLoc(), llvm1DVectorTy,
+ adaptor.getIn());
+ if constexpr (std::is_same_v<ExtCastTy, LLVM::ZExtOp>) {
+ if (isNonNeg)
+ extOp.setNonNeg(true);
+ }
+ return extOp;
},
rewriter);
}
diff --git a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
index 8be2af5cb3cfc..fb9c16db91431 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
+++ b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
@@ -304,14 +304,14 @@ def IndexCastOfExtSI :
// index_castui(index_castui(x)) -> x, if dstType == srcType.
def IndexCastUIOfIndexCastUI :
- Pat<(Arith_IndexCastUIOp:$res (Arith_IndexCastUIOp $x)),
+ Pat<(Arith_IndexCastUIOp:$res (Arith_IndexCastUIOp $x, $nneg1), $nneg2),
(replaceWithValue $x),
[(Constraint<CPred<"$0.getType() == $1.getType()">> $res, $x)]>;
// index_castui(extui(x)) -> index_castui(x)
def IndexCastUIOfExtUI :
- Pat<(Arith_IndexCastUIOp (Arith_ExtUIOp $x, $nneg)),
- (Arith_IndexCastUIOp $x)>;
+ Pat<(Arith_IndexCastUIOp (Arith_ExtUIOp $x, $nneg1), $nneg2),
+ (Arith_IndexCastUIOp $x, $nneg1)>;
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
index 7ae27a884d5d3..47069906fa110 100644
--- a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
+++ b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
@@ -141,6 +141,25 @@ func.func @vector_index_castui(%arg0: vector<2xindex>, %arg1: vector<2xi1>) {
// -----
+// CHECK-LABEL: @index_castui_nneg
+func.func @index_castui_nneg(%arg0: i1) {
+// CHECK: llvm.zext nneg %{{.*}} : i1 to i{{.*}}
+ %0 = arith.index_castui %arg0 nneg : i1 to index
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @index_castui_nneg_not_set
+func.func @index_castui_nneg_not_set(%arg0: i1) {
+// CHECK: llvm.zext %{{.*}} : i1 to i{{.*}}
+// CHECK-NOT: nneg
+ %0 = arith.index_castui %arg0 : i1 to index
+ return
+}
+
+// -----
+
// Checking conversion of signed integer types to floating point.
// CHECK-LABEL: @sitofp
func.func @sitofp(%arg0 : i32, %arg1 : i64) {
diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index 52f4a54feebab..4dc29897cec26 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -597,15 +597,25 @@ func.func @indexCastUIOfUnsignedExtend(%arg0: i8) -> index {
return %idx : index
}
-// CHECK-LABEL: @indexCastUIOfUnsignedExtend_nneg
-// CHECK: %[[res:.+]] = arith.index_castui %arg0 : i8 to index
+// CHECK-LABEL: @indexCastUIOfUnsignedExtend_nneg_on_extui
+// CHECK: %[[res:.+]] = arith.index_castui %arg0 nneg : i8 to index
// CHECK: return %[[res]]
-func.func @indexCastUIOfUnsignedExtend_nneg(%arg0: i8) -> index {
+func.func @indexCastUIOfUnsignedExtend_nneg_on_extui(%arg0: i8) -> index {
%ext = arith.extui %arg0 nneg : i8 to i16
%idx = arith.index_castui %ext : i16 to index
return %idx : index
}
+// CHECK-LABEL: @indexCastUIOfUnsignedExtend_nneg_on_castui
+// CHECK: %[[res:.+]] = arith.index_castui %arg0 : i8 to index
+// CHECK-NOT: nneg
+// CHECK: return %[[res]]
+func.func @indexCastUIOfUnsignedExtend_nneg_on_castui(%arg0: i8) -> index {
+ %ext = arith.extui %arg0 : i8 to i16
+ %idx = arith.index_castui %ext nneg : i16 to index
+ return %idx : index
+}
+
// CHECK-LABEL: @indexCastFold
// CHECK: %[[res:.*]] = arith.constant -2 : index
// CHECK: return %[[res]]
diff --git a/mlir/test/Dialect/Arith/ops.mlir b/mlir/test/Dialect/Arith/ops.mlir
index 9541cecb489f6..9765db69d6dd5 100644
--- a/mlir/test/Dialect/Arith/ops.mlir
+++ b/mlir/test/Dialect/Arith/ops.mlir
@@ -909,7 +909,6 @@ func.func @test_index_cast_scalable_vector1(%arg0 : vector<[8]xindex>) -> vector
return %0 : vector<[8]xi64>
}
-
// CHECK-LABEL: test_index_castui0
func.func @test_index_castui0(%arg0 : i32) -> index {
%0 = arith.index_castui %arg0 : i32 to index
@@ -958,6 +957,20 @@ func.func @test_index_castui_scalable_vector1(%arg0 : vector<[8]xindex>) -> vect
return %0 : vector<[8]xi64>
}
+// CHECK-LABEL: test_index_castui_nneg
+// CHECK: arith.index_castui %{{.*}} nneg : i32 to index
+func.func @test_index_castui_nneg(%arg0 : i32) -> index {
+ %0 = arith.index_castui %arg0 nneg : i32 to index
+ return %0 : index
+}
+
+// CHECK-LABEL: test_index_castui_nneg_vector
+// CHECK: arith.index_castui %{{.*}} nneg : vector<8xi32> to vector<8xindex>
+func.func @test_index_castui_nneg_vector(%arg0 : vector<8xi32>) -> vector<8xindex> {
+ %0 = arith.index_castui %arg0 nneg : vector<8xi32> to vector<8xindex>
+ return %0 : vector<8xindex>
+}
+
// CHECK-LABEL: test_bitcast0
func.func @test_bitcast0(%arg0 : i64) -> f64 {
%0 = arith.bitcast %arg0 : i64 to f64
>From d6ee09abb0dd54c00047f26e80ca2412fb275068 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Wed, 25 Feb 2026 15:49:36 -0500
Subject: [PATCH 2/7] [mlir][arith] Add ArithExactFlagInterface.
---
.../include/mlir/Dialect/Arith/IR/ArithOps.td | 3 +-
.../Dialect/Arith/IR/ArithOpsInterfaces.td | 47 +++++++++++++++++++
2 files changed, 49 insertions(+), 1 deletion(-)
diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
index 50cbd970ef6ac..b11f5e8b8eb4d 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
@@ -160,7 +160,8 @@ class Arith_IntBinaryOpWithOverflowFlags<string mnemonic, list<Trait> traits = [
class Arith_IntBinaryOpWithExactFlag<string mnemonic, list<Trait> traits = []> :
Arith_BinaryOp<mnemonic, traits #
- [DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>]>,
+ [DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>,
+ DeclareOpInterfaceMethods<ArithExactFlagInterface>]>,
Arguments<(ins SignlessIntegerOrIndexLike:$lhs,
SignlessIntegerOrIndexLike:$rhs,
UnitAttr:$isExact)>,
diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOpsInterfaces.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOpsInterfaces.td
index d1b8e250cdb59..e8287ac2d6bcc 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOpsInterfaces.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOpsInterfaces.td
@@ -153,6 +153,53 @@ def ArithNonNegFlagInterface : OpInterface<"ArithNonNegFlagInterface"> {
];
}
+def ArithExactFlagInterface : OpInterface<"ArithExactFlagInterface"> {
+ let description = [{
+ Access to op exact flag.
+ }];
+
+ let cppNamespace = "::mlir::arith";
+
+ let methods = [
+ InterfaceMethod<
+ /*desc=*/ "Returns whether the operation has the exact flag set",
+ /*returnType=*/ "bool",
+ /*methodName=*/ "getExact",
+ /*args=*/ (ins),
+ /*methodBody=*/ [{}],
+ /*defaultImpl=*/ [{
+ auto op = cast<ConcreteOp>(this->getOperation());
+ return op.getIsExactAttr() != nullptr;
+ }]
+ >,
+ InterfaceMethod<
+ /*desc=*/ "Set the exact flag for the operation",
+ /*returnType=*/ "void",
+ /*methodName=*/ "setExact",
+ /*args=*/ (ins "bool":$isExact),
+ /*methodBody=*/ [{}],
+ /*defaultImpl=*/ [{
+ auto op = cast<ConcreteOp>(this->getOperation());
+ if (isExact)
+ op.setIsExactAttr(UnitAttr::get(op->getContext()));
+ else
+ op.removeIsExactAttr();
+ }]
+ >,
+ StaticInterfaceMethod<
+ /*desc=*/ [{Returns the name of the exact flag attribute for
+ the operation}],
+ /*returnType=*/ "StringRef",
+ /*methodName=*/ "getExactFlagAttrName",
+ /*args=*/ (ins),
+ /*methodBody=*/ [{}],
+ /*defaultImpl=*/ [{
+ return "isExact";
+ }]
+ >
+ ];
+}
+
def ArithRoundingModeInterface : OpInterface<"ArithRoundingModeInterface"> {
let description = [{
Access to op rounding mode.
>From 8123bdcab57532df873e2b5476a19c480d884e9c Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Wed, 25 Feb 2026 15:54:38 -0500
Subject: [PATCH 3/7] [mlir][arith] Add exact attribute to index_cast{,ui}
---
.../include/mlir/Dialect/Arith/IR/ArithOps.td | 34 ++++++++++++++++---
1 file changed, 30 insertions(+), 4 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
index b11f5e8b8eb4d..c0987fba98562 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
@@ -1580,15 +1580,32 @@ def IndexCastTypeConstraint : TypeConstraint<Or<[
def Arith_IndexCastOp
: Arith_CastOp<"index_cast", IndexCastTypeConstraint, IndexCastTypeConstraint,
- [DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>]> {
+ [DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>,
+ DeclareOpInterfaceMethods<ArithExactFlagInterface>]> {
let summary = "cast between index and integer types";
let description = [{
Casts between scalar or vector integers and corresponding 'index' scalar or
vectors. Index is an integer of platform-specific bit width. If casting to
a wider integer, the value is sign-extended. If casting to a narrower
integer, the value is truncated.
+
+ If the `exact` attribute is present, it is assumed that the index type width
+ is such that the conversion does not lose information. When this assumption
+ is violated, the result is poison.
+
+ Example:
+
+ ```mlir
+ %0 = arith.index_cast %a : index to i64
+ %1 = arith.index_cast %a exact : index to i64
+ ```
}];
+ let arguments = (ins IndexCastTypeConstraint:$in, UnitAttr:$isExact);
+ let results = (outs IndexCastTypeConstraint:$out);
+ let assemblyFormat = [{
+ $in (`exact` $isExact^)? attr-dict `:` type($in) `to` type($out)
+ }];
let hasFolder = 1;
let hasCanonicalizer = 1;
}
@@ -1600,7 +1617,8 @@ def Arith_IndexCastOp
def Arith_IndexCastUIOp
: Arith_CastOp<"index_castui", IndexCastTypeConstraint, IndexCastTypeConstraint,
[DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>,
- DeclareOpInterfaceMethods<ArithNonNegFlagInterface>]> {
+ DeclareOpInterfaceMethods<ArithNonNegFlagInterface>,
+ DeclareOpInterfaceMethods<ArithExactFlagInterface>]> {
let summary = "unsigned cast between index and integer types";
let description = [{
Casts between scalar or vector integers and corresponding 'index' scalar or
@@ -1612,18 +1630,26 @@ def Arith_IndexCastUIOp
In this case, zero extension is equivalent to sign extension. When this
assumption is violated, the result is poison.
+ If the `exact` attribute is present, it is assumed that the index type width
+ is such that the conversion does not lose information. When this assumption
+ is violated, the result is poison.
+
Example:
```mlir
%0 = arith.index_castui %a : i32 to index
%1 = arith.index_castui %a nneg : i32 to index
+ %2 = arith.index_castui %a exact : i32 to index
+ %3 = arith.index_castui %a nneg exact : i32 to index
```
}];
- let arguments = (ins IndexCastTypeConstraint:$in, UnitAttr:$nonNeg);
+ let arguments = (ins IndexCastTypeConstraint:$in, UnitAttr:$nonNeg,
+ UnitAttr:$isExact);
let results = (outs IndexCastTypeConstraint:$out);
let assemblyFormat = [{
- $in (`nneg` $nonNeg^)? attr-dict `:` type($in) `to` type($out)
+ $in (`nneg` $nonNeg^)? (`exact` $isExact^)? attr-dict
+ `:` type($in) `to` type($out)
}];
let hasFolder = 1;
let hasCanonicalizer = 1;
>From d624802197961593af945618c1671cfd84ab08f2 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Wed, 25 Feb 2026 15:55:34 -0500
Subject: [PATCH 4/7] [mlir][arith] Add roundtrip tests
---
mlir/test/Dialect/Arith/ops.mlir | 28 ++++++++++++++++++++++++++++
1 file changed, 28 insertions(+)
diff --git a/mlir/test/Dialect/Arith/ops.mlir b/mlir/test/Dialect/Arith/ops.mlir
index 9765db69d6dd5..b7ad2ff8a8694 100644
--- a/mlir/test/Dialect/Arith/ops.mlir
+++ b/mlir/test/Dialect/Arith/ops.mlir
@@ -909,6 +909,20 @@ func.func @test_index_cast_scalable_vector1(%arg0 : vector<[8]xindex>) -> vector
return %0 : vector<[8]xi64>
}
+// CHECK-LABEL: test_index_cast_exact
+// CHECK: arith.index_cast %{{.*}} exact : i32 to index
+func.func @test_index_cast_exact(%arg0 : i32) -> index {
+ %0 = arith.index_cast %arg0 exact : i32 to index
+ return %0 : index
+}
+
+// CHECK-LABEL: test_index_cast_exact_vector
+// CHECK: arith.index_cast %{{.*}} exact : vector<8xi32> to vector<8xindex>
+func.func @test_index_cast_exact_vector(%arg0 : vector<8xi32>) -> vector<8xindex> {
+ %0 = arith.index_cast %arg0 exact : vector<8xi32> to vector<8xindex>
+ return %0 : vector<8xindex>
+}
+
// CHECK-LABEL: test_index_castui0
func.func @test_index_castui0(%arg0 : i32) -> index {
%0 = arith.index_castui %arg0 : i32 to index
@@ -971,6 +985,20 @@ func.func @test_index_castui_nneg_vector(%arg0 : vector<8xi32>) -> vector<8xinde
return %0 : vector<8xindex>
}
+// CHECK-LABEL: test_index_castui_exact
+// CHECK: arith.index_castui %{{.*}} exact : i32 to index
+func.func @test_index_castui_exact(%arg0 : i32) -> index {
+ %0 = arith.index_castui %arg0 exact : i32 to index
+ return %0 : index
+}
+
+// CHECK-LABEL: test_index_castui_nneg_exact
+// CHECK: arith.index_castui %{{.*}} nneg exact : i32 to index
+func.func @test_index_castui_nneg_exact(%arg0 : i32) -> index {
+ %0 = arith.index_castui %arg0 nneg exact : i32 to index
+ return %0 : index
+}
+
// CHECK-LABEL: test_bitcast0
func.func @test_bitcast0(%arg0 : i64) -> f64 {
%0 = arith.bitcast %arg0 : i64 to f64
>From 3d2aa5460679f52df7b51f8081be45c581eb9cfc Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Wed, 25 Feb 2026 16:03:31 -0500
Subject: [PATCH 5/7] [mlir][arith] Update patterns and canonicalization tests
---
.../Dialect/Arith/IR/ArithCanonicalization.td | 12 +++++----
mlir/test/Dialect/Arith/canonicalize.mlir | 27 +++++++++++++++++++
2 files changed, 34 insertions(+), 5 deletions(-)
diff --git a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
index fb9c16db91431..5eb973174a53c 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
+++ b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
@@ -290,13 +290,14 @@ def SelectI1ToNot :
// index_cast(index_cast(x)) -> x, if dstType == srcType.
def IndexCastOfIndexCast :
- Pat<(Arith_IndexCastOp:$res (Arith_IndexCastOp $x)),
+ Pat<(Arith_IndexCastOp:$res (Arith_IndexCastOp $x, $exact1), $exact2),
(replaceWithValue $x),
[(Constraint<CPred<"$0.getType() == $1.getType()">> $res, $x)]>;
// index_cast(extsi(x)) -> index_cast(x)
def IndexCastOfExtSI :
- Pat<(Arith_IndexCastOp (Arith_ExtSIOp $x)), (Arith_IndexCastOp $x)>;
+ Pat<(Arith_IndexCastOp (Arith_ExtSIOp $x), $exact),
+ (Arith_IndexCastOp $x, $exact)>;
//===----------------------------------------------------------------------===//
// IndexCastUIOp
@@ -304,14 +305,15 @@ def IndexCastOfExtSI :
// index_castui(index_castui(x)) -> x, if dstType == srcType.
def IndexCastUIOfIndexCastUI :
- Pat<(Arith_IndexCastUIOp:$res (Arith_IndexCastUIOp $x, $nneg1), $nneg2),
+ Pat<(Arith_IndexCastUIOp:$res
+ (Arith_IndexCastUIOp $x, $nneg1, $exact1), $nneg2, $exact2),
(replaceWithValue $x),
[(Constraint<CPred<"$0.getType() == $1.getType()">> $res, $x)]>;
// index_castui(extui(x)) -> index_castui(x)
def IndexCastUIOfExtUI :
- Pat<(Arith_IndexCastUIOp (Arith_ExtUIOp $x, $nneg1), $nneg2),
- (Arith_IndexCastUIOp $x, $nneg1)>;
+ Pat<(Arith_IndexCastUIOp (Arith_ExtUIOp $x, $nneg1), $nneg2, $exact),
+ (Arith_IndexCastUIOp $x, $nneg1, $exact)>;
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index 4dc29897cec26..cfbf93b7f2761 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -588,6 +588,15 @@ func.func @indexCastOfSignExtend(%arg0: i8) -> index {
return %idx : index
}
+// CHECK-LABEL: @indexCastOfSignExtend_exact
+// CHECK: %[[res:.+]] = arith.index_cast %arg0 exact : i8 to index
+// CHECK: return %[[res]]
+func.func @indexCastOfSignExtend_exact(%arg0: i8) -> index {
+ %ext = arith.extsi %arg0 : i8 to i16
+ %idx = arith.index_cast %ext exact : i16 to index
+ return %idx : index
+}
+
// CHECK-LABEL: @indexCastUIOfUnsignedExtend
// CHECK: %[[res:.+]] = arith.index_castui %arg0 : i8 to index
// CHECK: return %[[res]]
@@ -616,6 +625,24 @@ func.func @indexCastUIOfUnsignedExtend_nneg_on_castui(%arg0: i8) -> index {
return %idx : index
}
+// CHECK-LABEL: @indexCastUIOfUnsignedExtend_exact
+// CHECK: %[[res:.+]] = arith.index_castui %arg0 exact : i8 to index
+// CHECK: return %[[res]]
+func.func @indexCastUIOfUnsignedExtend_exact(%arg0: i8) -> index {
+ %ext = arith.extui %arg0 : i8 to i16
+ %idx = arith.index_castui %ext exact : i16 to index
+ return %idx : index
+}
+
+// CHECK-LABEL: @indexCastUIOfUnsignedExtend_nneg_exact
+// CHECK: %[[res:.+]] = arith.index_castui %arg0 nneg exact : i8 to index
+// CHECK: return %[[res]]
+func.func @indexCastUIOfUnsignedExtend_nneg_exact(%arg0: i8) -> index {
+ %ext = arith.extui %arg0 nneg : i8 to i16
+ %idx = arith.index_castui %ext exact : i16 to index
+ return %idx : index
+}
+
// CHECK-LABEL: @indexCastFold
// CHECK: %[[res:.*]] = arith.constant -2 : index
// CHECK: return %[[res]]
>From f10aaa235c802b6cb1930c31f0436e4e182d61d6 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Wed, 25 Feb 2026 16:21:56 -0500
Subject: [PATCH 6/7] [mlir][arith] Update lowerings for index_cast
---
.../Conversion/ArithToLLVM/ArithToLLVM.cpp | 36 ++++++++++---
.../Conversion/ArithToLLVM/arith-to-llvm.mlir | 52 +++++++++++++++++++
2 files changed, 80 insertions(+), 8 deletions(-)
diff --git a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
index e7f561e8a4d67..e0e1be35e4e1d 100644
--- a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
+++ b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
@@ -311,13 +311,32 @@ LogicalResult IndexCastOpLowering<OpTy, ExtCastTy>::matchAndRewrite(
if constexpr (std::is_same_v<ExtCastTy, LLVM::ZExtOp>)
isNonNeg = op.getNonNeg();
+ bool isExact = op.getExact();
+
+ // Map exact to the appropriate overflow flag(s) for truncation:
+ // index_cast (signed) exact -> trunc nsw
+ // index_castui (unsigned) exact -> trunc nuw
+ // index_castui nneg exact -> trunc nuw nsw
+ LLVM::IntegerOverflowFlags truncOverflow = LLVM::IntegerOverflowFlags::none;
+ if (isExact) {
+ if constexpr (std::is_same_v<ExtCastTy, LLVM::SExtOp>) {
+ truncOverflow = LLVM::IntegerOverflowFlags::nsw;
+ } else {
+ truncOverflow = LLVM::IntegerOverflowFlags::nuw;
+ if (isNonNeg)
+ truncOverflow |= LLVM::IntegerOverflowFlags::nsw;
+ }
+ }
+
// Handle the scalar and 1D vector cases.
Type operandType = adaptor.getIn().getType();
if (!isa<LLVM::LLVMArrayType>(operandType)) {
Type targetType = this->typeConverter->convertType(resultType);
if (targetBits < sourceBits) {
- rewriter.replaceOpWithNewOp<LLVM::TruncOp>(op, targetType,
- adaptor.getIn());
+ auto truncOp = rewriter.replaceOpWithNewOp<LLVM::TruncOp>(
+ op, targetType, adaptor.getIn());
+ if (isExact)
+ truncOp.setOverflowFlags(truncOverflow);
} else {
auto extOp = rewriter.replaceOpWithNewOp<ExtCastTy>(op, targetType,
adaptor.getIn());
@@ -335,15 +354,16 @@ LogicalResult IndexCastOpLowering<OpTy, ExtCastTy>::matchAndRewrite(
[&](Type llvm1DVectorTy, ValueRange operands) -> Value {
typename OpTy::Adaptor adaptor(operands);
if (targetBits < sourceBits) {
- return LLVM::TruncOp::create(rewriter, op.getLoc(), llvm1DVectorTy,
- adaptor.getIn());
+ auto truncOp = LLVM::TruncOp::create(rewriter, op.getLoc(),
+ llvm1DVectorTy, adaptor.getIn());
+ if (isExact)
+ truncOp.setOverflowFlags(truncOverflow);
+ return truncOp;
}
auto extOp = ExtCastTy::create(rewriter, op.getLoc(), llvm1DVectorTy,
adaptor.getIn());
- if constexpr (std::is_same_v<ExtCastTy, LLVM::ZExtOp>) {
- if (isNonNeg)
- extOp.setNonNeg(true);
- }
+ if constexpr (std::is_same_v<ExtCastTy, LLVM::ZExtOp>)
+ extOp.setNonNeg(isNonNeg);
return extOp;
},
rewriter);
diff --git a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
index 47069906fa110..2845df23293d5 100644
--- a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
+++ b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
@@ -160,6 +160,58 @@ func.func @index_castui_nneg_not_set(%arg0: i1) {
// -----
+// index_cast exact on truncation lowers to trunc nsw (signed semantics).
+// CHECK-LABEL: @index_cast_exact_trunc
+func.func @index_cast_exact_trunc(%arg0: index) {
+// CHECK: llvm.trunc %{{.*}} overflow<nsw> : i{{.*}} to i1
+ %0 = arith.index_cast %arg0 exact : index to i1
+ return
+}
+
+// -----
+
+// index_cast exact on widening: exact is vacuously true, sext has no flag.
+// CHECK-LABEL: @index_cast_exact_ext
+func.func @index_cast_exact_ext(%arg0: i1) {
+// CHECK: llvm.sext %{{.*}} : i1 to i{{.*}}
+// CHECK-NOT: nsw
+ %0 = arith.index_cast %arg0 exact : i1 to index
+ return
+}
+
+// -----
+
+// index_castui exact on truncation lowers to trunc nuw (unsigned semantics).
+// CHECK-LABEL: @index_castui_exact_trunc
+func.func @index_castui_exact_trunc(%arg0: index) {
+// CHECK: llvm.trunc %{{.*}} overflow<nuw> : i{{.*}} to i1
+ %0 = arith.index_castui %arg0 exact : index to i1
+ return
+}
+
+// -----
+
+// index_castui nneg exact on truncation lowers to trunc nuw nsw.
+// CHECK-LABEL: @index_castui_nneg_exact_trunc
+func.func @index_castui_nneg_exact_trunc(%arg0: index) {
+// CHECK: llvm.trunc %{{.*}} overflow<nsw, nuw> : i{{.*}} to i1
+ %0 = arith.index_castui %arg0 nneg exact : index to i1
+ return
+}
+
+// -----
+
+// index_castui exact on widening: exact is vacuously true, zext has no flag.
+// CHECK-LABEL: @index_castui_exact_ext
+func.func @index_castui_exact_ext(%arg0: i1) {
+// CHECK: llvm.zext %{{.*}} : i1 to i{{.*}}
+// CHECK-NOT: nuw
+ %0 = arith.index_castui %arg0 exact : i1 to index
+ return
+}
+
+// -----
+
// Checking conversion of signed integer types to floating point.
// CHECK-LABEL: @sitofp
func.func @sitofp(%arg0 : i32, %arg1 : i64) {
>From 74797e1da5f72c87e354096f61ae83bc566faf90 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Thu, 26 Feb 2026 08:59:53 -0500
Subject: [PATCH 7/7] Address comments
---
.../include/mlir/Dialect/Arith/IR/ArithOps.td | 2 +-
.../Dialect/Arith/IR/ArithCanonicalization.td | 6 ++--
mlir/test/Dialect/Arith/canonicalize.mlir | 36 ++++++++++++++++++-
mlir/test/Dialect/Arith/ops.mlir | 2 +-
4 files changed, 41 insertions(+), 5 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
index c0987fba98562..d081204154b1f 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
@@ -1648,7 +1648,7 @@ def Arith_IndexCastUIOp
UnitAttr:$isExact);
let results = (outs IndexCastTypeConstraint:$out);
let assemblyFormat = [{
- $in (`nneg` $nonNeg^)? (`exact` $isExact^)? attr-dict
+ $in oilist(`exact` $isExact | `nneg` $nonNeg) attr-dict
`:` type($in) `to` type($out)
}];
let hasFolder = 1;
diff --git a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
index 5eb973174a53c..f16cd91a37f7c 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
+++ b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
@@ -303,12 +303,14 @@ def IndexCastOfExtSI :
// IndexCastUIOp
//===----------------------------------------------------------------------===//
-// index_castui(index_castui(x)) -> x, if dstType == srcType.
+// index_castui(index_castui(x)) -> x, if dstType == srcType and at least one
+// exact flag is set (guaranteeing no information loss in either cast).
def IndexCastUIOfIndexCastUI :
Pat<(Arith_IndexCastUIOp:$res
(Arith_IndexCastUIOp $x, $nneg1, $exact1), $nneg2, $exact2),
(replaceWithValue $x),
- [(Constraint<CPred<"$0.getType() == $1.getType()">> $res, $x)]>;
+ [(Constraint<CPred<"$0.getType() == $1.getType()">> $res, $x),
+ (Constraint<CPred<"$0 || $1">> $exact1, $exact2)]>;
// index_castui(extui(x)) -> index_castui(x)
def IndexCastUIOfExtUI :
diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index cfbf93b7f2761..d43af670d2793 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -635,7 +635,7 @@ func.func @indexCastUIOfUnsignedExtend_exact(%arg0: i8) -> index {
}
// CHECK-LABEL: @indexCastUIOfUnsignedExtend_nneg_exact
-// CHECK: %[[res:.+]] = arith.index_castui %arg0 nneg exact : i8 to index
+// CHECK: %[[res:.+]] = arith.index_castui %arg0 exact nneg : i8 to index
// CHECK: return %[[res]]
func.func @indexCastUIOfUnsignedExtend_nneg_exact(%arg0: i8) -> index {
%ext = arith.extui %arg0 nneg : i8 to i16
@@ -643,6 +643,40 @@ func.func @indexCastUIOfUnsignedExtend_nneg_exact(%arg0: i8) -> index {
return %idx : index
}
+// index_castui(index_castui(x)) -> x only when at least one exact is set.
+// CHECK-LABEL: @indexCastUIOfIndexCastUI_no_exact
+// CHECK: arith.index_castui
+// CHECK: arith.index_castui
+func.func @indexCastUIOfIndexCastUI_no_exact(%arg0: i32) -> i32 {
+ %idx = arith.index_castui %arg0 : i32 to index
+ %res = arith.index_castui %idx : index to i32
+ return %res : i32
+}
+
+// CHECK-LABEL: @indexCastUIOfIndexCastUI_exact_inner
+// CHECK: return %arg0 : i32
+func.func @indexCastUIOfIndexCastUI_exact_inner(%arg0: i32) -> i32 {
+ %idx = arith.index_castui %arg0 exact : i32 to index
+ %res = arith.index_castui %idx : index to i32
+ return %res : i32
+}
+
+// CHECK-LABEL: @indexCastUIOfIndexCastUI_exact_outer
+// CHECK: return %arg0 : i32
+func.func @indexCastUIOfIndexCastUI_exact_outer(%arg0: i32) -> i32 {
+ %idx = arith.index_castui %arg0 : i32 to index
+ %res = arith.index_castui %idx exact : index to i32
+ return %res : i32
+}
+
+// CHECK-LABEL: @indexCastUIOfIndexCastUI_exact_both
+// CHECK: return %arg0 : i32
+func.func @indexCastUIOfIndexCastUI_exact_both(%arg0: i32) -> i32 {
+ %idx = arith.index_castui %arg0 exact : i32 to index
+ %res = arith.index_castui %idx exact : index to i32
+ return %res : i32
+}
+
// CHECK-LABEL: @indexCastFold
// CHECK: %[[res:.*]] = arith.constant -2 : index
// CHECK: return %[[res]]
diff --git a/mlir/test/Dialect/Arith/ops.mlir b/mlir/test/Dialect/Arith/ops.mlir
index b7ad2ff8a8694..a9eabe97ebfcd 100644
--- a/mlir/test/Dialect/Arith/ops.mlir
+++ b/mlir/test/Dialect/Arith/ops.mlir
@@ -993,7 +993,7 @@ func.func @test_index_castui_exact(%arg0 : i32) -> index {
}
// CHECK-LABEL: test_index_castui_nneg_exact
-// CHECK: arith.index_castui %{{.*}} nneg exact : i32 to index
+// CHECK: arith.index_castui %{{.*}} exact nneg : i32 to index
func.func @test_index_castui_nneg_exact(%arg0 : i32) -> index {
%0 = arith.index_castui %arg0 nneg exact : i32 to index
return %0 : index
More information about the Mlir-commits
mailing list