[Mlir-commits] [mlir] [mlir][arith] Fix SelectOp unsafe int range inference with uninitialized range case (PR #173716)

Longsheng Du llvmlistbot at llvm.org
Mon Jan 12 22:52:14 PST 2026


https://github.com/LongshengDu updated https://github.com/llvm/llvm-project/pull/173716

>From 65f5b35ed24b14bb60b2ed67de731bbe03c39ea5 Mon Sep 17 00:00:00 2001
From: Longsheng Du <longshengd at nvidia.com>
Date: Sat, 27 Dec 2025 03:43:39 -0800
Subject: [PATCH 1/4] fix select range

---
 .../Arith/IR/InferIntRangeInterfaceImpls.cpp   |  6 +++++-
 .../Dialect/Arith/int-range-interface.mlir     | 18 ++++++++++++++++++
 mlir/test/lib/Dialect/Test/TestOpDefs.cpp      | 10 ++++++++++
 mlir/test/lib/Dialect/Test/TestOps.td          | 17 +++++++++++++++++
 4 files changed, 50 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp b/mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp
index 7673185487eef..4d20430b39f91 100644
--- a/mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp
+++ b/mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp
@@ -329,7 +329,11 @@ void arith::SelectOp::inferResultRangesFromOptional(
       setResultRange(getResult(), trueCase);
     return;
   }
-  setResultRange(getResult(), IntegerValueRange::join(trueCase, falseCase));
+
+  if (trueCase.isUninitialized() || falseCase.isUninitialized())
+    setResultRange(getResult(), IntegerValueRange{});
+  else
+    setResultRange(getResult(), IntegerValueRange::join(trueCase, falseCase));
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Arith/int-range-interface.mlir b/mlir/test/Dialect/Arith/int-range-interface.mlir
index 130782ba9f525..30b7128dab42c 100644
--- a/mlir/test/Dialect/Arith/int-range-interface.mlir
+++ b/mlir/test/Dialect/Arith/int-range-interface.mlir
@@ -663,6 +663,24 @@ func.func @select_union(%arg0 : index, %arg1 : i1) -> i1 {
     func.return %5 : i1
 }
 
+// CHECK-LABEL: func @select_undefined_union
+// CHECK-COUNT-2: arith.select
+// CHECK: %[[ret:.*]] = arith.cmpi eq
+// CHECK: return %[[ret]]
+
+func.func @select_undefined_union(%arg0: i1) -> i1 {
+  %c32 = arith.constant 32 : index
+  %c64 = arith.constant 64 : index
+  %0 = test.without_bounds : index
+  %1 = arith.select %arg0, %0, %c64 : index
+  %2 = arith.cmpi eq, %1, %c64 : index
+  %3 = test.without_bounds : index
+  %4 = arith.select %2, %c32, %3 : index
+  %5 = arith.cmpi eq, %4, %c32 : index
+
+  return %5 : i1
+}
+
 // CHECK-LABEL: func @if_union
 // CHECK: %[[true:.*]] = arith.constant true
 // CHECK: return %[[true]]
diff --git a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
index 868926520af05..110f83c75ef00 100644
--- a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
+++ b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
@@ -863,6 +863,16 @@ void TestWithBoundsOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
   setResultRanges(getResult(), {getUmin(), getUmax(), getSmin(), getSmax()});
 }
 
+//===----------------------------------------------------------------------===//
+// TestWithoutBoundsOp
+//===----------------------------------------------------------------------===//
+
+void TestWithoutBoundsOp::inferResultRangesFromOptional(
+    ArrayRef<IntegerValueRange> argRanges, SetIntLatticeFn setResultRanges) {
+  // mimic ops with uninitialized range
+  setResultRanges(getResult(), IntegerValueRange{});
+}
+
 //===----------------------------------------------------------------------===//
 // TestWithBoundsRegionOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 5417ae94f00d7..bceb49ebe17f6 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -3182,6 +3182,23 @@ def TestWithBoundsOp : TEST_Op<"with_bounds",
   let assemblyFormat = "attr-dict `:` type($fakeVal)";
 }
 
+def TestWithoutBoundsOp : TEST_Op<"without_bounds",
+                         [DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRangesFromOptional"]>,
+                         NoMemoryEffect]> {
+  let description = [{
+    Creates a value with uninitialized range for integer range analysis tests.
+
+    Example:
+
+    ```mlir
+    %0 = test.without_bounds : index
+    ```
+  }];
+  let results = (outs InferIntRangeType:$result);
+
+  let assemblyFormat = "attr-dict `:` type($result)";
+}
+
 def TestWithBoundsRegionOp : TEST_Op<"with_bounds_region",
                           [DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>,
                            SingleBlock, NoTerminator]> {

>From 6f06fe99d615e96cd8d937e2cb388370209e7d40 Mon Sep 17 00:00:00 2001
From: Longsheng Du <longshengd at nvidia.com>
Date: Mon, 29 Dec 2025 01:35:21 -0800
Subject: [PATCH 2/4] fix range to max

---
 mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp | 4 +++-
 mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp                | 2 ++
 2 files changed, 5 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp b/mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp
index 4d20430b39f91..1d12eff57aaa4 100644
--- a/mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp
+++ b/mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp
@@ -330,8 +330,10 @@ void arith::SelectOp::inferResultRangesFromOptional(
     return;
   }
 
+  // When one of the ranges is uninitialized, set the whole range to max
+  // otherwise the result will ignore the uninitialized range
   if (trueCase.isUninitialized() || falseCase.isUninitialized())
-    setResultRange(getResult(), IntegerValueRange{});
+    setResultRange(getResult(), IntegerValueRange::getMaxRange(getResult()));
   else
     setResultRange(getResult(), IntegerValueRange::join(trueCase, falseCase));
 }
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 331d7a244310f..59f9acf140074 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -4544,6 +4544,8 @@ static void nvvmInferResultRanges(Operation *op, Value result,
   if (auto rangeAttr = op->getAttrOfType<LLVM::ConstantRangeAttr>("range")) {
     setResultRanges(result, {rangeAttr.getLower(), rangeAttr.getUpper(),
                              rangeAttr.getLower(), rangeAttr.getUpper()});
+  } else {
+    setResultRanges(result, IntegerValueRange::getMaxRange(result).getValue());
   }
 }
 

