[Mlir-commits] [mlir] [mlir][arith] Add `nneg` to index_castui. (PR #183383)

Erick Ochoa Lopez llvmlistbot at llvm.org
Wed Feb 25 13:49:06 PST 2026


https://github.com/amd-eochoalo updated https://github.com/llvm/llvm-project/pull/183383

>From 2bb6bc335fc27f61f867d3ab24f7b3e4d8f4bd8f Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Wed, 25 Feb 2026 15:12:14 -0500
Subject: [PATCH 1/5] [mlir][arith] Add nneg to index_castui.

---
 .../include/mlir/Dialect/Arith/IR/ArithOps.td | 19 ++++++++++++++++++-
 1 file changed, 18 insertions(+), 1 deletion(-)

diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
index c372038a8d43e..50cbd970ef6ac 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
@@ -1598,15 +1598,32 @@ def Arith_IndexCastOp
 
 def Arith_IndexCastUIOp
   : Arith_CastOp<"index_castui", IndexCastTypeConstraint, IndexCastTypeConstraint,
-                 [DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>]> {
+                 [DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>,
+                  DeclareOpInterfaceMethods<ArithNonNegFlagInterface>]> {
   let summary = "unsigned cast between index and integer types";
   let description = [{
     Casts between scalar or vector integers and corresponding 'index' scalar or
     vectors. Index is an integer of platform-specific bit width. If casting to
     a wider integer, the value is zero-extended. If casting to a narrower
     integer, the value is truncated.
+
+    When the `nneg` flag is present, the operand is assumed to be non-negative.
+    In this case, zero extension is equivalent to sign extension. When this
+    assumption is violated, the result is poison.
+
+    Example:
+
+    ```mlir
+      %0 = arith.index_castui %a : i32 to index
+      %1 = arith.index_castui %a nneg : i32 to index
+    ```
   }];
 
+  let arguments = (ins IndexCastTypeConstraint:$in, UnitAttr:$nonNeg);
+  let results = (outs IndexCastTypeConstraint:$out);
+  let assemblyFormat = [{
+    $in (`nneg` $nonNeg^)? attr-dict `:` type($in) `to` type($out)
+  }];
   let hasFolder = 1;
   let hasCanonicalizer = 1;
 }

>From 73baf64f31be1c4f251943a7d311244f0f9b46e3 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Wed, 25 Feb 2026 15:12:49 -0500
Subject: [PATCH 2/5] [mlir][arith] Add roundtrip tests

---
 mlir/test/Dialect/Arith/ops.mlir | 15 ++++++++++++++-
 1 file changed, 14 insertions(+), 1 deletion(-)

diff --git a/mlir/test/Dialect/Arith/ops.mlir b/mlir/test/Dialect/Arith/ops.mlir
index 9541cecb489f6..9765db69d6dd5 100644
--- a/mlir/test/Dialect/Arith/ops.mlir
+++ b/mlir/test/Dialect/Arith/ops.mlir
@@ -909,7 +909,6 @@ func.func @test_index_cast_scalable_vector1(%arg0 : vector<[8]xindex>) -> vector
   return %0 : vector<[8]xi64>
 }
 
-
 // CHECK-LABEL: test_index_castui0
 func.func @test_index_castui0(%arg0 : i32) -> index {
   %0 = arith.index_castui %arg0 : i32 to index
@@ -958,6 +957,20 @@ func.func @test_index_castui_scalable_vector1(%arg0 : vector<[8]xindex>) -> vect
   return %0 : vector<[8]xi64>
 }
 
+// CHECK-LABEL: test_index_castui_nneg
+// CHECK: arith.index_castui %{{.*}} nneg : i32 to index
+func.func @test_index_castui_nneg(%arg0 : i32) -> index {
+  %0 = arith.index_castui %arg0 nneg : i32 to index
+  return %0 : index
+}
+
+// CHECK-LABEL: test_index_castui_nneg_vector
+// CHECK: arith.index_castui %{{.*}} nneg : vector<8xi32> to vector<8xindex>
+func.func @test_index_castui_nneg_vector(%arg0 : vector<8xi32>) -> vector<8xindex> {
+  %0 = arith.index_castui %arg0 nneg : vector<8xi32> to vector<8xindex>
+  return %0 : vector<8xindex>
+}
+
 // CHECK-LABEL: test_bitcast0
 func.func @test_bitcast0(%arg0 : i64) -> f64 {
   %0 = arith.bitcast %arg0 : i64 to f64

>From a85a7561cdbb4cee48d73710532b960d9c188926 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Wed, 25 Feb 2026 15:20:49 -0500
Subject: [PATCH 3/5] [mlir][arith] Update patterns and canonicalizations

---
 .../Dialect/Arith/IR/ArithCanonicalization.td    |  6 +++---
 mlir/test/Dialect/Arith/canonicalize.mlir        | 16 +++++++++++++---
 2 files changed, 16 insertions(+), 6 deletions(-)

diff --git a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
index 8be2af5cb3cfc..fb9c16db91431 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
+++ b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
@@ -304,14 +304,14 @@ def IndexCastOfExtSI :
 
 // index_castui(index_castui(x)) -> x, if dstType == srcType.
 def IndexCastUIOfIndexCastUI :
-    Pat<(Arith_IndexCastUIOp:$res (Arith_IndexCastUIOp $x)),
+    Pat<(Arith_IndexCastUIOp:$res (Arith_IndexCastUIOp $x, $nneg1), $nneg2),
         (replaceWithValue $x),
         [(Constraint<CPred<"$0.getType() == $1.getType()">> $res, $x)]>;
 
 // index_castui(extui(x)) -> index_castui(x)
 def IndexCastUIOfExtUI :
-    Pat<(Arith_IndexCastUIOp (Arith_ExtUIOp $x, $nneg)),
-        (Arith_IndexCastUIOp $x)>;
+    Pat<(Arith_IndexCastUIOp (Arith_ExtUIOp $x, $nneg1), $nneg2),
+        (Arith_IndexCastUIOp $x, $nneg1)>;
 
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index 52f4a54feebab..4dc29897cec26 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -597,15 +597,25 @@ func.func @indexCastUIOfUnsignedExtend(%arg0: i8) -> index {
   return %idx : index
 }
 
-// CHECK-LABEL: @indexCastUIOfUnsignedExtend_nneg
-//       CHECK:   %[[res:.+]] = arith.index_castui %arg0 : i8 to index
+// CHECK-LABEL: @indexCastUIOfUnsignedExtend_nneg_on_extui
+//       CHECK:   %[[res:.+]] = arith.index_castui %arg0 nneg : i8 to index
 //       CHECK:   return %[[res]]
-func.func @indexCastUIOfUnsignedExtend_nneg(%arg0: i8) -> index {
+func.func @indexCastUIOfUnsignedExtend_nneg_on_extui(%arg0: i8) -> index {
   %ext = arith.extui %arg0 nneg : i8 to i16
   %idx = arith.index_castui %ext : i16 to index
   return %idx : index
 }
 
+// CHECK-LABEL: @indexCastUIOfUnsignedExtend_nneg_on_castui
+//       CHECK:   %[[res:.+]] = arith.index_castui %arg0 : i8 to index
+//   CHECK-NOT:   nneg
+//       CHECK:   return %[[res]]
+func.func @indexCastUIOfUnsignedExtend_nneg_on_castui(%arg0: i8) -> index {
+  %ext = arith.extui %arg0 : i8 to i16
+  %idx = arith.index_castui %ext nneg : i16 to index
+  return %idx : index
+}
+
 // CHECK-LABEL: @indexCastFold
 //       CHECK:   %[[res:.*]] = arith.constant -2 : index
 //       CHECK:   return %[[res]]

>From adabb2ec15f4ce4d9628b4084aae90511a798ccd Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Wed, 25 Feb 2026 15:25:37 -0500
Subject: [PATCH 4/5] [mlir][arith] Update lowering

---
 .../Conversion/ArithToLLVM/ArithToLLVM.cpp    | 23 +++++++++++++++----
 .../Conversion/ArithToLLVM/arith-to-llvm.mlir | 19 +++++++++++++++
 2 files changed, 37 insertions(+), 5 deletions(-)

diff --git a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
index 178dcd419264d..e7f561e8a4d67 100644
--- a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
+++ b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
@@ -307,15 +307,23 @@ LogicalResult IndexCastOpLowering<OpTy, ExtCastTy>::matchAndRewrite(
     return success();
   }
 
+  bool isNonNeg = false;
+  if constexpr (std::is_same_v<ExtCastTy, LLVM::ZExtOp>)
+    isNonNeg = op.getNonNeg();
+
   // Handle the scalar and 1D vector cases.
   Type operandType = adaptor.getIn().getType();
   if (!isa<LLVM::LLVMArrayType>(operandType)) {
     Type targetType = this->typeConverter->convertType(resultType);
-    if (targetBits < sourceBits)
+    if (targetBits < sourceBits) {
       rewriter.replaceOpWithNewOp<LLVM::TruncOp>(op, targetType,
                                                  adaptor.getIn());
-    else
-      rewriter.replaceOpWithNewOp<ExtCastTy>(op, targetType, adaptor.getIn());
+    } else {
+      auto extOp = rewriter.replaceOpWithNewOp<ExtCastTy>(op, targetType,
+                                                          adaptor.getIn());
+      if constexpr (std::is_same_v<ExtCastTy, LLVM::ZExtOp>)
+        extOp.setNonNeg(isNonNeg);
+    }
     return success();
   }
 
@@ -330,8 +338,13 @@ LogicalResult IndexCastOpLowering<OpTy, ExtCastTy>::matchAndRewrite(
           return LLVM::TruncOp::create(rewriter, op.getLoc(), llvm1DVectorTy,
                                        adaptor.getIn());
         }
-        return ExtCastTy::create(rewriter, op.getLoc(), llvm1DVectorTy,
-                                 adaptor.getIn());
+        auto extOp = ExtCastTy::create(rewriter, op.getLoc(), llvm1DVectorTy,
+                                       adaptor.getIn());
+        if constexpr (std::is_same_v<ExtCastTy, LLVM::ZExtOp>) {
+          if (isNonNeg)
+            extOp.setNonNeg(true);
+        }
+        return extOp;
       },
       rewriter);
 }
