[Mlir-commits] [mlir] [mlir][arith] Add `exact` to `index_cast{, ui}` (PR #183395)

Erick Ochoa Lopez llvmlistbot at llvm.org
Thu Feb 26 06:51:44 PST 2026


https://github.com/amd-eochoalo updated https://github.com/llvm/llvm-project/pull/183395

>From e329975132e2bc2ecf3ca49e7ce51b2813aca813 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Wed, 25 Feb 2026 15:12:14 -0500
Subject: [PATCH 1/7] [mlir][arith] Add nneg to index_castui.

---
 .../include/mlir/Dialect/Arith/IR/ArithOps.td | 19 ++++++++++++++-
 .../Conversion/ArithToLLVM/ArithToLLVM.cpp    | 23 +++++++++++++++----
 .../Dialect/Arith/IR/ArithCanonicalization.td |  6 ++---
 .../Conversion/ArithToLLVM/arith-to-llvm.mlir | 19 +++++++++++++++
 mlir/test/Dialect/Arith/canonicalize.mlir     | 16 ++++++++++---
 mlir/test/Dialect/Arith/ops.mlir              | 15 +++++++++++-
 6 files changed, 85 insertions(+), 13 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
index c372038a8d43e..50cbd970ef6ac 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
@@ -1598,15 +1598,32 @@ def Arith_IndexCastOp
 
 def Arith_IndexCastUIOp
   : Arith_CastOp<"index_castui", IndexCastTypeConstraint, IndexCastTypeConstraint,
-                 [DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>]> {
+                 [DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>,
+                  DeclareOpInterfaceMethods<ArithNonNegFlagInterface>]> {
   let summary = "unsigned cast between index and integer types";
   let description = [{
     Casts between scalar or vector integers and corresponding 'index' scalar or
     vectors. Index is an integer of platform-specific bit width. If casting to
     a wider integer, the value is zero-extended. If casting to a narrower
     integer, the value is truncated.
+
+    When the `nneg` flag is present, the operand is assumed to be non-negative.
+    In this case, zero extension is equivalent to sign extension. When this
+    assumption is violated, the result is poison.
+
+    Example:
+
+    ```mlir
+      %0 = arith.index_castui %a : i32 to index
+      %1 = arith.index_castui %a nneg : i32 to index
+    ```
   }];
 
+  let arguments = (ins IndexCastTypeConstraint:$in, UnitAttr:$nonNeg);
+  let results = (outs IndexCastTypeConstraint:$out);
+  let assemblyFormat = [{
+    $in (`nneg` $nonNeg^)? attr-dict `:` type($in) `to` type($out)
+  }];
   let hasFolder = 1;
   let hasCanonicalizer = 1;
 }
diff --git a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
index 178dcd419264d..e7f561e8a4d67 100644
--- a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
+++ b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
@@ -307,15 +307,23 @@ LogicalResult IndexCastOpLowering<OpTy, ExtCastTy>::matchAndRewrite(
     return success();
   }
 
+  bool isNonNeg = false;
+  if constexpr (std::is_same_v<ExtCastTy, LLVM::ZExtOp>)
+    isNonNeg = op.getNonNeg();
+
   // Handle the scalar and 1D vector cases.
   Type operandType = adaptor.getIn().getType();
   if (!isa<LLVM::LLVMArrayType>(operandType)) {
     Type targetType = this->typeConverter->convertType(resultType);
-    if (targetBits < sourceBits)
+    if (targetBits < sourceBits) {
       rewriter.replaceOpWithNewOp<LLVM::TruncOp>(op, targetType,
                                                  adaptor.getIn());
-    else
-      rewriter.replaceOpWithNewOp<ExtCastTy>(op, targetType, adaptor.getIn());
+    } else {
+      auto extOp = rewriter.replaceOpWithNewOp<ExtCastTy>(op, targetType,
+                                                          adaptor.getIn());
+      if constexpr (std::is_same_v<ExtCastTy, LLVM::ZExtOp>)
+        extOp.setNonNeg(isNonNeg);
+    }
     return success();
   }
 
@@ -330,8 +338,13 @@ LogicalResult IndexCastOpLowering<OpTy, ExtCastTy>::matchAndRewrite(
           return LLVM::TruncOp::create(rewriter, op.getLoc(), llvm1DVectorTy,
                                        adaptor.getIn());
         }
-        return ExtCastTy::create(rewriter, op.getLoc(), llvm1DVectorTy,
-                                 adaptor.getIn());
+        auto extOp = ExtCastTy::create(rewriter, op.getLoc(), llvm1DVectorTy,
+                                       adaptor.getIn());
+        if constexpr (std::is_same_v<ExtCastTy, LLVM::ZExtOp>) {
+          if (isNonNeg)
+            extOp.setNonNeg(true);
+        }
+        return extOp;
       },
       rewriter);
 }
