[Mlir-commits] [mlir] [mlir][arith] Add `nneg` to index_castui. (PR #183383)
Erick Ochoa Lopez
llvmlistbot at llvm.org
Wed Feb 25 13:49:06 PST 2026
https://github.com/amd-eochoalo updated https://github.com/llvm/llvm-project/pull/183383
>From 2bb6bc335fc27f61f867d3ab24f7b3e4d8f4bd8f Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Wed, 25 Feb 2026 15:12:14 -0500
Subject: [PATCH 1/5] [mlir][arith] Add nneg to index_castui.
---
.../include/mlir/Dialect/Arith/IR/ArithOps.td | 19 ++++++++++++++++++-
1 file changed, 18 insertions(+), 1 deletion(-)
diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
index c372038a8d43e..50cbd970ef6ac 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
@@ -1598,15 +1598,32 @@ def Arith_IndexCastOp
def Arith_IndexCastUIOp
: Arith_CastOp<"index_castui", IndexCastTypeConstraint, IndexCastTypeConstraint,
- [DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>]> {
+ [DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>,
+ DeclareOpInterfaceMethods<ArithNonNegFlagInterface>]> {
let summary = "unsigned cast between index and integer types";
let description = [{
Casts between scalar or vector integers and corresponding 'index' scalar or
vectors. Index is an integer of platform-specific bit width. If casting to
a wider integer, the value is zero-extended. If casting to a narrower
integer, the value is truncated.
+
+ When the `nneg` flag is present, the operand is assumed to be non-negative.
+ In this case, zero extension is equivalent to sign extension. When this
+ assumption is violated, the result is poison.
+
+ Example:
+
+ ```mlir
+ %0 = arith.index_castui %a : i32 to index
+ %1 = arith.index_castui %a nneg : i32 to index
+ ```
}];
+ let arguments = (ins IndexCastTypeConstraint:$in, UnitAttr:$nonNeg);
+ let results = (outs IndexCastTypeConstraint:$out);
+ let assemblyFormat = [{
+ $in (`nneg` $nonNeg^)? attr-dict `:` type($in) `to` type($out)
+ }];
let hasFolder = 1;
let hasCanonicalizer = 1;
}
>From 73baf64f31be1c4f251943a7d311244f0f9b46e3 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Wed, 25 Feb 2026 15:12:49 -0500
Subject: [PATCH 2/5] [mlir][arith] Add roundtrip tests
---
mlir/test/Dialect/Arith/ops.mlir | 15 ++++++++++++++-
1 file changed, 14 insertions(+), 1 deletion(-)
diff --git a/mlir/test/Dialect/Arith/ops.mlir b/mlir/test/Dialect/Arith/ops.mlir
index 9541cecb489f6..9765db69d6dd5 100644
--- a/mlir/test/Dialect/Arith/ops.mlir
+++ b/mlir/test/Dialect/Arith/ops.mlir
@@ -909,7 +909,6 @@ func.func @test_index_cast_scalable_vector1(%arg0 : vector<[8]xindex>) -> vector
return %0 : vector<[8]xi64>
}
-
// CHECK-LABEL: test_index_castui0
func.func @test_index_castui0(%arg0 : i32) -> index {
%0 = arith.index_castui %arg0 : i32 to index
@@ -958,6 +957,20 @@ func.func @test_index_castui_scalable_vector1(%arg0 : vector<[8]xindex>) -> vect
return %0 : vector<[8]xi64>
}
+// CHECK-LABEL: test_index_castui_nneg
+// CHECK: arith.index_castui %{{.*}} nneg : i32 to index
+func.func @test_index_castui_nneg(%arg0 : i32) -> index {
+ %0 = arith.index_castui %arg0 nneg : i32 to index
+ return %0 : index
+}
+
+// CHECK-LABEL: test_index_castui_nneg_vector
+// CHECK: arith.index_castui %{{.*}} nneg : vector<8xi32> to vector<8xindex>
+func.func @test_index_castui_nneg_vector(%arg0 : vector<8xi32>) -> vector<8xindex> {
+ %0 = arith.index_castui %arg0 nneg : vector<8xi32> to vector<8xindex>
+ return %0 : vector<8xindex>
+}
+
// CHECK-LABEL: test_bitcast0
func.func @test_bitcast0(%arg0 : i64) -> f64 {
%0 = arith.bitcast %arg0 : i64 to f64
>From a85a7561cdbb4cee48d73710532b960d9c188926 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Wed, 25 Feb 2026 15:20:49 -0500
Subject: [PATCH 3/5] [mlir][arith] Update patterns and canonicalizations
---
.../Dialect/Arith/IR/ArithCanonicalization.td | 6 +++---
mlir/test/Dialect/Arith/canonicalize.mlir | 16 +++++++++++++---
2 files changed, 16 insertions(+), 6 deletions(-)
diff --git a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
index 8be2af5cb3cfc..fb9c16db91431 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
+++ b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
@@ -304,14 +304,14 @@ def IndexCastOfExtSI :
// index_castui(index_castui(x)) -> x, if dstType == srcType.
def IndexCastUIOfIndexCastUI :
- Pat<(Arith_IndexCastUIOp:$res (Arith_IndexCastUIOp $x)),
+ Pat<(Arith_IndexCastUIOp:$res (Arith_IndexCastUIOp $x, $nneg1), $nneg2),
(replaceWithValue $x),
[(Constraint<CPred<"$0.getType() == $1.getType()">> $res, $x)]>;
// index_castui(extui(x)) -> index_castui(x)
def IndexCastUIOfExtUI :
- Pat<(Arith_IndexCastUIOp (Arith_ExtUIOp $x, $nneg)),
- (Arith_IndexCastUIOp $x)>;
+ Pat<(Arith_IndexCastUIOp (Arith_ExtUIOp $x, $nneg1), $nneg2),
+ (Arith_IndexCastUIOp $x, $nneg1)>;
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index 52f4a54feebab..4dc29897cec26 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -597,15 +597,25 @@ func.func @indexCastUIOfUnsignedExtend(%arg0: i8) -> index {
return %idx : index
}
-// CHECK-LABEL: @indexCastUIOfUnsignedExtend_nneg
-// CHECK: %[[res:.+]] = arith.index_castui %arg0 : i8 to index
+// CHECK-LABEL: @indexCastUIOfUnsignedExtend_nneg_on_extui
+// CHECK: %[[res:.+]] = arith.index_castui %arg0 nneg : i8 to index
// CHECK: return %[[res]]
-func.func @indexCastUIOfUnsignedExtend_nneg(%arg0: i8) -> index {
+func.func @indexCastUIOfUnsignedExtend_nneg_on_extui(%arg0: i8) -> index {
%ext = arith.extui %arg0 nneg : i8 to i16
%idx = arith.index_castui %ext : i16 to index
return %idx : index
}
+// CHECK-LABEL: @indexCastUIOfUnsignedExtend_nneg_on_castui
+// CHECK: %[[res:.+]] = arith.index_castui %arg0 : i8 to index
+// CHECK-NOT: nneg
+// CHECK: return %[[res]]
+func.func @indexCastUIOfUnsignedExtend_nneg_on_castui(%arg0: i8) -> index {
+ %ext = arith.extui %arg0 : i8 to i16
+ %idx = arith.index_castui %ext nneg : i16 to index
+ return %idx : index
+}
+
// CHECK-LABEL: @indexCastFold
// CHECK: %[[res:.*]] = arith.constant -2 : index
// CHECK: return %[[res]]
>From adabb2ec15f4ce4d9628b4084aae90511a798ccd Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Wed, 25 Feb 2026 15:25:37 -0500
Subject: [PATCH 4/5] [mlir][arith] Update lowering
---
.../Conversion/ArithToLLVM/ArithToLLVM.cpp | 23 +++++++++++++++----
.../Conversion/ArithToLLVM/arith-to-llvm.mlir | 19 +++++++++++++++
2 files changed, 37 insertions(+), 5 deletions(-)
diff --git a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
index 178dcd419264d..e7f561e8a4d67 100644
--- a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
+++ b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
@@ -307,15 +307,23 @@ LogicalResult IndexCastOpLowering<OpTy, ExtCastTy>::matchAndRewrite(
return success();
}
+ bool isNonNeg = false;
+ if constexpr (std::is_same_v<ExtCastTy, LLVM::ZExtOp>)
+ isNonNeg = op.getNonNeg();
+
// Handle the scalar and 1D vector cases.
Type operandType = adaptor.getIn().getType();
if (!isa<LLVM::LLVMArrayType>(operandType)) {
Type targetType = this->typeConverter->convertType(resultType);
- if (targetBits < sourceBits)
+ if (targetBits < sourceBits) {
rewriter.replaceOpWithNewOp<LLVM::TruncOp>(op, targetType,
adaptor.getIn());
- else
- rewriter.replaceOpWithNewOp<ExtCastTy>(op, targetType, adaptor.getIn());
+ } else {
+ auto extOp = rewriter.replaceOpWithNewOp<ExtCastTy>(op, targetType,
+ adaptor.getIn());
+ if constexpr (std::is_same_v<ExtCastTy, LLVM::ZExtOp>)
+ extOp.setNonNeg(isNonNeg);
+ }
return success();
}
@@ -330,8 +338,13 @@ LogicalResult IndexCastOpLowering<OpTy, ExtCastTy>::matchAndRewrite(
return LLVM::TruncOp::create(rewriter, op.getLoc(), llvm1DVectorTy,
adaptor.getIn());
}
- return ExtCastTy::create(rewriter, op.getLoc(), llvm1DVectorTy,
- adaptor.getIn());
+ auto extOp = ExtCastTy::create(rewriter, op.getLoc(), llvm1DVectorTy,
+ adaptor.getIn());
+ if constexpr (std::is_same_v<ExtCastTy, LLVM::ZExtOp>) {
+ if (isNonNeg)
+ extOp.setNonNeg(true);
+ }
+ return extOp;
},
rewriter);
}
diff --git a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
index 7ae27a884d5d3..47069906fa110 100644
--- a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
+++ b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
@@ -141,6 +141,25 @@ func.func @vector_index_castui(%arg0: vector<2xindex>, %arg1: vector<2xi1>) {
// -----
+// CHECK-LABEL: @index_castui_nneg
+func.func @index_castui_nneg(%arg0: i1) {
+// CHECK: llvm.zext nneg %{{.*}} : i1 to i{{.*}}
+ %0 = arith.index_castui %arg0 nneg : i1 to index
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @index_castui_nneg_not_set
+func.func @index_castui_nneg_not_set(%arg0: i1) {
+// CHECK: llvm.zext %{{.*}} : i1 to i{{.*}}
+// CHECK-NOT: nneg
+ %0 = arith.index_castui %arg0 : i1 to index
+ return
+}
+
+// -----
+
// Checking conversion of signed integer types to floating point.
// CHECK-LABEL: @sitofp
func.func @sitofp(%arg0 : i32, %arg1 : i64) {
>From 3bbdb48f06a1b38e0ff665c88b7dd7e9639b7462 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Wed, 25 Feb 2026 16:48:47 -0500
Subject: [PATCH 5/5] Update comment and example
---
mlir/include/mlir/Dialect/Arith/IR/ArithOps.td | 8 +++++---
1 file changed, 5 insertions(+), 3 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
index 50cbd970ef6ac..b308ce29e6468 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
@@ -1607,15 +1607,17 @@ def Arith_IndexCastUIOp
a wider integer, the value is zero-extended. If casting to a narrower
integer, the value is truncated.
- When the `nneg` flag is present, the operand is assumed to be non-negative.
- In this case, zero extension is equivalent to sign extension. When this
- assumption is violated, the result is poison.
+ When the `nneg` flag is present, the operand is assumed to have
+ the most significant bit set to 0. In this case, zero extension
+ is equivalent to sign extension. When this assumption is violated,
+ the result is poison.
Example:
```mlir
%0 = arith.index_castui %a : i32 to index
%1 = arith.index_castui %a nneg : i32 to index
+ %2 = arith.index_castui %b nneg : index to i32
```
}];
More information about the Mlir-commits
mailing list