diff --git a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
index 7ae27a884d5d3..47069906fa110 100644
--- a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
+++ b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
@@ -141,6 +141,25 @@ func.func @vector_index_castui(%arg0: vector<2xindex>, %arg1: vector<2xi1>) {
 
 // -----
 
+// CHECK-LABEL: @index_castui_nneg
+func.func @index_castui_nneg(%arg0: i1) {
+// CHECK: llvm.zext nneg %{{.*}} : i1 to i{{.*}}
+  %0 = arith.index_castui %arg0 nneg : i1 to index
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @index_castui_nneg_not_set
+func.func @index_castui_nneg_not_set(%arg0: i1) {
+// CHECK: llvm.zext %{{.*}} : i1 to i{{.*}}
+// CHECK-NOT: nneg
+  %0 = arith.index_castui %arg0 : i1 to index
+  return
+}
+
+// -----
+
 // Checking conversion of signed integer types to floating point.
 // CHECK-LABEL: @sitofp
 func.func @sitofp(%arg0 : i32, %arg1 : i64) {

>From 3bbdb48f06a1b38e0ff665c88b7dd7e9639b7462 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Wed, 25 Feb 2026 16:48:47 -0500
Subject: [PATCH 5/5] Update comment and example

---
 mlir/include/mlir/Dialect/Arith/IR/ArithOps.td | 8 +++++---
 1 file changed, 5 insertions(+), 3 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
index 50cbd970ef6ac..b308ce29e6468 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
@@ -1607,15 +1607,17 @@ def Arith_IndexCastUIOp
     a wider integer, the value is zero-extended. If casting to a narrower
     integer, the value is truncated.
 
-    When the `nneg` flag is present, the operand is assumed to be non-negative.
-    In this case, zero extension is equivalent to sign extension. When this
-    assumption is violated, the result is poison.
+    When the `nneg` flag is present, the operand is assumed to have
+    the most significant bit set to 0. In this case, zero extension
+    is equivalent to sign extension. When this assumption is violated,
+    the result is poison.
 
     Example:
 
     ```mlir
       %0 = arith.index_castui %a : i32 to index
       %1 = arith.index_castui %a nneg : i32 to index
+      %2 = arith.index_castui %b nneg : index to i32
     ```
   }];
 



More information about the Mlir-commits mailing list