diff --git a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
index 8be2af5cb3cfc..fb9c16db91431 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
+++ b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
@@ -304,14 +304,14 @@ def IndexCastOfExtSI :
 
 // index_castui(index_castui(x)) -> x, if dstType == srcType.
 def IndexCastUIOfIndexCastUI :
-    Pat<(Arith_IndexCastUIOp:$res (Arith_IndexCastUIOp $x)),
+    Pat<(Arith_IndexCastUIOp:$res (Arith_IndexCastUIOp $x, $nneg1), $nneg2),
         (replaceWithValue $x),
         [(Constraint<CPred<"$0.getType() == $1.getType()">> $res, $x)]>;
 
 // index_castui(extui(x)) -> index_castui(x)
 def IndexCastUIOfExtUI :
-    Pat<(Arith_IndexCastUIOp (Arith_ExtUIOp $x, $nneg)),
-        (Arith_IndexCastUIOp $x)>;
+    Pat<(Arith_IndexCastUIOp (Arith_ExtUIOp $x, $nneg1), $nneg2),
+        (Arith_IndexCastUIOp $x, $nneg1)>;
 
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
index 7ae27a884d5d3..47069906fa110 100644
--- a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
+++ b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
@@ -141,6 +141,25 @@ func.func @vector_index_castui(%arg0: vector<2xindex>, %arg1: vector<2xi1>) {
 
 // -----
 
+// CHECK-LABEL: @index_castui_nneg
+func.func @index_castui_nneg(%arg0: i1) {
+// CHECK: llvm.zext nneg %{{.*}} : i1 to i{{.*}}
+  %0 = arith.index_castui %arg0 nneg : i1 to index
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @index_castui_nneg_not_set
+func.func @index_castui_nneg_not_set(%arg0: i1) {
+// CHECK: llvm.zext %{{.*}} : i1 to i{{.*}}
+// CHECK-NOT: nneg
+  %0 = arith.index_castui %arg0 : i1 to index
+  return
+}
+
+// -----
+
 // Checking conversion of signed integer types to floating point.
 // CHECK-LABEL: @sitofp
 func.func @sitofp(%arg0 : i32, %arg1 : i64) {
diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index 52f4a54feebab..4dc29897cec26 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -597,15 +597,25 @@ func.func @indexCastUIOfUnsignedExtend(%arg0: i8) -> index {
   return %idx : index
 }
 
-// CHECK-LABEL: @indexCastUIOfUnsignedExtend_nneg
-//       CHECK:   %[[res:.+]] = arith.index_castui %arg0 : i8 to index
+// CHECK-LABEL: @indexCastUIOfUnsignedExtend_nneg_on_extui
+//       CHECK:   %[[res:.+]] = arith.index_castui %arg0 nneg : i8 to index
 //       CHECK:   return %[[res]]
-func.func @indexCastUIOfUnsignedExtend_nneg(%arg0: i8) -> index {
+func.func @indexCastUIOfUnsignedExtend_nneg_on_extui(%arg0: i8) -> index {
   %ext = arith.extui %arg0 nneg : i8 to i16
   %idx = arith.index_castui %ext : i16 to index
   return %idx : index
 }
 
+// CHECK-LABEL: @indexCastUIOfUnsignedExtend_nneg_on_castui
+//       CHECK:   %[[res:.+]] = arith.index_castui %arg0 : i8 to index
+//   CHECK-NOT:   nneg
+//       CHECK:   return %[[res]]
+func.func @indexCastUIOfUnsignedExtend_nneg_on_castui(%arg0: i8) -> index {
+  %ext = arith.extui %arg0 : i8 to i16
+  %idx = arith.index_castui %ext nneg : i16 to index
+  return %idx : index
+}
+
 // CHECK-LABEL: @indexCastFold
 //       CHECK:   %[[res:.*]] = arith.constant -2 : index
 //       CHECK:   return %[[res]]
diff --git a/mlir/test/Dialect/Arith/ops.mlir b/mlir/test/Dialect/Arith/ops.mlir
index 9541cecb489f6..9765db69d6dd5 100644
--- a/mlir/test/Dialect/Arith/ops.mlir
+++ b/mlir/test/Dialect/Arith/ops.mlir
@@ -909,7 +909,6 @@ func.func @test_index_cast_scalable_vector1(%arg0 : vector<[8]xindex>) -> vector
   return %0 : vector<[8]xi64>
 }
 
-
 // CHECK-LABEL: test_index_castui0
 func.func @test_index_castui0(%arg0 : i32) -> index {
   %0 = arith.index_castui %arg0 : i32 to index
@@ -958,6 +957,20 @@ func.func @test_index_castui_scalable_vector1(%arg0 : vector<[8]xindex>) -> vect
   return %0 : vector<[8]xi64>
 }
 
+// CHECK-LABEL: test_index_castui_nneg
+// CHECK: arith.index_castui %{{.*}} nneg : i32 to index
+func.func @test_index_castui_nneg(%arg0 : i32) -> index {
+  %0 = arith.index_castui %arg0 nneg : i32 to index
+  return %0 : index
+}
+
+// CHECK-LABEL: test_index_castui_nneg_vector
+// CHECK: arith.index_castui %{{.*}} nneg : vector<8xi32> to vector<8xindex>
+func.func @test_index_castui_nneg_vector(%arg0 : vector<8xi32>) -> vector<8xindex> {
+  %0 = arith.index_castui %arg0 nneg : vector<8xi32> to vector<8xindex>
+  return %0 : vector<8xindex>
+}
+
 // CHECK-LABEL: test_bitcast0
 func.func @test_bitcast0(%arg0 : i64) -> f64 {
   %0 = arith.bitcast %arg0 : i64 to f64

>From d6ee09abb0dd54c00047f26e80ca2412fb275068 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Wed, 25 Feb 2026 15:49:36 -0500
Subject: [PATCH 2/7] [mlir][arith] Add ArithExactFlagInterface.

---
 .../include/mlir/Dialect/Arith/IR/ArithOps.td |  3 +-
 .../Dialect/Arith/IR/ArithOpsInterfaces.td    | 47 +++++++++++++++++++
 2 files changed, 49 insertions(+), 1 deletion(-)

diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
index 50cbd970ef6ac..b11f5e8b8eb4d 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
@@ -160,7 +160,8 @@ class Arith_IntBinaryOpWithOverflowFlags<string mnemonic, list<Trait> traits = [
 
 class Arith_IntBinaryOpWithExactFlag<string mnemonic, list<Trait> traits = []> :
     Arith_BinaryOp<mnemonic, traits #
-      [DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>]>,
+      [DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>,
+       DeclareOpInterfaceMethods<ArithExactFlagInterface>]>,
     Arguments<(ins SignlessIntegerOrIndexLike:$lhs,
                SignlessIntegerOrIndexLike:$rhs,
                UnitAttr:$isExact)>,
diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOpsInterfaces.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOpsInterfaces.td
index d1b8e250cdb59..e8287ac2d6bcc 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOpsInterfaces.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOpsInterfaces.td
@@ -153,6 +153,53 @@ def ArithNonNegFlagInterface : OpInterface<"ArithNonNegFlagInterface"> {
   ];
 }
 
+def ArithExactFlagInterface : OpInterface<"ArithExactFlagInterface"> {
+  let description = [{
+    Access to op exact flag.
+  }];
+
+  let cppNamespace = "::mlir::arith";
+
+  let methods = [
+    InterfaceMethod<
+      /*desc=*/        "Returns whether the operation has the exact flag set",
+      /*returnType=*/  "bool",
+      /*methodName=*/  "getExact",
+      /*args=*/        (ins),
+      /*methodBody=*/  [{}],
+      /*defaultImpl=*/ [{
+        auto op = cast<ConcreteOp>(this->getOperation());
+        return op.getIsExactAttr() != nullptr;
+      }]
+      >,
+    InterfaceMethod<
+      /*desc=*/        "Set the exact flag for the operation",
+      /*returnType=*/  "void",
+      /*methodName=*/  "setExact",
+      /*args=*/        (ins "bool":$isExact),
+      /*methodBody=*/  [{}],
+      /*defaultImpl=*/ [{
+        auto op = cast<ConcreteOp>(this->getOperation());
+        if (isExact)
+          op.setIsExactAttr(UnitAttr::get(op->getContext()));
+        else
+          op.removeIsExactAttr();
+      }]
+      >,
+    StaticInterfaceMethod<
+      /*desc=*/        [{Returns the name of the exact flag attribute for
+                         the operation}],
+      /*returnType=*/  "StringRef",
+      /*methodName=*/  "getExactFlagAttrName",
+      /*args=*/        (ins),
+      /*methodBody=*/  [{}],
+      /*defaultImpl=*/ [{
+        return "isExact";
+      }]
+      >
+  ];
+}
+
 def ArithRoundingModeInterface : OpInterface<"ArithRoundingModeInterface"> {
   let description = [{
     Access to op rounding mode.

>From 8123bdcab57532df873e2b5476a19c480d884e9c Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Wed, 25 Feb 2026 15:54:38 -0500
Subject: [PATCH 3/7] [mlir][arith] Add exact attribute to index_cast{,ui}

---
 .../include/mlir/Dialect/Arith/IR/ArithOps.td | 34 ++++++++++++++++---
 1 file changed, 30 insertions(+), 4 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
index b11f5e8b8eb4d..c0987fba98562 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
@@ -1580,15 +1580,32 @@ def IndexCastTypeConstraint : TypeConstraint<Or<[
 
 def Arith_IndexCastOp
   : Arith_CastOp<"index_cast", IndexCastTypeConstraint, IndexCastTypeConstraint,
-                 [DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>]> {
+                 [DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>,
+                  DeclareOpInterfaceMethods<ArithExactFlagInterface>]> {
   let summary = "cast between index and integer types";
   let description = [{
     Casts between scalar or vector integers and corresponding 'index' scalar or
     vectors. Index is an integer of platform-specific bit width. If casting to
     a wider integer, the value is sign-extended. If casting to a narrower
     integer, the value is truncated.
+
+    If the `exact` attribute is present, it is assumed that the index type width
+    is such that the conversion does not lose information. When this assumption
+    is violated, the result is poison.
+
+    Example:
+
+    ```mlir
+      %0 = arith.index_cast %a : index to i64
+      %1 = arith.index_cast %a exact : index to i64
+    ```
   }];
 
+  let arguments = (ins IndexCastTypeConstraint:$in, UnitAttr:$isExact);
+  let results = (outs IndexCastTypeConstraint:$out);
+  let assemblyFormat = [{
+    $in (`exact` $isExact^)? attr-dict `:` type($in) `to` type($out)
+  }];
   let hasFolder = 1;
   let hasCanonicalizer = 1;
 }
@@ -1600,7 +1617,8 @@ def Arith_IndexCastOp
 def Arith_IndexCastUIOp
   : Arith_CastOp<"index_castui", IndexCastTypeConstraint, IndexCastTypeConstraint,
                  [DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>,
-                  DeclareOpInterfaceMethods<ArithNonNegFlagInterface>]> {
+                  DeclareOpInterfaceMethods<ArithNonNegFlagInterface>,
+                  DeclareOpInterfaceMethods<ArithExactFlagInterface>]> {
   let summary = "unsigned cast between index and integer types";
   let description = [{
     Casts between scalar or vector integers and corresponding 'index' scalar or
@@ -1612,18 +1630,26 @@ def Arith_IndexCastUIOp
     In this case, zero extension is equivalent to sign extension. When this
     assumption is violated, the result is poison.
 
+    If the `exact` attribute is present, it is assumed that the index type width
+    is such that the conversion does not lose information. When this assumption
+    is violated, the result is poison.
+
     Example:
 
     ```mlir
       %0 = arith.index_castui %a : i32 to index
       %1 = arith.index_castui %a nneg : i32 to index
+      %2 = arith.index_castui %a exact : i32 to index
+      %3 = arith.index_castui %a nneg exact : i32 to index
     ```
   }];
 
-  let arguments = (ins IndexCastTypeConstraint:$in, UnitAttr:$nonNeg);
+  let arguments = (ins IndexCastTypeConstraint:$in, UnitAttr:$nonNeg,
+                       UnitAttr:$isExact);
   let results = (outs IndexCastTypeConstraint:$out);
   let assemblyFormat = [{
-    $in (`nneg` $nonNeg^)? attr-dict `:` type($in) `to` type($out)
+    $in (`nneg` $nonNeg^)? (`exact` $isExact^)? attr-dict
+        `:` type($in) `to` type($out)
   }];
   let hasFolder = 1;
   let hasCanonicalizer = 1;

>From d624802197961593af945618c1671cfd84ab08f2 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Wed, 25 Feb 2026 15:55:34 -0500
Subject: [PATCH 4/7] [mlir][arith] Add roundtrip tests

---
 mlir/test/Dialect/Arith/ops.mlir | 28 ++++++++++++++++++++++++++++
 1 file changed, 28 insertions(+)

diff --git a/mlir/test/Dialect/Arith/ops.mlir b/mlir/test/Dialect/Arith/ops.mlir
index 9765db69d6dd5..b7ad2ff8a8694 100644
--- a/mlir/test/Dialect/Arith/ops.mlir
+++ b/mlir/test/Dialect/Arith/ops.mlir
@@ -909,6 +909,20 @@ func.func @test_index_cast_scalable_vector1(%arg0 : vector<[8]xindex>) -> vector
   return %0 : vector<[8]xi64>
 }
 
+// CHECK-LABEL: test_index_cast_exact
+// CHECK: arith.index_cast %{{.*}} exact : i32 to index
+func.func @test_index_cast_exact(%arg0 : i32) -> index {
+  %0 = arith.index_cast %arg0 exact : i32 to index
+  return %0 : index
+}
+
+// CHECK-LABEL: test_index_cast_exact_vector
+// CHECK: arith.index_cast %{{.*}} exact : vector<8xi32> to vector<8xindex>
+func.func @test_index_cast_exact_vector(%arg0 : vector<8xi32>) -> vector<8xindex> {
+  %0 = arith.index_cast %arg0 exact : vector<8xi32> to vector<8xindex>
+  return %0 : vector<8xindex>
+}
+
 // CHECK-LABEL: test_index_castui0
 func.func @test_index_castui0(%arg0 : i32) -> index {
   %0 = arith.index_castui %arg0 : i32 to index
@@ -971,6 +985,20 @@ func.func @test_index_castui_nneg_vector(%arg0 : vector<8xi32>) -> vector<8xinde
   return %0 : vector<8xindex>
 }
 
+// CHECK-LABEL: test_index_castui_exact
+// CHECK: arith.index_castui %{{.*}} exact : i32 to index
+func.func @test_index_castui_exact(%arg0 : i32) -> index {
+  %0 = arith.index_castui %arg0 exact : i32 to index
+  return %0 : index
+}
+
+// CHECK-LABEL: test_index_castui_nneg_exact
+// CHECK: arith.index_castui %{{.*}} nneg exact : i32 to index
+func.func @test_index_castui_nneg_exact(%arg0 : i32) -> index {
+  %0 = arith.index_castui %arg0 nneg exact : i32 to index
+  return %0 : index
+}
+
 // CHECK-LABEL: test_bitcast0
 func.func @test_bitcast0(%arg0 : i64) -> f64 {
   %0 = arith.bitcast %arg0 : i64 to f64

>From 3d2aa5460679f52df7b51f8081be45c581eb9cfc Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Wed, 25 Feb 2026 16:03:31 -0500
Subject: [PATCH 5/7] [mlir][arith] Update patterns and canonicalization tests

---
 .../Dialect/Arith/IR/ArithCanonicalization.td | 12 +++++----
 mlir/test/Dialect/Arith/canonicalize.mlir     | 27 +++++++++++++++++++
 2 files changed, 34 insertions(+), 5 deletions(-)

diff --git a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
index fb9c16db91431..5eb973174a53c 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
+++ b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
@@ -290,13 +290,14 @@ def SelectI1ToNot :
 
 // index_cast(index_cast(x)) -> x, if dstType == srcType.
 def IndexCastOfIndexCast :
-    Pat<(Arith_IndexCastOp:$res (Arith_IndexCastOp $x)),
+    Pat<(Arith_IndexCastOp:$res (Arith_IndexCastOp $x, $exact1), $exact2),
         (replaceWithValue $x),
         [(Constraint<CPred<"$0.getType() == $1.getType()">> $res, $x)]>;
 
 // index_cast(extsi(x)) -> index_cast(x)
 def IndexCastOfExtSI :
-    Pat<(Arith_IndexCastOp (Arith_ExtSIOp $x)), (Arith_IndexCastOp $x)>;
+    Pat<(Arith_IndexCastOp (Arith_ExtSIOp $x), $exact),
+        (Arith_IndexCastOp $x, $exact)>;
 
 //===----------------------------------------------------------------------===//
 // IndexCastUIOp
@@ -304,14 +305,15 @@ def IndexCastOfExtSI :
 
 // index_castui(index_castui(x)) -> x, if dstType == srcType.
 def IndexCastUIOfIndexCastUI :
-    Pat<(Arith_IndexCastUIOp:$res (Arith_IndexCastUIOp $x, $nneg1), $nneg2),
+    Pat<(Arith_IndexCastUIOp:$res
+          (Arith_IndexCastUIOp $x, $nneg1, $exact1), $nneg2, $exact2),
         (replaceWithValue $x),
         [(Constraint<CPred<"$0.getType() == $1.getType()">> $res, $x)]>;
 
 // index_castui(extui(x)) -> index_castui(x)
 def IndexCastUIOfExtUI :
-    Pat<(Arith_IndexCastUIOp (Arith_ExtUIOp $x, $nneg1), $nneg2),
-        (Arith_IndexCastUIOp $x, $nneg1)>;
+    Pat<(Arith_IndexCastUIOp (Arith_ExtUIOp $x, $nneg1), $nneg2, $exact),
+        (Arith_IndexCastUIOp $x, $nneg1, $exact)>;
 
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index 4dc29897cec26..cfbf93b7f2761 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -588,6 +588,15 @@ func.func @indexCastOfSignExtend(%arg0: i8) -> index {
   return %idx : index
 }
 
+// CHECK-LABEL: @indexCastOfSignExtend_exact
+//       CHECK:   %[[res:.+]] = arith.index_cast %arg0 exact : i8 to index
+//       CHECK:   return %[[res]]
+func.func @indexCastOfSignExtend_exact(%arg0: i8) -> index {
+  %ext = arith.extsi %arg0 : i8 to i16
+  %idx = arith.index_cast %ext exact : i16 to index
+  return %idx : index
+}
+
 // CHECK-LABEL: @indexCastUIOfUnsignedExtend
 //       CHECK:   %[[res:.+]] = arith.index_castui %arg0 : i8 to index
 //       CHECK:   return %[[res]]
@@ -616,6 +625,24 @@ func.func @indexCastUIOfUnsignedExtend_nneg_on_castui(%arg0: i8) -> index {
   return %idx : index
 }
 
+// CHECK-LABEL: @indexCastUIOfUnsignedExtend_exact
+//       CHECK:   %[[res:.+]] = arith.index_castui %arg0 exact : i8 to index
+//       CHECK:   return %[[res]]
+func.func @indexCastUIOfUnsignedExtend_exact(%arg0: i8) -> index {
+  %ext = arith.extui %arg0 : i8 to i16
+  %idx = arith.index_castui %ext exact : i16 to index
+  return %idx : index
+}
+
+// CHECK-LABEL: @indexCastUIOfUnsignedExtend_nneg_exact
+//       CHECK:   %[[res:.+]] = arith.index_castui %arg0 nneg exact : i8 to index
+//       CHECK:   return %[[res]]
+func.func @indexCastUIOfUnsignedExtend_nneg_exact(%arg0: i8) -> index {
+  %ext = arith.extui %arg0 nneg : i8 to i16
+  %idx = arith.index_castui %ext exact : i16 to index
+  return %idx : index
+}
+
 // CHECK-LABEL: @indexCastFold
 //       CHECK:   %[[res:.*]] = arith.constant -2 : index
 //       CHECK:   return %[[res]]

>From f10aaa235c802b6cb1930c31f0436e4e182d61d6 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Wed, 25 Feb 2026 16:21:56 -0500
Subject: [PATCH 6/7] [mlir][arith] Update lowerings for index_cast

---
 .../Conversion/ArithToLLVM/ArithToLLVM.cpp    | 36 ++++++++++---
 .../Conversion/ArithToLLVM/arith-to-llvm.mlir | 52 +++++++++++++++++++
 2 files changed, 80 insertions(+), 8 deletions(-)

diff --git a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
index e7f561e8a4d67..e0e1be35e4e1d 100644
--- a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
+++ b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
@@ -311,13 +311,32 @@ LogicalResult IndexCastOpLowering<OpTy, ExtCastTy>::matchAndRewrite(
   if constexpr (std::is_same_v<ExtCastTy, LLVM::ZExtOp>)
     isNonNeg = op.getNonNeg();
 
+  bool isExact = op.getExact();
+
+  // Map exact to the appropriate overflow flag(s) for truncation:
+  //   index_cast (signed) exact    -> trunc nsw
+  //   index_castui (unsigned) exact -> trunc nuw
+  //   index_castui nneg exact       -> trunc nuw nsw
+  LLVM::IntegerOverflowFlags truncOverflow = LLVM::IntegerOverflowFlags::none;
+  if (isExact) {
+    if constexpr (std::is_same_v<ExtCastTy, LLVM::SExtOp>) {
+      truncOverflow = LLVM::IntegerOverflowFlags::nsw;
+    } else {
+      truncOverflow = LLVM::IntegerOverflowFlags::nuw;
+      if (isNonNeg)
+        truncOverflow |= LLVM::IntegerOverflowFlags::nsw;
+    }
+  }
+
   // Handle the scalar and 1D vector cases.
   Type operandType = adaptor.getIn().getType();
   if (!isa<LLVM::LLVMArrayType>(operandType)) {
     Type targetType = this->typeConverter->convertType(resultType);
     if (targetBits < sourceBits) {
-      rewriter.replaceOpWithNewOp<LLVM::TruncOp>(op, targetType,
-                                                 adaptor.getIn());
+      auto truncOp = rewriter.replaceOpWithNewOp<LLVM::TruncOp>(
+          op, targetType, adaptor.getIn());
+      if (isExact)
+        truncOp.setOverflowFlags(truncOverflow);
     } else {
       auto extOp = rewriter.replaceOpWithNewOp<ExtCastTy>(op, targetType,
                                                           adaptor.getIn());
@@ -335,15 +354,16 @@ LogicalResult IndexCastOpLowering<OpTy, ExtCastTy>::matchAndRewrite(
       [&](Type llvm1DVectorTy, ValueRange operands) -> Value {
         typename OpTy::Adaptor adaptor(operands);
         if (targetBits < sourceBits) {
-          return LLVM::TruncOp::create(rewriter, op.getLoc(), llvm1DVectorTy,
-                                       adaptor.getIn());
+          auto truncOp = LLVM::TruncOp::create(rewriter, op.getLoc(),
+                                               llvm1DVectorTy, adaptor.getIn());
+          if (isExact)
+            truncOp.setOverflowFlags(truncOverflow);
+          return truncOp;
         }
         auto extOp = ExtCastTy::create(rewriter, op.getLoc(), llvm1DVectorTy,
                                        adaptor.getIn());
-        if constexpr (std::is_same_v<ExtCastTy, LLVM::ZExtOp>) {
-          if (isNonNeg)
-            extOp.setNonNeg(true);
-        }
+        if constexpr (std::is_same_v<ExtCastTy, LLVM::ZExtOp>)
+          extOp.setNonNeg(isNonNeg);
         return extOp;
       },
       rewriter);
diff --git a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
index 47069906fa110..2845df23293d5 100644
--- a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
+++ b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
@@ -160,6 +160,58 @@ func.func @index_castui_nneg_not_set(%arg0: i1) {
 
 // -----
 
+// index_cast exact on truncation lowers to trunc nsw (signed semantics).
+// CHECK-LABEL: @index_cast_exact_trunc
+func.func @index_cast_exact_trunc(%arg0: index) {
+// CHECK: llvm.trunc %{{.*}} overflow<nsw> : i{{.*}} to i1
+  %0 = arith.index_cast %arg0 exact : index to i1
+  return
+}
+
+// -----
+
+// index_cast exact on widening: exact is vacuously true, sext has no flag.
+// CHECK-LABEL: @index_cast_exact_ext
+func.func @index_cast_exact_ext(%arg0: i1) {
+// CHECK: llvm.sext %{{.*}} : i1 to i{{.*}}
+// CHECK-NOT: nsw
+  %0 = arith.index_cast %arg0 exact : i1 to index
+  return
+}
+
+// -----
+
+// index_castui exact on truncation lowers to trunc nuw (unsigned semantics).
+// CHECK-LABEL: @index_castui_exact_trunc
+func.func @index_castui_exact_trunc(%arg0: index) {
+// CHECK: llvm.trunc %{{.*}} overflow<nuw> : i{{.*}} to i1
+  %0 = arith.index_castui %arg0 exact : index to i1
+  return
+}
+
+// -----
+
+// index_castui nneg exact on truncation lowers to trunc nuw nsw.
+// CHECK-LABEL: @index_castui_nneg_exact_trunc
+func.func @index_castui_nneg_exact_trunc(%arg0: index) {
+// CHECK: llvm.trunc %{{.*}} overflow<nsw, nuw> : i{{.*}} to i1
+  %0 = arith.index_castui %arg0 nneg exact : index to i1
+  return
+}
+
+// -----
+
+// index_castui exact on widening: exact is vacuously true, zext has no flag.
+// CHECK-LABEL: @index_castui_exact_ext
+func.func @index_castui_exact_ext(%arg0: i1) {
+// CHECK: llvm.zext %{{.*}} : i1 to i{{.*}}
+// CHECK-NOT: nuw
+  %0 = arith.index_castui %arg0 exact : i1 to index
+  return
+}
+
+// -----
+
 // Checking conversion of signed integer types to floating point.
 // CHECK-LABEL: @sitofp
 func.func @sitofp(%arg0 : i32, %arg1 : i64) {

>From 74797e1da5f72c87e354096f61ae83bc566faf90 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Thu, 26 Feb 2026 08:59:53 -0500
Subject: [PATCH 7/7] Address comments

---
 .../include/mlir/Dialect/Arith/IR/ArithOps.td |  2 +-
 .../Dialect/Arith/IR/ArithCanonicalization.td |  6 ++--
 mlir/test/Dialect/Arith/canonicalize.mlir     | 36 ++++++++++++++++++-
 mlir/test/Dialect/Arith/ops.mlir              |  2 +-
 4 files changed, 41 insertions(+), 5 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
index c0987fba98562..d081204154b1f 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
@@ -1648,7 +1648,7 @@ def Arith_IndexCastUIOp
                        UnitAttr:$isExact);
   let results = (outs IndexCastTypeConstraint:$out);
   let assemblyFormat = [{
-    $in (`nneg` $nonNeg^)? (`exact` $isExact^)? attr-dict
+    $in oilist(`exact` $isExact | `nneg` $nonNeg) attr-dict
         `:` type($in) `to` type($out)
   }];
   let hasFolder = 1;
diff --git a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
index 5eb973174a53c..f16cd91a37f7c 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
+++ b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
@@ -303,12 +303,14 @@ def IndexCastOfExtSI :
 // IndexCastUIOp
 //===----------------------------------------------------------------------===//
 
-// index_castui(index_castui(x)) -> x, if dstType == srcType.
+// index_castui(index_castui(x)) -> x, if dstType == srcType and at least one
+// exact flag is set (guaranteeing no information loss in either cast).
 def IndexCastUIOfIndexCastUI :
     Pat<(Arith_IndexCastUIOp:$res
           (Arith_IndexCastUIOp $x, $nneg1, $exact1), $nneg2, $exact2),
         (replaceWithValue $x),
-        [(Constraint<CPred<"$0.getType() == $1.getType()">> $res, $x)]>;
+        [(Constraint<CPred<"$0.getType() == $1.getType()">> $res, $x),
+         (Constraint<CPred<"$0 || $1">> $exact1, $exact2)]>;
 
 // index_castui(extui(x)) -> index_castui(x)
 def IndexCastUIOfExtUI :
diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index cfbf93b7f2761..d43af670d2793 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -635,7 +635,7 @@ func.func @indexCastUIOfUnsignedExtend_exact(%arg0: i8) -> index {
 }
 
 // CHECK-LABEL: @indexCastUIOfUnsignedExtend_nneg_exact
-//       CHECK:   %[[res:.+]] = arith.index_castui %arg0 nneg exact : i8 to index
+//       CHECK:   %[[res:.+]] = arith.index_castui %arg0 exact nneg : i8 to index
 //       CHECK:   return %[[res]]
 func.func @indexCastUIOfUnsignedExtend_nneg_exact(%arg0: i8) -> index {
   %ext = arith.extui %arg0 nneg : i8 to i16
@@ -643,6 +643,40 @@ func.func @indexCastUIOfUnsignedExtend_nneg_exact(%arg0: i8) -> index {
   return %idx : index
 }
 
+// index_castui(index_castui(x)) -> x only when at least one exact is set.
+// CHECK-LABEL: @indexCastUIOfIndexCastUI_no_exact
+//       CHECK:   arith.index_castui
+//       CHECK:   arith.index_castui
+func.func @indexCastUIOfIndexCastUI_no_exact(%arg0: i32) -> i32 {
+  %idx = arith.index_castui %arg0 : i32 to index
+  %res = arith.index_castui %idx : index to i32
+  return %res : i32
+}
+
+// CHECK-LABEL: @indexCastUIOfIndexCastUI_exact_inner
+//       CHECK:   return %arg0 : i32
+func.func @indexCastUIOfIndexCastUI_exact_inner(%arg0: i32) -> i32 {
+  %idx = arith.index_castui %arg0 exact : i32 to index
+  %res = arith.index_castui %idx : index to i32
+  return %res : i32
+}
+
+// CHECK-LABEL: @indexCastUIOfIndexCastUI_exact_outer
+//       CHECK:   return %arg0 : i32
+func.func @indexCastUIOfIndexCastUI_exact_outer(%arg0: i32) -> i32 {
+  %idx = arith.index_castui %arg0 : i32 to index
+  %res = arith.index_castui %idx exact : index to i32
+  return %res : i32
+}
+
+// CHECK-LABEL: @indexCastUIOfIndexCastUI_exact_both
+//       CHECK:   return %arg0 : i32
+func.func @indexCastUIOfIndexCastUI_exact_both(%arg0: i32) -> i32 {
+  %idx = arith.index_castui %arg0 exact : i32 to index
+  %res = arith.index_castui %idx exact : index to i32
+  return %res : i32
+}
+
 // CHECK-LABEL: @indexCastFold
 //       CHECK:   %[[res:.*]] = arith.constant -2 : index
 //       CHECK:   return %[[res]]
diff --git a/mlir/test/Dialect/Arith/ops.mlir b/mlir/test/Dialect/Arith/ops.mlir
index b7ad2ff8a8694..a9eabe97ebfcd 100644
--- a/mlir/test/Dialect/Arith/ops.mlir
+++ b/mlir/test/Dialect/Arith/ops.mlir
@@ -993,7 +993,7 @@ func.func @test_index_castui_exact(%arg0 : i32) -> index {
 }
 
 // CHECK-LABEL: test_index_castui_nneg_exact
-// CHECK: arith.index_castui %{{.*}} nneg exact : i32 to index
+// CHECK: arith.index_castui %{{.*}} exact nneg : i32 to index
 func.func @test_index_castui_nneg_exact(%arg0 : i32) -> index {
   %0 = arith.index_castui %arg0 nneg exact : i32 to index
   return %0 : index



More information about the Mlir-commits mailing list