[Mlir-commits] [mlir] 377efc7 - [mlir][arith] Fix SelectOp unsafe int range inference with uninitialized range case (#173716)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Jan 13 09:17:34 PST 2026
Author: Longsheng Du
Date: 2026-01-13T09:17:22-08:00
New Revision: 377efc730beb292d1f51025d51ea27d86196ff06
URL: https://github.com/llvm/llvm-project/commit/377efc730beb292d1f51025d51ea27d86196ff06
DIFF: https://github.com/llvm/llvm-project/commit/377efc730beb292d1f51025d51ea27d86196ff06.diff
LOG: [mlir][arith] Fix SelectOp unsafe int range inference with uninitialized range case (#173716)
This PR fixes a bug in `arith::SelectOp::inferResultRangesFromOptional`
where uninitialized SelectOp branch int ranges were incorrectly joined
with initialized int ranges during dataflow analysis, leading to
incorrect folding in `-int-range-optimizations`.
**The Issue:**
When a `arith.select` branch has an uninitialized range (e.g., from an
op like `nvvm.read.ptx.sreg.cluster.ctaid.x`, `scf.switch`, `llvm.call`,
... that lacks range inference), the analysis computed
`IntegerValueRange::join(Uninitialized, Constant) = Constant`. This
caused the `arith.select` to be replaced with the constant, ignoring the
dynamic branch.
**Example:**
```mlir
// The bug before fix: -int-range-optimizations replaces %1 with %c32
// led to incorrect results and unsafe behaviours
%0 = nvvm.read.ptx.sreg.cluster.ctaid.x : i32 // Uninitialized int range
%c32 = arith.constant 32 : i32
%1 = arith.select %cond, %0, %c32 : i32
```
**The Fix:**
Explicitly ensure `inferResultRangesFromOptional` all select cases have
initialized ranges before combining them. If any case is uninitialized,
the result is now treated as max range. Also added default max range for
`nvvmInferResultRanges` and `test.without_bounds` op to simulate and
test uninitialized ranges.
---------
Co-authored-by: Jakub Kuderski <kubakuderski at gmail.com>
Added:
Modified:
mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp
mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
mlir/test/Dialect/Arith/int-range-interface.mlir
mlir/test/Dialect/LLVMIR/nvvm-test-range.mlir
mlir/test/lib/Dialect/Test/TestOpDefs.cpp
mlir/test/lib/Dialect/Test/TestOps.td
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp b/mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp
index 7673185487eef..49f89e1bd17f3 100644
--- a/mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp
+++ b/mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp
@@ -329,7 +329,13 @@ void arith::SelectOp::inferResultRangesFromOptional(
setResultRange(getResult(), trueCase);
return;
}
- setResultRange(getResult(), IntegerValueRange::join(trueCase, falseCase));
+
+ // When one of the ranges is uninitialized, set the whole range to max
+ // otherwise the result will ignore the uninitialized range.
+ if (trueCase.isUninitialized() || falseCase.isUninitialized())
+ setResultRange(getResult(), IntegerValueRange::getMaxRange(getResult()));
+ else
+ setResultRange(getResult(), IntegerValueRange::join(trueCase, falseCase));
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 331d7a244310f..59f9acf140074 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -4544,6 +4544,8 @@ static void nvvmInferResultRanges(Operation *op, Value result,
if (auto rangeAttr = op->getAttrOfType<LLVM::ConstantRangeAttr>("range")) {
setResultRanges(result, {rangeAttr.getLower(), rangeAttr.getUpper(),
rangeAttr.getLower(), rangeAttr.getUpper()});
+ } else {
+ setResultRanges(result, IntegerValueRange::getMaxRange(result).getValue());
}
}
diff --git a/mlir/test/Dialect/Arith/int-range-interface.mlir b/mlir/test/Dialect/Arith/int-range-interface.mlir
index 130782ba9f525..30b7128dab42c 100644
--- a/mlir/test/Dialect/Arith/int-range-interface.mlir
+++ b/mlir/test/Dialect/Arith/int-range-interface.mlir
@@ -663,6 +663,24 @@ func.func @select_union(%arg0 : index, %arg1 : i1) -> i1 {
func.return %5 : i1
}
+// CHECK-LABEL: func @select_undefined_union
+// CHECK-COUNT-2: arith.select
+// CHECK: %[[ret:.*]] = arith.cmpi eq
+// CHECK: return %[[ret]]
+
+func.func @select_undefined_union(%arg0: i1) -> i1 {
+ %c32 = arith.constant 32 : index
+ %c64 = arith.constant 64 : index
+ %0 = test.without_bounds : index
+ %1 = arith.select %arg0, %0, %c64 : index
+ %2 = arith.cmpi eq, %1, %c64 : index
+ %3 = test.without_bounds : index
+ %4 = arith.select %2, %c32, %3 : index
+ %5 = arith.cmpi eq, %4, %c32 : index
+
+ return %5 : i1
+}
+
// CHECK-LABEL: func @if_union
// CHECK: %[[true:.*]] = arith.constant true
// CHECK: return %[[true]]
diff --git a/mlir/test/Dialect/LLVMIR/nvvm-test-range.mlir b/mlir/test/Dialect/LLVMIR/nvvm-test-range.mlir
index fae40dc7806ba..68a3b72653b05 100644
--- a/mlir/test/Dialect/LLVMIR/nvvm-test-range.mlir
+++ b/mlir/test/Dialect/LLVMIR/nvvm-test-range.mlir
@@ -4,6 +4,8 @@ gpu.module @module{
%tidx = nvvm.read.ptx.sreg.tid.x range <i32, 0, 32> : i32
%tidy = nvvm.read.ptx.sreg.tid.y range <i32, 0, 128> : i32
%tidz = nvvm.read.ptx.sreg.tid.z range <i32, 0, 4> : i32
+ %cidx = nvvm.read.ptx.sreg.cluster.ctaid.x : i32 // unspecified range
+ %cond = ub.poison : i1
%c64 = arith.constant 64 : i32
%1 = arith.cmpi sgt, %tidx, %c64 : i32
@@ -18,6 +20,9 @@ gpu.module @module{
scf.if %3 {
gpu.printf "threadidz"
}
+ %4 = arith.select %cond, %cidx, %c64 : i32
+ gpu.printf "ctaidx", %4 : i32
+
gpu.return
}
}
@@ -33,3 +38,4 @@ gpu.module @module{
// CHECK: gpu.printf "threadidy"
// CHECK: scf.if %[[false]] {
// CHECK: gpu.printf "threadidz"
+// CHECK: arith.select
diff --git a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
index 868926520af05..765e44f3d16f3 100644
--- a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
+++ b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
@@ -863,6 +863,16 @@ void TestWithBoundsOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
setResultRanges(getResult(), {getUmin(), getUmax(), getSmin(), getSmax()});
}
+//===----------------------------------------------------------------------===//
+// TestWithoutBoundsOp
+//===----------------------------------------------------------------------===//
+
+void TestWithoutBoundsOp::inferResultRangesFromOptional(
+ ArrayRef<IntegerValueRange> argRanges, SetIntLatticeFn setResultRanges) {
+ // Mimic ops with uninitialized range.
+ setResultRanges(getResult(), IntegerValueRange{});
+}
+
//===----------------------------------------------------------------------===//
// TestWithBoundsRegionOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 5417ae94f00d7..6b18d72477a56 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -3182,6 +3182,23 @@ def TestWithBoundsOp : TEST_Op<"with_bounds",
let assemblyFormat = "attr-dict `:` type($fakeVal)";
}
+def TestWithoutBoundsOp : TEST_Op<"without_bounds",
+ [DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRangesFromOptional"]>,
+ NoMemoryEffect]> {
+ let description = [{
+ Creates a value with uninitialized range for integer range analysis tests.
+
+ Example:
+
+ ```mlir
+ %0 = test.without_bounds : index
+ ```
+ }];
+ let results = (outs InferIntRangeType:$result);
+
+ let assemblyFormat = "attr-dict `:` type($result)";
+}
+
def TestWithBoundsRegionOp : TEST_Op<"with_bounds_region",
[DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>,
SingleBlock, NoTerminator]> {
More information about the Mlir-commits
mailing list