>From eda22713e886d31c197954388f2a66491ec487f5 Mon Sep 17 00:00:00 2001
From: Longsheng Du <longshengd at nvidia.com>
Date: Tue, 30 Dec 2025 17:54:19 +0800
Subject: [PATCH 3/4] Apply suggestions from code review

Co-authored-by: Jakub Kuderski <kubakuderski at gmail.com>
---
 mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp | 2 +-
 mlir/test/lib/Dialect/Test/TestOpDefs.cpp                 | 2 +-
 mlir/test/lib/Dialect/Test/TestOps.td                     | 2 +-
 3 files changed, 3 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp b/mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp
index 1d12eff57aaa4..49f89e1bd17f3 100644
--- a/mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp
+++ b/mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp
@@ -331,7 +331,7 @@ void arith::SelectOp::inferResultRangesFromOptional(
   }
 
   // When one of the ranges is uninitialized, set the whole range to max
-  // otherwise the result will ignore the uninitialized range
+  // otherwise the result will ignore the uninitialized range.
   if (trueCase.isUninitialized() || falseCase.isUninitialized())
     setResultRange(getResult(), IntegerValueRange::getMaxRange(getResult()));
   else
diff --git a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
index 110f83c75ef00..765e44f3d16f3 100644
--- a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
+++ b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
@@ -869,7 +869,7 @@ void TestWithBoundsOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
 
 void TestWithoutBoundsOp::inferResultRangesFromOptional(
     ArrayRef<IntegerValueRange> argRanges, SetIntLatticeFn setResultRanges) {
-  // mimic ops with uninitialized range
+  // Mimic ops with uninitialized range.
   setResultRanges(getResult(), IntegerValueRange{});
 }
 
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index bceb49ebe17f6..6b18d72477a56 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -3184,7 +3184,7 @@ def TestWithBoundsOp : TEST_Op<"with_bounds",
 
 def TestWithoutBoundsOp : TEST_Op<"without_bounds",
                          [DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRangesFromOptional"]>,
-                         NoMemoryEffect]> {
+                          NoMemoryEffect]> {
   let description = [{
     Creates a value with uninitialized range for integer range analysis tests.
 

>From c6ddf14c5b59b8e777c9a22fd3545ca360a29736 Mon Sep 17 00:00:00 2001
From: Longsheng Du <longshengd at nvidia.com>
Date: Mon, 12 Jan 2026 22:45:31 -0800
Subject: [PATCH 4/4] add nvvm op test range

---
 mlir/test/Dialect/LLVMIR/nvvm-test-range.mlir | 6 ++++++
 1 file changed, 6 insertions(+)

diff --git a/mlir/test/Dialect/LLVMIR/nvvm-test-range.mlir b/mlir/test/Dialect/LLVMIR/nvvm-test-range.mlir
index fae40dc7806ba..68a3b72653b05 100644
--- a/mlir/test/Dialect/LLVMIR/nvvm-test-range.mlir
+++ b/mlir/test/Dialect/LLVMIR/nvvm-test-range.mlir
@@ -4,6 +4,8 @@ gpu.module @module{
         %tidx = nvvm.read.ptx.sreg.tid.x range <i32, 0, 32> : i32
         %tidy = nvvm.read.ptx.sreg.tid.y range <i32, 0, 128> : i32
         %tidz = nvvm.read.ptx.sreg.tid.z range <i32, 0, 4> : i32
+        %cidx = nvvm.read.ptx.sreg.cluster.ctaid.x : i32 // unspecified range
+        %cond = ub.poison : i1
         %c64 = arith.constant 64 : i32
         
         %1 = arith.cmpi sgt, %tidx, %c64 : i32
@@ -18,6 +20,9 @@ gpu.module @module{
         scf.if %3 {
             gpu.printf "threadidz"
         }
+        %4 = arith.select %cond, %cidx, %c64 : i32
+        gpu.printf "ctaidx", %4 : i32
+
         gpu.return
     }
 }
@@ -33,3 +38,4 @@ gpu.module @module{
 // CHECK: gpu.printf "threadidy"
 // CHECK: scf.if %[[false]] {
 // CHECK: gpu.printf "threadidz"
+// CHECK: arith.select



More information about the Mlir-commits mailing list