[Mlir-commits] [mlir] [mlir][IntRange] Poison support in int-range analysis (PR #152932)

Ivan Butygin llvmlistbot at llvm.org
Sat Aug 23 02:15:19 PDT 2025


https://github.com/Hardcode84 updated https://github.com/llvm/llvm-project/pull/152932

>From 4d337344bab284ce73663e7bd79182f789013908 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sun, 10 Aug 2025 19:39:54 +0200
Subject: [PATCH 1/6] [mlir][IntRange] Poison support in int-range analysis

---
 .../mlir/Dialect/Arith/Transforms/Passes.td   |  3 +-
 mlir/include/mlir/Dialect/UB/IR/UBOps.h       |  1 +
 mlir/include/mlir/Dialect/UB/IR/UBOps.td      |  6 +-
 .../mlir/Interfaces/InferIntRangeInterface.h  | 15 +++
 .../Transforms/IntRangeOptimizations.cpp      | 25 ++++-
 mlir/lib/Dialect/UB/IR/UBOps.cpp              |  6 ++
 .../lib/Interfaces/InferIntRangeInterface.cpp | 93 +++++++++++++++---
 .../Interfaces/Utils/InferIntRangeCommon.cpp  | 94 ++++++++++++++-----
 .../Dialect/Arith/int-range-interface.mlir    | 20 ++++
 mlir/test/Dialect/Arith/int-range-opts.mlir   | 12 ++-
 mlir/test/Dialect/UB/int-range-interface.mlir | 24 +++++
 .../Dialect/Vector/int-range-interface.mlir   | 17 ++++
 12 files changed, 277 insertions(+), 39 deletions(-)
 create mode 100644 mlir/test/Dialect/UB/int-range-interface.mlir

diff --git a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td
index c7370b83fdb6c..15ea30ceca96d 100644
--- a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td
@@ -49,7 +49,8 @@ def ArithIntRangeOpts : Pass<"int-range-optimizations"> {
   // Explicitly depend on "arith" because this pass could create operations in
   // `arith` out of thin air in some cases.
   let dependentDialects = [
-    "::mlir::arith::ArithDialect"
+    "::mlir::arith::ArithDialect",
+    "::mlir::ub::UBDialect"
   ];
 }
 
diff --git a/mlir/include/mlir/Dialect/UB/IR/UBOps.h b/mlir/include/mlir/Dialect/UB/IR/UBOps.h
index 21de5cb0c182a..fc2dbad7a8aa7 100644
--- a/mlir/include/mlir/Dialect/UB/IR/UBOps.h
+++ b/mlir/include/mlir/Dialect/UB/IR/UBOps.h
@@ -12,6 +12,7 @@
 #include "mlir/Bytecode/BytecodeOpInterface.h"
 #include "mlir/IR/Dialect.h"
 #include "mlir/IR/OpImplementation.h"
+#include "mlir/Interfaces/InferIntRangeInterface.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 
 #include "mlir/Dialect/UB/IR/UBOpsInterfaces.h.inc"
diff --git a/mlir/include/mlir/Dialect/UB/IR/UBOps.td b/mlir/include/mlir/Dialect/UB/IR/UBOps.td
index f3d5a26ef6f9b..db88838d15dfd 100644
--- a/mlir/include/mlir/Dialect/UB/IR/UBOps.td
+++ b/mlir/include/mlir/Dialect/UB/IR/UBOps.td
@@ -9,8 +9,9 @@
 #ifndef MLIR_DIALECT_UB_IR_UBOPS_TD
 #define MLIR_DIALECT_UB_IR_UBOPS_TD
 
-include "mlir/Interfaces/SideEffectInterfaces.td"
 include "mlir/IR/AttrTypeBase.td"
+include "mlir/Interfaces/InferIntRangeInterface.td"
+include "mlir/Interfaces/SideEffectInterfaces.td"
 
 include "UBOpsInterfaces.td"
 
@@ -39,7 +40,8 @@ def PoisonAttr : UB_Attr<"Poison", "poison", [PoisonAttrInterface]> {
 // PoisonOp
 //===----------------------------------------------------------------------===//
 
-def PoisonOp : UB_Op<"poison", [ConstantLike, Pure]> {
+def PoisonOp : UB_Op<"poison", [ConstantLike, Pure,
+  DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>]> {
   let summary = "Poisoned constant operation.";
   let description = [{
     The `poison` operation materializes a compile-time poisoned constant value
diff --git a/mlir/include/mlir/Interfaces/InferIntRangeInterface.h b/mlir/include/mlir/Interfaces/InferIntRangeInterface.h
index 0e107e88f5232..4e4f3725a69fd 100644
--- a/mlir/include/mlir/Interfaces/InferIntRangeInterface.h
+++ b/mlir/include/mlir/Interfaces/InferIntRangeInterface.h
@@ -51,6 +51,9 @@ class ConstantIntRanges {
   /// The maximum value of an integer when it is interpreted as signed.
   const APInt &smax() const;
 
+  /// Get the bitwidth of the ranges.
+  unsigned getBitWidth() const;
+
   /// Return the bitwidth that should be used for integer ranges describing
   /// `type`. For concrete integer types, this is their bitwidth, for `index`,
   /// this is the internal storage bitwidth of `index` attributes, and for
@@ -62,6 +65,10 @@ class ConstantIntRanges {
   /// sint_max(width)].
   static ConstantIntRanges maxRange(unsigned bitwidth);
 
+  /// Create a poisoned range, i.e. a range that represents no valid integer
+  /// values.
+  static ConstantIntRanges poison(unsigned bitwidth);
+
   /// Create a `ConstantIntRanges` with a constant value - that is, with the
   /// bounds [value, value] for both its signed interpretations.
   static ConstantIntRanges constant(const APInt &value);
@@ -96,6 +103,14 @@ class ConstantIntRanges {
   /// value.
   std::optional<APInt> getConstantValue() const;
 
+  /// Returns true if signed range is poisoned, i.e. no valid signed value
+  /// can be represented.
+  bool isSignedPoison() const;
+
+  /// Returns true if unsigned range is poisoned, i.e. no valid unsigned value
+  /// can be represented.
+  bool isUnsignedPoison() const;
+
   friend raw_ostream &operator<<(raw_ostream &os,
                                  const ConstantIntRanges &range);
 
diff --git a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
index 777ff0ecaa314..03da1e5327e39 100644
--- a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
@@ -14,6 +14,7 @@
 #include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
 #include "mlir/Analysis/DataFlow/IntegerRangeAnalysis.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/UB/IR/UBOps.h"
 #include "mlir/Dialect/Utils/StaticValueUtils.h"
 #include "mlir/IR/IRMapping.h"
 #include "mlir/IR/Matchers.h"
@@ -46,6 +47,16 @@ static std::optional<APInt> getMaybeConstantValue(DataFlowSolver &solver,
   return inferredRange.getConstantValue();
 }
 
+static bool isPoison(DataFlowSolver &solver, Value value) {
+  auto *maybeInferredRange =
+      solver.lookupState<IntegerValueRangeLattice>(value);
+  if (!maybeInferredRange || maybeInferredRange->getValue().isUninitialized())
+    return false;
+  const ConstantIntRanges &inferredRange =
+      maybeInferredRange->getValue().getValue();
+  return inferredRange.isSignedPoison() && inferredRange.isUnsignedPoison();
+}
+
 static void copyIntegerRange(DataFlowSolver &solver, Value oldVal,
                              Value newVal) {
   assert(oldVal.getType() == newVal.getType() &&
@@ -63,6 +74,17 @@ LogicalResult maybeReplaceWithConstant(DataFlowSolver &solver,
                                        RewriterBase &rewriter, Value value) {
   if (value.use_empty())
     return failure();
+
+  if (isPoison(solver, value)) {
+    Value poison =
+        ub::PoisonOp::create(rewriter, value.getLoc(), value.getType());
+    if (solver.lookupState<dataflow::IntegerValueRangeLattice>(poison))
+      solver.eraseState(poison);
+    copyIntegerRange(solver, value, poison);
+    rewriter.replaceAllUsesWith(value, poison);
+    return success();
+  }
+
   std::optional<APInt> maybeConstValue = getMaybeConstantValue(solver, value);
   if (!maybeConstValue.has_value())
     return failure();
@@ -131,7 +153,8 @@ struct MaterializeKnownConstantValues : public RewritePattern {
       return failure();
 
     auto needsReplacing = [&](Value v) {
-      return getMaybeConstantValue(solver, v).has_value() && !v.use_empty();
+      return (getMaybeConstantValue(solver, v) || isPoison(solver, v)) &&
+             !v.use_empty();
     };
     bool hasConstantResults = llvm::any_of(op->getResults(), needsReplacing);
     if (op->getNumRegions() == 0)
diff --git a/mlir/lib/Dialect/UB/IR/UBOps.cpp b/mlir/lib/Dialect/UB/IR/UBOps.cpp
index ee523f9522953..4bb6f0979cfaa 100644
--- a/mlir/lib/Dialect/UB/IR/UBOps.cpp
+++ b/mlir/lib/Dialect/UB/IR/UBOps.cpp
@@ -59,6 +59,12 @@ Operation *UBDialect::materializeConstant(OpBuilder &builder, Attribute value,
 
 OpFoldResult PoisonOp::fold(FoldAdaptor /*adaptor*/) { return getValue(); }
 
+void PoisonOp::inferResultRanges(ArrayRef<ConstantIntRanges> /*argRanges*/,
+                                 SetIntRangeFn setResultRange) {
+  unsigned width = ConstantIntRanges::getStorageBitwidth(getType());
+  setResultRange(getResult(), ConstantIntRanges::poison(width));
+}
+
 #include "mlir/Dialect/UB/IR/UBOpsInterfaces.cpp.inc"
 
 #define GET_ATTRDEF_CLASSES
diff --git a/mlir/lib/Interfaces/InferIntRangeInterface.cpp b/mlir/lib/Interfaces/InferIntRangeInterface.cpp
index 9f3e97d051c85..4f6ab0306229f 100644
--- a/mlir/lib/Interfaces/InferIntRangeInterface.cpp
+++ b/mlir/lib/Interfaces/InferIntRangeInterface.cpp
@@ -28,6 +28,8 @@ const APInt &ConstantIntRanges::smin() const { return sminVal; }
 
 const APInt &ConstantIntRanges::smax() const { return smaxVal; }
 
+unsigned ConstantIntRanges::getBitWidth() const { return umin().getBitWidth(); }
+
 unsigned ConstantIntRanges::getStorageBitwidth(Type type) {
   type = getElementTypeOrSelf(type);
   if (type.isIndex())
@@ -42,6 +44,21 @@ ConstantIntRanges ConstantIntRanges::maxRange(unsigned bitwidth) {
   return fromUnsigned(APInt::getZero(bitwidth), APInt::getMaxValue(bitwidth));
 }
 
+ConstantIntRanges ConstantIntRanges::poison(unsigned bitwidth) {
+  if (bitwidth == 0) {
+    auto zero = APInt::getZero(0);
+    return {zero, zero, zero, zero};
+  }
+
+  // Poison is represented by an empty range.
+  auto zero = APInt::getZero(bitwidth);
+  auto one = zero + 1;
+  auto onem = zero - 1;
+  // For i1 the valid unsigned range is [0, 1] and the valid signed range
+  // is [-1, 0].
+  return {one, zero, zero, onem};
+}
+
 ConstantIntRanges ConstantIntRanges::constant(const APInt &value) {
   return {value, value, value, value};
 }
@@ -85,15 +102,37 @@ ConstantIntRanges
 ConstantIntRanges::rangeUnion(const ConstantIntRanges &other) const {
   // "Not an integer" poisons everything and also cannot be fed to comparison
   // operators.
-  if (umin().getBitWidth() == 0)
+  if (getBitWidth() == 0)
     return *this;
-  if (other.umin().getBitWidth() == 0)
+  if (other.getBitWidth() == 0)
     return other;
 
-  const APInt &uminUnion = umin().ult(other.umin()) ? umin() : other.umin();
-  const APInt &umaxUnion = umax().ugt(other.umax()) ? umax() : other.umax();
-  const APInt &sminUnion = smin().slt(other.smin()) ? smin() : other.smin();
-  const APInt &smaxUnion = smax().sgt(other.smax()) ? smax() : other.smax();
+  APInt uminUnion;
+  APInt umaxUnion;
+  APInt sminUnion;
+  APInt smaxUnion;
+
+  if (isUnsignedPoison()) {
+    uminUnion = other.umin();
+    umaxUnion = other.umax();
+  } else if (other.isUnsignedPoison()) {
+    uminUnion = umin();
+    umaxUnion = umax();
+  } else {
+    uminUnion = umin().ult(other.umin()) ? umin() : other.umin();
+    umaxUnion = umax().ugt(other.umax()) ? umax() : other.umax();
+  }
+
+  if (isSignedPoison()) {
+    sminUnion = other.smin();
+    smaxUnion = other.smax();
+  } else if (other.isSignedPoison()) {
+    sminUnion = smin();
+    smaxUnion = smax();
+  } else {
+    sminUnion = smin().slt(other.smin()) ? smin() : other.smin();
+    smaxUnion = smax().sgt(other.smax()) ? smax() : other.smax();
+  }
 
   return {uminUnion, umaxUnion, sminUnion, smaxUnion};
 }
@@ -102,15 +141,37 @@ ConstantIntRanges
 ConstantIntRanges::intersection(const ConstantIntRanges &other) const {
   // "Not an integer" poisons everything and also cannot be fed to comparison
   // operators.
-  if (umin().getBitWidth() == 0)
+  if (getBitWidth() == 0)
     return *this;
-  if (other.umin().getBitWidth() == 0)
+  if (other.getBitWidth() == 0)
     return other;
 
-  const APInt &uminIntersect = umin().ugt(other.umin()) ? umin() : other.umin();
-  const APInt &umaxIntersect = umax().ult(other.umax()) ? umax() : other.umax();
-  const APInt &sminIntersect = smin().sgt(other.smin()) ? smin() : other.smin();
-  const APInt &smaxIntersect = smax().slt(other.smax()) ? smax() : other.smax();
+  APInt uminIntersect;
+  APInt umaxIntersect;
+  APInt sminIntersect;
+  APInt smaxIntersect;
+
+  if (isUnsignedPoison()) {
+    uminIntersect = umin();
+    umaxIntersect = umax();
+  } else if (other.isUnsignedPoison()) {
+    uminIntersect = other.umin();
+    umaxIntersect = other.umax();
+  } else {
+    uminIntersect = umin().ugt(other.umin()) ? umin() : other.umin();
+    umaxIntersect = umax().ult(other.umax()) ? umax() : other.umax();
+  }
+
+  if (isSignedPoison()) {
+    sminIntersect = smin();
+    smaxIntersect = smax();
+  } else if (other.isSignedPoison()) {
+    sminIntersect = other.smin();
+    smaxIntersect = other.smax();
+  } else {
+    sminIntersect = smin().sgt(other.smin()) ? smin() : other.smin();
+    smaxIntersect = smax().slt(other.smax()) ? smax() : other.smax();
+  }
 
   return {uminIntersect, umaxIntersect, sminIntersect, smaxIntersect};
 }
@@ -124,6 +185,14 @@ std::optional<APInt> ConstantIntRanges::getConstantValue() const {
   return std::nullopt;
 }
 
+bool ConstantIntRanges::isSignedPoison() const {
+  return getBitWidth() > 0 && smin().sgt(smax());
+}
+
+bool ConstantIntRanges::isUnsignedPoison() const {
+  return getBitWidth() > 0 && umin().ugt(umax());
+}
+
 raw_ostream &mlir::operator<<(raw_ostream &os, const ConstantIntRanges &range) {
   os << "unsigned : [";
   range.umin().print(os, /*isSigned*/ false);
diff --git a/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp b/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp
index af4ea5ac1cec8..eeb4c4ac0446f 100644
--- a/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp
+++ b/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp
@@ -32,6 +32,29 @@ using namespace mlir;
 // General utilities
 //===----------------------------------------------------------------------===//
 
+/// If any of the arguments are poison, return poison.
+static ConstantIntRanges
+propagatePoison(const ConstantIntRanges &newRange,
+                ArrayRef<ConstantIntRanges> argRanges) {
+  APInt umin = newRange.umin();
+  APInt umax = newRange.umax();
+  APInt smin = newRange.smin();
+  APInt smax = newRange.smax();
+
+  unsigned width = umin.getBitWidth();
+  for (const ConstantIntRanges &argRange : argRanges) {
+    if (argRange.isSignedPoison()) {
+      smin = APInt::getZero(width);
+      smax = smin - 1;
+    }
+    if (argRange.isUnsignedPoison()) {
+      umax = APInt::getZero(width);
+      umin = umax + 1;
+    }
+  }
+  return {umin, umax, smin, smax};
+}
+
 /// Function that evaluates the result of doing something on arithmetic
 /// constants and returns std::nullopt on overflow.
 using ConstArithFn =
@@ -114,7 +137,7 @@ mlir::intrange::inferIndexOp(const InferRangeFn &inferFn,
     // Returing the 64-bit result preserves more information.
     return sixtyFour;
   ConstantIntRanges merged = sixtyFour.rangeUnion(thirtyTwoAsSixtyFour);
-  return merged;
+  return propagatePoison(merged, argRanges);
 }
 
 ConstantIntRanges mlir::intrange::extRange(const ConstantIntRanges &range,
@@ -123,21 +146,21 @@ ConstantIntRanges mlir::intrange::extRange(const ConstantIntRanges &range,
   APInt umax = range.umax().zext(destWidth);
   APInt smin = range.smin().sext(destWidth);
   APInt smax = range.smax().sext(destWidth);
-  return {umin, umax, smin, smax};
+  return propagatePoison({umin, umax, smin, smax}, range);
 }
 
 ConstantIntRanges mlir::intrange::extUIRange(const ConstantIntRanges &range,
                                              unsigned destWidth) {
   APInt umin = range.umin().zext(destWidth);
   APInt umax = range.umax().zext(destWidth);
-  return ConstantIntRanges::fromUnsigned(umin, umax);
+  return propagatePoison(ConstantIntRanges::fromUnsigned(umin, umax), range);
 }
 
 ConstantIntRanges mlir::intrange::extSIRange(const ConstantIntRanges &range,
                                              unsigned destWidth) {
   APInt smin = range.smin().sext(destWidth);
   APInt smax = range.smax().sext(destWidth);
-  return ConstantIntRanges::fromSigned(smin, smax);
+  return propagatePoison(ConstantIntRanges::fromSigned(smin, smax), range);
 }
 
 ConstantIntRanges mlir::intrange::truncRange(const ConstantIntRanges &range,
@@ -173,7 +196,7 @@ ConstantIntRanges mlir::intrange::truncRange(const ConstantIntRanges &range,
                                  : range.smin().trunc(destWidth);
   APInt smax = hasSignedOverflow ? APInt::getSignedMaxValue(destWidth)
                                  : range.smax().trunc(destWidth);
-  return {umin, umax, smin, smax};
+  return propagatePoison({umin, umax, smin, smax}, range);
 }
 
 //===----------------------------------------------------------------------===//
@@ -206,7 +229,7 @@ mlir::intrange::inferAdd(ArrayRef<ConstantIntRanges> argRanges,
       uadd, lhs.umin(), rhs.umin(), lhs.umax(), rhs.umax(), /*isSigned=*/false);
   ConstantIntRanges srange = computeBoundsBy(
       sadd, lhs.smin(), rhs.smin(), lhs.smax(), rhs.smax(), /*isSigned=*/true);
-  return urange.intersection(srange);
+  return propagatePoison(urange.intersection(srange), argRanges);
 }
 
 //===----------------------------------------------------------------------===//
@@ -238,7 +261,7 @@ mlir::intrange::inferSub(ArrayRef<ConstantIntRanges> argRanges,
       usub, lhs.umin(), rhs.umax(), lhs.umax(), rhs.umin(), /*isSigned=*/false);
   ConstantIntRanges srange = computeBoundsBy(
       ssub, lhs.smin(), rhs.smax(), lhs.smax(), rhs.smin(), /*isSigned=*/true);
-  return urange.intersection(srange);
+  return propagatePoison(urange.intersection(srange), argRanges);
 }
 
 //===----------------------------------------------------------------------===//
@@ -273,7 +296,7 @@ mlir::intrange::inferMul(ArrayRef<ConstantIntRanges> argRanges,
   ConstantIntRanges srange =
       minMaxBy(smul, {lhs.smin(), lhs.smax()}, {rhs.smin(), rhs.smax()},
                /*isSigned=*/true);
-  return urange.intersection(srange);
+  return propagatePoison(urange.intersection(srange), argRanges);
 }
 
 //===----------------------------------------------------------------------===//
@@ -305,7 +328,8 @@ static ConstantIntRanges inferDivURange(const ConstantIntRanges &lhs,
 
   // X u/ Y u<= X.
   APInt umax = lhsMax;
-  return ConstantIntRanges::fromUnsigned(umin, umax);
+  return propagatePoison(ConstantIntRanges::fromUnsigned(umin, umax),
+                         {lhs, rhs});
 }
 
 ConstantIntRanges
@@ -350,10 +374,12 @@ static ConstantIntRanges inferDivSRange(const ConstantIntRanges &lhs,
       APInt result = a.sdiv_ov(b, overflowed);
       return overflowed ? std::optional<APInt>() : fixup(a, b, result);
     };
-    return minMaxBy(sdiv, {lhsMin, lhsMax}, {rhsMin, rhsMax},
-                    /*isSigned=*/true);
+    return propagatePoison(minMaxBy(sdiv, {lhsMin, lhsMax}, {rhsMin, rhsMax},
+                                    /*isSigned=*/true),
+                           {lhs, rhs});
   }
-  return ConstantIntRanges::maxRange(rhsMin.getBitWidth());
+  return propagatePoison(ConstantIntRanges::maxRange(rhsMin.getBitWidth()),
+                         {lhs, rhs});
 }
 
 ConstantIntRanges
@@ -394,7 +420,7 @@ mlir::intrange::inferCeilDivS(ArrayRef<ConstantIntRanges> argRanges) {
     auto newLhs = ConstantIntRanges::fromSigned(lhs.smin() + 1, lhs.smax());
     result = result.rangeUnion(inferDivSRange(newLhs, rhs, ceilDivSIFix));
   }
-  return result;
+  return propagatePoison(result, argRanges);
 }
 
 ConstantIntRanges
@@ -424,6 +450,9 @@ mlir::intrange::inferRemS(ArrayRef<ConstantIntRanges> argRanges) {
   const APInt &lhsMin = lhs.smin(), &lhsMax = lhs.smax(), &rhsMin = rhs.smin(),
               &rhsMax = rhs.smax();
 
+  if (lhs.isSignedPoison() || rhs.isSignedPoison())
+    return ConstantIntRanges::poison(rhsMin.getBitWidth());
+
   unsigned width = rhsMax.getBitWidth();
   APInt smin = APInt::getSignedMinValue(width);
   APInt smax = APInt::getSignedMaxValue(width);
@@ -462,6 +491,9 @@ mlir::intrange::inferRemU(ArrayRef<ConstantIntRanges> argRanges) {
   const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
   const APInt &rhsMin = rhs.umin(), &rhsMax = rhs.umax();
 
+  if (lhs.isUnsignedPoison() || rhs.isUnsignedPoison())
+    return ConstantIntRanges::poison(rhsMin.getBitWidth());
+
   unsigned width = rhsMin.getBitWidth();
   APInt umin = APInt::getZero(width);
   // Remainder can't be larger than either of its arguments.
@@ -491,6 +523,8 @@ mlir::intrange::inferRemU(ArrayRef<ConstantIntRanges> argRanges) {
 ConstantIntRanges
 mlir::intrange::inferMaxS(ArrayRef<ConstantIntRanges> argRanges) {
   const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
+  if (lhs.isSignedPoison() || rhs.isSignedPoison())
+    return ConstantIntRanges::poison(lhs.smin().getBitWidth());
 
   const APInt &smin = lhs.smin().sgt(rhs.smin()) ? lhs.smin() : rhs.smin();
   const APInt &smax = lhs.smax().sgt(rhs.smax()) ? lhs.smax() : rhs.smax();
@@ -500,6 +534,8 @@ mlir::intrange::inferMaxS(ArrayRef<ConstantIntRanges> argRanges) {
 ConstantIntRanges
 mlir::intrange::inferMaxU(ArrayRef<ConstantIntRanges> argRanges) {
   const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
+  if (lhs.isUnsignedPoison() || rhs.isUnsignedPoison())
+    return ConstantIntRanges::poison(lhs.umin().getBitWidth());
 
   const APInt &umin = lhs.umin().ugt(rhs.umin()) ? lhs.umin() : rhs.umin();
   const APInt &umax = lhs.umax().ugt(rhs.umax()) ? lhs.umax() : rhs.umax();
@@ -509,6 +545,8 @@ mlir::intrange::inferMaxU(ArrayRef<ConstantIntRanges> argRanges) {
 ConstantIntRanges
 mlir::intrange::inferMinS(ArrayRef<ConstantIntRanges> argRanges) {
   const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
+  if (lhs.isSignedPoison() || rhs.isSignedPoison())
+    return ConstantIntRanges::poison(lhs.smin().getBitWidth());
 
   const APInt &smin = lhs.smin().slt(rhs.smin()) ? lhs.smin() : rhs.smin();
   const APInt &smax = lhs.smax().slt(rhs.smax()) ? lhs.smax() : rhs.smax();
@@ -518,6 +556,8 @@ mlir::intrange::inferMinS(ArrayRef<ConstantIntRanges> argRanges) {
 ConstantIntRanges
 mlir::intrange::inferMinU(ArrayRef<ConstantIntRanges> argRanges) {
   const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
+  if (lhs.isUnsignedPoison() || rhs.isUnsignedPoison())
+    return ConstantIntRanges::poison(lhs.umin().getBitWidth());
 
   const APInt &umin = lhs.umin().ult(rhs.umin()) ? lhs.umin() : rhs.umin();
   const APInt &umax = lhs.umax().ult(rhs.umax()) ? lhs.umax() : rhs.umax();
@@ -549,8 +589,10 @@ mlir::intrange::inferAnd(ArrayRef<ConstantIntRanges> argRanges) {
   auto andi = [](const APInt &a, const APInt &b) -> std::optional<APInt> {
     return a & b;
   };
-  return minMaxBy(andi, {lhsZeros, lhsOnes}, {rhsZeros, rhsOnes},
-                  /*isSigned=*/false);
+  return propagatePoison(minMaxBy(andi, {lhsZeros, lhsOnes},
+                                  {rhsZeros, rhsOnes},
+                                  /*isSigned=*/false),
+                         argRanges);
 }
 
 ConstantIntRanges
@@ -560,8 +602,9 @@ mlir::intrange::inferOr(ArrayRef<ConstantIntRanges> argRanges) {
   auto ori = [](const APInt &a, const APInt &b) -> std::optional<APInt> {
     return a | b;
   };
-  return minMaxBy(ori, {lhsZeros, lhsOnes}, {rhsZeros, rhsOnes},
-                  /*isSigned=*/false);
+  return propagatePoison(minMaxBy(ori, {lhsZeros, lhsOnes}, {rhsZeros, rhsOnes},
+                                  /*isSigned=*/false),
+                         argRanges);
 }
 
 /// Get bitmask of all bits which can change while iterating in
@@ -578,6 +621,9 @@ mlir::intrange::inferXor(ArrayRef<ConstantIntRanges> argRanges) {
   // Construct mask of varying bits for both ranges, xor values and then replace
   // masked bits with 0s and 1s to get min and max values respectively.
   ConstantIntRanges lhs = argRanges[0], rhs = argRanges[1];
+  if (lhs.isUnsignedPoison() || rhs.isUnsignedPoison())
+    return ConstantIntRanges::poison(lhs.umin().getBitWidth());
+
   APInt mask = getVaryingBitsMask(lhs) | getVaryingBitsMask(rhs);
   APInt res = lhs.umin() ^ rhs.umin();
   APInt min = res & ~mask;
@@ -620,7 +666,7 @@ mlir::intrange::inferShl(ArrayRef<ConstantIntRanges> argRanges,
   ConstantIntRanges srange =
       minMaxBy(sshl, {lhs.smin(), lhs.smax()}, {rhsUMin, rhsUMax},
                /*isSigned=*/true);
-  return urange.intersection(srange);
+  return propagatePoison(urange.intersection(srange), argRanges);
 }
 
 ConstantIntRanges
@@ -631,8 +677,10 @@ mlir::intrange::inferShrS(ArrayRef<ConstantIntRanges> argRanges) {
     return r.uge(r.getBitWidth()) ? std::optional<APInt>() : l.ashr(r);
   };
 
-  return minMaxBy(ashr, {lhs.smin(), lhs.smax()}, {rhs.umin(), rhs.umax()},
-                  /*isSigned=*/true);
+  return propagatePoison(minMaxBy(ashr, {lhs.smin(), lhs.smax()},
+                                  {rhs.umin(), rhs.umax()},
+                                  /*isSigned=*/true),
+                         argRanges);
 }
 
 ConstantIntRanges
@@ -642,8 +690,10 @@ mlir::intrange::inferShrU(ArrayRef<ConstantIntRanges> argRanges) {
   auto lshr = [](const APInt &l, const APInt &r) -> std::optional<APInt> {
     return r.uge(r.getBitWidth()) ? std::optional<APInt>() : l.lshr(r);
   };
-  return minMaxBy(lshr, {lhs.umin(), lhs.umax()}, {rhs.umin(), rhs.umax()},
-                  /*isSigned=*/false);
+  return propagatePoison(minMaxBy(lshr, {lhs.umin(), lhs.umax()},
+                                  {rhs.umin(), rhs.umax()},
+                                  /*isSigned=*/false),
+                         argRanges);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Arith/int-range-interface.mlir b/mlir/test/Dialect/Arith/int-range-interface.mlir
index 130782ba9f525..7d43336cecd71 100644
--- a/mlir/test/Dialect/Arith/int-range-interface.mlir
+++ b/mlir/test/Dialect/Arith/int-range-interface.mlir
@@ -663,6 +663,26 @@ func.func @select_union(%arg0 : index, %arg1 : i1) -> i1 {
     func.return %5 : i1
 }
 
+// CHECK-LABEL: func @select_poison
+// CHECK: test.reflect_bounds {smax = 10 : index, smin = 0 : index, umax = 10 : index, umin = 0 : index}
+func.func @select_poison(%arg0: i1) -> index {
+    %0 = test.with_bounds { umin = 0 : index, umax = 10 : index, smin = 0 : index, smax = 10 : index } : index
+    %1 = test.with_bounds { umin = 1 : index, umax = 0 : index, smin = 1 : index, smax = 0 : index } : index
+    %2 = arith.select %arg0, %0, %1 : index
+    %3 = test.reflect_bounds %2 : index
+    func.return %3 : index
+}
+
+// CHECK-LABEL: func @add_posion
+// CHECK: test.reflect_bounds {smax = -1 : index, smin = 0 : index, umax = 0 : index, umin = 1 : index}
+func.func @add_posion() -> index {
+    %0 = test.with_bounds { umin = 0 : index, umax = 10 : index, smin = 0 : index, smax = 10 : index } : index
+    %1 = test.with_bounds { umin = 1 : index, umax = 0 : index, smin = 1 : index, smax = 0 : index } : index
+    %2 = arith.addi %0, %1 : index
+    %3 = test.reflect_bounds %2 : index
+    func.return %3 : index
+}
+
 // CHECK-LABEL: func @if_union
 // CHECK: %[[true:.*]] = arith.constant true
 // CHECK: return %[[true]]
diff --git a/mlir/test/Dialect/Arith/int-range-opts.mlir b/mlir/test/Dialect/Arith/int-range-opts.mlir
index ea5969a100258..d613f38a55f01 100644
--- a/mlir/test/Dialect/Arith/int-range-opts.mlir
+++ b/mlir/test/Dialect/Arith/int-range-opts.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -int-range-optimizations --split-input-file %s | FileCheck %s
+// RUN: mlir-opt --int-range-optimizations --split-input-file %s | FileCheck %s
 
 // CHECK-LABEL: func @test
 //       CHECK:   %[[C:.*]] = arith.constant false
@@ -132,3 +132,13 @@ func.func @wraps() -> i8 {
   %mod = arith.remsi %val, %c64 : i8
   return %mod : i8
 }
+
+// -----
+
+// CHECK-LABEL: func @create_poison_op
+// CHECK: %[[RES:.*]] = ub.poison : i32
+// CHECK: return %[[RES]]
+func.func @create_poison_op() -> i32 {
+  %val = test.with_bounds { umin = 1 : i32, umax = 0 : i32, smin = 1 : i32, smax = 0 : i32 } : i32
+  return %val : i32
+}
diff --git a/mlir/test/Dialect/UB/int-range-interface.mlir b/mlir/test/Dialect/UB/int-range-interface.mlir
new file mode 100644
index 0000000000000..69f4923ffe6c7
--- /dev/null
+++ b/mlir/test/Dialect/UB/int-range-interface.mlir
@@ -0,0 +1,24 @@
+// RUN: mlir-opt --int-range-optimizations %s | FileCheck %s
+
+// CHECK-LABEL: func @poison
+// CHECK: test.reflect_bounds {smax = -1 : si32, smin = 0 : si32, umax = 0 : ui32, umin = 1 : ui32}
+func.func @poison() -> i32 {
+    %0 = ub.poison : i32
+    %1 = test.reflect_bounds %0 : i32
+    func.return %1 : i32
+}
+
+// CHECK-LABEL: func @poison_i1
+// CHECK: test.reflect_bounds {smax = -1 : si1, smin = 0 : si1, umax = 0 : ui1, umin = 1 : ui1}
+func.func @poison_i1() -> i1 {
+    %0 = ub.poison : i1
+    %1 = test.reflect_bounds %0 : i1
+    func.return %1 : i1
+}
+
+// CHECK-LABEL: func @poison_non_int
+// Check it doesn't crash.
+func.func @poison_non_int() -> f32 {
+    %0 = ub.poison : f32
+    func.return %0 : f32
+}
diff --git a/mlir/test/Dialect/Vector/int-range-interface.mlir b/mlir/test/Dialect/Vector/int-range-interface.mlir
index b2f16bb3dac9c..3182ac6bf8b4b 100644
--- a/mlir/test/Dialect/Vector/int-range-interface.mlir
+++ b/mlir/test/Dialect/Vector/int-range-interface.mlir
@@ -116,3 +116,20 @@ func.func @vector_step() -> vector<8xindex> {
   %1 = test.reflect_bounds %0 : vector<8xindex>
   func.return %1 : vector<8xindex>
 }
+
+// CHECK-LABEL: func @poison_vector_insert
+// CHECK: test.reflect_bounds {smax = 4 : index, smin = 1 : index, umax = 4 : index, umin = 1 : index}
+func.func @poison_vector_insert() -> vector<4xindex> {
+  %0 = ub.poison : vector<4xindex>
+  %1 = test.with_bounds { umin = 1 : index, umax = 1 : index, smin = 1 : index, smax = 1 : index } : index
+  %2 = test.with_bounds { umin = 2 : index, umax = 2 : index, smin = 2 : index, smax = 2 : index } : index
+  %3 = test.with_bounds { umin = 3 : index, umax = 3 : index, smin = 3 : index, smax = 3 : index } : index
+  %4 = test.with_bounds { umin = 4 : index, umax = 4 : index, smin = 4 : index, smax = 4 : index } : index
+  %5 = vector.insert %1, %0[0] : index into vector<4xindex>
+  %6 = vector.insert %2, %5[1] : index into vector<4xindex>
+  %7 = vector.insert %3, %6[2] : index into vector<4xindex>
+  %8 = vector.insert %4, %7[3] : index into vector<4xindex>
+
+  %9 = test.reflect_bounds %8 : vector<4xindex>
+  func.return %9 : vector<4xindex>
+}

>From 7a68b15253c79754e7e0a4f0371b6e6efd5faccd Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sat, 16 Aug 2025 12:01:28 +0200
Subject: [PATCH 2/6] missing poison check

Signed-off-by: Ivan Butygin <ivan.butygin at gmail.com>
---
 mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp | 3 ++-
 1 file changed, 2 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp b/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp
index eeb4c4ac0446f..77d1bb00edd6c 100644
--- a/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp
+++ b/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp
@@ -135,7 +135,8 @@ mlir::intrange::inferIndexOp(const InferRangeFn &inferFn,
   }
   if (truncEqual)
     // Returing the 64-bit result preserves more information.
-    return sixtyFour;
+    return propagatePoison(sixtyFour, argRanges);
+
   ConstantIntRanges merged = sixtyFour.rangeUnion(thirtyTwoAsSixtyFour);
   return propagatePoison(merged, argRanges);
 }

>From 82548361cc8a2510333c32ad6ced87d7ee493211 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sat, 16 Aug 2025 12:42:59 +0200
Subject: [PATCH 3/6] more comments

---
 mlir/include/mlir/Interfaces/InferIntRangeInterface.h | 10 ++++++----
 mlir/lib/Interfaces/InferIntRangeInterface.cpp        |  8 ++++++++
 2 files changed, 14 insertions(+), 4 deletions(-)

diff --git a/mlir/include/mlir/Interfaces/InferIntRangeInterface.h b/mlir/include/mlir/Interfaces/InferIntRangeInterface.h
index 4e4f3725a69fd..cfa6cbdade21c 100644
--- a/mlir/include/mlir/Interfaces/InferIntRangeInterface.h
+++ b/mlir/include/mlir/Interfaces/InferIntRangeInterface.h
@@ -103,12 +103,14 @@ class ConstantIntRanges {
   /// value.
   std::optional<APInt> getConstantValue() const;
 
-  /// Returns true if signed range is poisoned, i.e. no valid signed value
-  /// can be represented.
+  /// Returns true if signed range is poisoned, poisoned ranges are propagated
+  /// through the DAG and will cause the immediate UB if reached the
+  /// side-effecting operation.
   bool isSignedPoison() const;
 
-  /// Returns true if unsigned range is poisoned, i.e. no valid unsigned value
-  /// can be represented.
+  /// Returns true if unsigned range is poisoned, poisoned ranges are propagated
+  /// through the DAG and will cause the immediate UB if reached the
+  /// side-effecting operation.
   bool isUnsignedPoison() const;
 
   friend raw_ostream &operator<<(raw_ostream &os,
diff --git a/mlir/lib/Interfaces/InferIntRangeInterface.cpp b/mlir/lib/Interfaces/InferIntRangeInterface.cpp
index 4f6ab0306229f..46b5604bb5731 100644
--- a/mlir/lib/Interfaces/InferIntRangeInterface.cpp
+++ b/mlir/lib/Interfaces/InferIntRangeInterface.cpp
@@ -112,6 +112,13 @@ ConstantIntRanges::rangeUnion(const ConstantIntRanges &other) const {
   APInt sminUnion;
   APInt smaxUnion;
 
+  // Union of poisoned range with any other range is the other range.
+  // Union is used when we need to merge ranges from multiple indepdenent
+  // sources, e.g. in `arith.select` or CFG merge. "Observing" a poisoned
+  // value (using it in side-effecting operation) will cause the immediate UB.
+  // Well-formed programs should never observe the immediate UB so we assume
+  // result is either unused or only used in circumstances when it received the
+  // non-poisoned argument.
   if (isUnsignedPoison()) {
     uminUnion = other.umin();
     umaxUnion = other.umax();
@@ -151,6 +158,7 @@ ConstantIntRanges::intersection(const ConstantIntRanges &other) const {
   APInt sminIntersect;
   APInt smaxIntersect;
 
+  // Intersection of poisoned range with any other range is poisoned.
   if (isUnsignedPoison()) {
     uminIntersect = umin();
     umaxIntersect = umax();

>From ca8b8a4ee7162507601a986f4d1618dc9d957aa5 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sat, 16 Aug 2025 13:29:21 +0200
Subject: [PATCH 4/6] comment

---
 mlir/include/mlir/Interfaces/InferIntRangeInterface.h | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/mlir/include/mlir/Interfaces/InferIntRangeInterface.h b/mlir/include/mlir/Interfaces/InferIntRangeInterface.h
index cfa6cbdade21c..e3fac84c336eb 100644
--- a/mlir/include/mlir/Interfaces/InferIntRangeInterface.h
+++ b/mlir/include/mlir/Interfaces/InferIntRangeInterface.h
@@ -65,8 +65,8 @@ class ConstantIntRanges {
   /// sint_max(width)].
   static ConstantIntRanges maxRange(unsigned bitwidth);
 
-  /// Create a poisoned range, i.e. a range that represents no valid integer
-  /// values.
+  /// Create a poisoned range, poisoned ranges are propagated through the DAG
+  /// and will cause the immediate UB if reached the side-effecting operation.
   static ConstantIntRanges poison(unsigned bitwidth);
 
   /// Create a `ConstantIntRanges` with a constant value - that is, with the

>From 039efae94608733ae8255c5a3cc2322ff052f03f Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sat, 23 Aug 2025 11:07:08 +0200
Subject: [PATCH 5/6] propagate poison as part of the interface

---
 .../include/mlir/Dialect/Arith/IR/ArithOps.td | 142 +++++++++---------
 .../mlir/Dialect/Vector/IR/VectorOps.td       |   2 +-
 mlir/include/mlir/IR/Matchers.h               |   2 +-
 .../mlir/Interfaces/InferIntRangeInterface.h  |   7 +
 .../mlir/Interfaces/InferIntRangeInterface.td |  30 ++++
 .../DataFlow/IntegerRangeAnalysis.cpp         |   4 +-
 .../Arith/IR/InferIntRangeInterfaceImpls.cpp  |   2 +-
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp      |  11 +-
 .../lib/Interfaces/InferIntRangeInterface.cpp |  35 +++++
 .../Interfaces/Utils/InferIntRangeCommon.cpp  |  97 +++---------
 10 files changed, 179 insertions(+), 153 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
index ef9ccb7e87946..c13e29c42e811 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
@@ -216,14 +216,14 @@ def Arith_ConstantOp : Op<Arith_Dialect, "constant",
 def Arith_AddIOp : Arith_IntBinaryOpWithOverflowFlags<"addi", [Commutative]> {
   let summary = "integer addition operation";
   let description = [{
-    Performs N-bit addition on the operands. The operands are interpreted as 
-    unsigned bitvectors. The result is represented by a bitvector containing the 
-    mathematical value of the addition modulo 2^n, where `n` is the bitwidth. 
-    Because `arith` integers use a two's complement representation, this operation 
+    Performs N-bit addition on the operands. The operands are interpreted as
+    unsigned bitvectors. The result is represented by a bitvector containing the
+    mathematical value of the addition modulo 2^n, where `n` is the bitwidth.
+    Because `arith` integers use a two's complement representation, this operation
     is applicable on both signed and unsigned integer operands.
 
     The `addi` operation takes two operands and returns one result, each of
-    these is required to be the same type. This type may be an integer scalar type, 
+    these is required to be the same type. This type may be an integer scalar type,
     a vector whose element type is integer, or a tensor of integers.
 
     This op supports `nuw`/`nsw` overflow flags which stands for
@@ -489,8 +489,8 @@ def Arith_DivUIOp : Arith_IntBinaryOp<"divui", [ConditionallySpeculatable]> {
     the most significant, i.e. for `i16` given two's complement representation,
     `6 / -2 = 6 / (2^16 - 2) = 0`.
 
-    Division by zero is undefined behavior. When applied to `vector` and 
-    `tensor` values, the behavior is undefined if _any_ elements are divided by 
+    Division by zero is undefined behavior. When applied to `vector` and
+    `tensor` values, the behavior is undefined if _any_ elements are divided by
     zero.
 
     Example:
@@ -525,9 +525,9 @@ def Arith_DivSIOp : Arith_IntBinaryOp<"divsi", [ConditionallySpeculatable]> {
     Signed integer division. Rounds towards zero. Treats the leading bit as
     sign, i.e. `6 / -2 = -3`.
 
-    Divison by zero, or signed division overflow (minimum value divided by -1) 
-    is undefined behavior. When applied to `vector` and `tensor` values, the 
-    behavior is undefined if _any_ of its elements are divided by zero or has a 
+    Divison by zero, or signed division overflow (minimum value divided by -1)
+    is undefined behavior. When applied to `vector` and `tensor` values, the
+    behavior is undefined if _any_ of its elements are divided by zero or has a
     signed division overflow.
 
     Example:
@@ -562,10 +562,10 @@ def Arith_CeilDivUIOp : Arith_IntBinaryOp<"ceildivui",
   let description = [{
     Unsigned integer division. Rounds towards positive infinity. Treats the
     leading bit as the most significant, i.e. for `i16` given two's complement
-    representation, `6 / -2 = 6 / (2^16 - 2) = 1`. 
+    representation, `6 / -2 = 6 / (2^16 - 2) = 1`.
 
-    Division by zero is undefined behavior. When applied to `vector` and 
-    `tensor` values, the behavior is undefined if _any_ elements are divided by 
+    Division by zero is undefined behavior. When applied to `vector` and
+    `tensor` values, the behavior is undefined if _any_ elements are divided by
     zero.
 
     Example:
@@ -594,9 +594,9 @@ def Arith_CeilDivSIOp : Arith_IntBinaryOp<"ceildivsi",
   let description = [{
     Signed integer division. Rounds towards positive infinity, i.e. `7 / -2 = -3`.
 
-    Divison by zero, or signed division overflow (minimum value divided by -1) 
-    is undefined behavior. When applied to `vector` and `tensor` values, the 
-    behavior is undefined if _any_ of its elements are divided by zero or has a 
+    Divison by zero, or signed division overflow (minimum value divided by -1)
+    is undefined behavior. When applied to `vector` and `tensor` values, the
+    behavior is undefined if _any_ of its elements are divided by zero or has a
     signed division overflow.
 
     Example:
@@ -624,9 +624,9 @@ def Arith_FloorDivSIOp : Arith_TotalIntBinaryOp<"floordivsi"> {
   let description = [{
     Signed integer division. Rounds towards negative infinity, i.e. `5 / -2 = -3`.
 
-    Divison by zero, or signed division overflow (minimum value divided by -1) 
-    is undefined behavior. When applied to `vector` and `tensor` values, the 
-    behavior is undefined if _any_ of its elements are divided by zero or has a 
+    Divison by zero, or signed division overflow (minimum value divided by -1)
+    is undefined behavior. When applied to `vector` and `tensor` values, the
+    behavior is undefined if _any_ of its elements are divided by zero or has a
     signed division overflow.
 
     Example:
@@ -650,8 +650,8 @@ def Arith_RemUIOp : Arith_TotalIntBinaryOp<"remui"> {
     Unsigned integer division remainder. Treats the leading bit as the most
     significant, i.e. for `i16`, `6 % -2 = 6 % (2^16 - 2) = 6`.
 
-    Division by zero is undefined behavior. When applied to `vector` and 
-    `tensor` values, the behavior is undefined if _any_ elements are divided by 
+    Division by zero is undefined behavior. When applied to `vector` and
+    `tensor` values, the behavior is undefined if _any_ elements are divided by
     zero.
 
     Example:
@@ -680,8 +680,8 @@ def Arith_RemSIOp : Arith_TotalIntBinaryOp<"remsi"> {
     Signed integer division remainder. Treats the leading bit as sign, i.e. `6 %
     -2 = 0`.
 
-    Division by zero is undefined behavior. When applied to `vector` and 
-    `tensor` values, the behavior is undefined if _any_ elements are divided by 
+    Division by zero is undefined behavior. When applied to `vector` and
+    `tensor` values, the behavior is undefined if _any_ elements are divided by
     zero.
 
     Example:
@@ -794,9 +794,9 @@ def Arith_XOrIOp : Arith_TotalIntBinaryOp<"xori", [Commutative]> {
 def Arith_ShLIOp : Arith_IntBinaryOpWithOverflowFlags<"shli"> {
   let summary = "integer left-shift";
   let description = [{
-    The `shli` operation shifts the integer value of the first operand to the left 
-    by the integer value of the second operand. The second operand is interpreted as 
-    unsigned. The low order bits are filled with zeros. If the value of the second 
+    The `shli` operation shifts the integer value of the first operand to the left
+    by the integer value of the second operand. The second operand is interpreted as
+    unsigned. The low order bits are filled with zeros. If the value of the second
     operand is greater or equal than the bitwidth of the first operand, then the
     operation returns poison.
 
@@ -811,7 +811,7 @@ def Arith_ShLIOp : Arith_IntBinaryOpWithOverflowFlags<"shli"> {
     %1 = arith.constant 5 : i8  // %1 is 0b00000101
     %2 = arith.constant 3 : i8
     %3 = arith.shli %1, %2 : i8 // %3 is 0b00101000
-    %4 = arith.shli %1, %2 overflow<nsw, nuw> : i8  
+    %4 = arith.shli %1, %2 overflow<nsw, nuw> : i8
     ```
   }];
   let hasFolder = 1;
@@ -824,9 +824,9 @@ def Arith_ShLIOp : Arith_IntBinaryOpWithOverflowFlags<"shli"> {
 def Arith_ShRUIOp : Arith_TotalIntBinaryOp<"shrui"> {
   let summary = "unsigned integer right-shift";
   let description = [{
-    The `shrui` operation shifts an integer value of the first operand to the right 
+    The `shrui` operation shifts an integer value of the first operand to the right
     by the value of the second operand. The first operand is interpreted as unsigned,
-    and the second operand is interpreted as unsigned. The high order bits are always 
+    and the second operand is interpreted as unsigned. The high order bits are always
     filled with zeros. If the value of the second operand is greater or equal than the
     bitwidth of the first operand, then the operation returns poison.
 
@@ -848,11 +848,11 @@ def Arith_ShRUIOp : Arith_TotalIntBinaryOp<"shrui"> {
 def Arith_ShRSIOp : Arith_TotalIntBinaryOp<"shrsi"> {
   let summary = "signed integer right-shift";
   let description = [{
-    The `shrsi` operation shifts an integer value of the first operand to the right 
-    by the value of the second operand. The first operand is interpreted as signed, 
-    and the second operand is interpreter as unsigned. The high order bits in the 
-    output are filled with copies of the most-significant bit of the shifted value 
-    (which means that the sign of the value is preserved). If the value of the second 
+    The `shrsi` operation shifts an integer value of the first operand to the right
+    by the value of the second operand. The first operand is interpreted as signed,
+    and the second operand is interpreter as unsigned. The high order bits in the
+    output are filled with copies of the most-significant bit of the shifted value
+    (which means that the sign of the value is preserved). If the value of the second
     operand is greater or equal than bitwidth of the first operand, then the operation
     returns poison.
 
@@ -1229,28 +1229,28 @@ def Arith_ScalingExtFOp
   let summary = "Upcasts input floats using provided scales values following "
                 "OCP MXFP Spec";
   let description = [{
-  This operation upcasts input floating-point values using provided scale 
-  values. It expects both scales and the input operand to be of the same shape, 
-  making the operation elementwise. Scales are usually calculated per block 
+  This operation upcasts input floating-point values using provided scale
+  values. It expects both scales and the input operand to be of the same shape,
+  making the operation elementwise. Scales are usually calculated per block
   following the OCP MXFP spec as described in https://arxiv.org/abs/2310.10537.
 
-  If scales are calculated per block where blockSize != 1, then scales may 
-  require broadcasting to make this operation elementwise. For example, let's 
-  say the input is of shape `<dim1 x dim2 x ... dimN>`. Given blockSize != 1 and 
-  assuming quantization happens on the last axis, the input can be reshaped to 
-  `<dim1 x dim2 x ... (dimN/blockSize) x blockSize>`. Scales will be calculated 
-  per block on the last axis. Therefore, scales will be of shape 
-  `<dim1 x dim2 x ... (dimN/blockSize) x 1>`. Scales could also be of some other 
-  shape as long as it is broadcast compatible with the input, e.g., 
+  If scales are calculated per block where blockSize != 1, then scales may
+  require broadcasting to make this operation elementwise. For example, let's
+  say the input is of shape `<dim1 x dim2 x ... dimN>`. Given blockSize != 1 and
+  assuming quantization happens on the last axis, the input can be reshaped to
+  `<dim1 x dim2 x ... (dimN/blockSize) x blockSize>`. Scales will be calculated
+  per block on the last axis. Therefore, scales will be of shape
+  `<dim1 x dim2 x ... (dimN/blockSize) x 1>`. Scales could also be of some other
+  shape as long as it is broadcast compatible with the input, e.g.,
   `<1 x 1 x ... (dimN/blockSize) x 1>`.
 
-  In this example, before calling into `arith.scaling_extf`, scales must be 
-  broadcasted to `<dim1 x dim2 x dim3 ... (dimN/blockSize) x blockSize>`. Note 
-  that there could be multiple quantization axes. Internally, 
+  In this example, before calling into `arith.scaling_extf`, scales must be
+  broadcasted to `<dim1 x dim2 x dim3 ... (dimN/blockSize) x blockSize>`. Note
+  that there could be multiple quantization axes. Internally,
   `arith.scaling_extf` would perform the following:
- 
+
     ```
-    resultTy = get_type(result) 
+    resultTy = get_type(result)
     scaleTy  = get_type(scale)
     inputTy = get_type(input)
     scale.exponent = arith.truncf(scale) : scaleTy to f8E8M0
@@ -1258,12 +1258,12 @@ def Arith_ScalingExtFOp
     input.extf = arith.extf(input) : inputTy to resultTy
     result = arith.mulf(scale.extf, input.extf)
     ```
-    It propagates NaN values. Therefore, if either scale or the input element 
+    It propagates NaN values. Therefore, if either scale or the input element
     contains NaN, then the output element value will also be a NaN.
   }];
   let hasVerifier = 1;
   let assemblyFormat =
-      [{ $in `,` $scale (`fastmath` `` $fastmath^)? attr-dict `:` 
+      [{ $in `,` $scale (`fastmath` `` $fastmath^)? attr-dict `:`
       type($in) `,` type($scale) `to` type($out)}];
 }
 
@@ -1373,28 +1373,28 @@ def Arith_ScalingTruncFOp
   let summary = "Downcasts input floating point values using provided scales "
                 "values following OCP MXFP Spec";
   let description = [{
-    This operation downcasts input using the provided scale values. It expects 
-    both scales and the input operand to be of the same shape and, therefore, 
-    makes the operation elementwise. Scales are usually calculated per block 
+    This operation downcasts input using the provided scale values. It expects
+    both scales and the input operand to be of the same shape and, therefore,
+    makes the operation elementwise. Scales are usually calculated per block
     following the OCP MXFP spec as described in https://arxiv.org/abs/2310.10537.
     Users are required to normalize and clamp the scales as necessary before calling
     passing them to this operation.  OCP MXFP spec also does the flushing of denorms
-    on the input operand, which should be handled during lowering by passing appropriate 
-    fastMath flag to this operation. 
-
-    If scales are calculated per block where blockSize != 1, scales may require 
-    broadcasting to make this operation elementwise. For example, let's say the 
-    input is of shape `<dim1 x dim2 x ... dimN>`. Given blockSize != 1 and 
-    assuming quantization happens on the last axis, the input can be reshaped to 
-    `<dim1 x dim2 x ... (dimN/blockSize) x blockSize>`. Scales will be calculated 
-    per block on the last axis. Therefore, scales will be of shape 
-    `<dim1 x dim2 x ... (dimN/blockSize) x 1>`. Scales could also be of some other 
-    shape as long as it is broadcast compatible with the input, e.g., 
+    on the input operand, which should be handled during lowering by passing appropriate
+    fastMath flag to this operation.
+
+    If scales are calculated per block where blockSize != 1, scales may require
+    broadcasting to make this operation elementwise. For example, let's say the
+    input is of shape `<dim1 x dim2 x ... dimN>`. Given blockSize != 1 and
+    assuming quantization happens on the last axis, the input can be reshaped to
+    `<dim1 x dim2 x ... (dimN/blockSize) x blockSize>`. Scales will be calculated
+    per block on the last axis. Therefore, scales will be of shape
+    `<dim1 x dim2 x ... (dimN/blockSize) x 1>`. Scales could also be of some other
+    shape as long as it is broadcast compatible with the input, e.g.,
     `<1 x 1 x ... (dimN/blockSize) x 1>`.
 
-    In this example, before calling into `arith.scaling_truncf`, scales must be 
-    broadcasted to `<dim1 x dim2 x dim3 ... (dimN/blockSize) x blockSize>`. Note 
-    that there could be multiple quantization axes. Internally, 
+    In this example, before calling into `arith.scaling_truncf`, scales must be
+    broadcasted to `<dim1 x dim2 x dim3 ... (dimN/blockSize) x blockSize>`. Note
+    that there could be multiple quantization axes. Internally,
     `arith.scaling_truncf` would perform the following:
 
     ```
@@ -1409,7 +1409,7 @@ def Arith_ScalingTruncFOp
   }];
   let hasVerifier = 1;
   let assemblyFormat =
-      [{ $in `,` $scale ($roundingmode^)? (`fastmath` `` $fastmath^)? attr-dict `:` 
+      [{ $in `,` $scale ($roundingmode^)? (`fastmath` `` $fastmath^)? attr-dict `:`
       type($in) `,` type($scale) `to` type($out)}];
 }
 
@@ -1687,7 +1687,7 @@ class BooleanConditionOrMatchingShape<string condition, string result> :
 def SelectOp : Arith_Op<"select", [Pure,
     AllTypesMatch<["true_value", "false_value", "result"]>,
     BooleanConditionOrMatchingShape<"condition", "result">,
-    DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRangesFromOptional"]>,
+    DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRangesOrPoison"]>,
     DeclareOpInterfaceMethods<SelectLikeOpInterface>]> {
   let summary = "select operation";
   let description = [{
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index c7b83674fb009..59d2749799724 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -843,7 +843,7 @@ def Vector_FromElementsOp : Vector_Op<"from_elements", [
 
 def Vector_InsertOp :
   Vector_Op<"insert", [Pure,
-     DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>,
+     DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRangesOrPoison"]>,
      PredOpTrait<"source operand and result have same element type",
                  TCresVTEtIsSameAsOpBase<0, 0>>,
      AllTypesMatch<["dest", "result"]>]> {
diff --git a/mlir/include/mlir/IR/Matchers.h b/mlir/include/mlir/IR/Matchers.h
index e577909621cb8..4ec947b59520b 100644
--- a/mlir/include/mlir/IR/Matchers.h
+++ b/mlir/include/mlir/IR/Matchers.h
@@ -129,7 +129,7 @@ struct infer_int_range_op_binder {
       *bind_value = argRanges;
       matched = true;
     };
-    inferIntRangeOp.inferResultRangesFromOptional(argRanges, setResultRanges);
+    inferIntRangeOp.inferResultRangesOrPoison(argRanges, setResultRanges);
     return matched;
   }
 };
diff --git a/mlir/include/mlir/Interfaces/InferIntRangeInterface.h b/mlir/include/mlir/Interfaces/InferIntRangeInterface.h
index e3fac84c336eb..0943d9a3e6030 100644
--- a/mlir/include/mlir/Interfaces/InferIntRangeInterface.h
+++ b/mlir/include/mlir/Interfaces/InferIntRangeInterface.h
@@ -198,6 +198,13 @@ void defaultInferResultRanges(InferIntRangeInterface interface,
 void defaultInferResultRangesFromOptional(InferIntRangeInterface interface,
                                           ArrayRef<ConstantIntRanges> argRanges,
                                           SetIntRangeFn setResultRanges);
+
+/// Default implementation of `inferResultRangesOrPoison` which propagates
+/// poison and dispatches to the `inferResultRangesFromOptional`.
+void defaultInferResultRangesOrPoison(InferIntRangeInterface interface,
+                                      ArrayRef<IntegerValueRange> argRanges,
+                                      SetIntLatticeFn setResultRanges);
+
 } // end namespace intrange::detail
 } // end namespace mlir
 
diff --git a/mlir/include/mlir/Interfaces/InferIntRangeInterface.td b/mlir/include/mlir/Interfaces/InferIntRangeInterface.td
index 6ee436ce4d6c2..db6d686659ca5 100644
--- a/mlir/include/mlir/Interfaces/InferIntRangeInterface.td
+++ b/mlir/include/mlir/Interfaces/InferIntRangeInterface.td
@@ -33,6 +33,9 @@ def InferIntRangeInterface : OpInterface<"InferIntRangeInterface"> {
       When operations take non-integer inputs, the
      `inferResultRangesFromOptional` method should be implemented instead.
 
+      If any of the operands have poison ranges, they will be propagated to the
+      results automatically after the metdod returns.
+
       When called on an op that also implements the RegionBranchOpInterface
       or BranchOpInterface, this method should not attempt to infer the values
       of the branch results, as this will be handled by the analyses that use
@@ -60,6 +63,9 @@ def InferIntRangeInterface : OpInterface<"InferIntRangeInterface"> {
       as an argument. When implemented, `setValueRange` should be called on
       all result values for the operation.
 
+      If any of the operands have poison ranges, they will be propagated to the
+      results automatically after the metdod returns.
+
       This method allows for more precise implementations when operations
       want to reason about inputs which may be undefined during the analysis.
     }],
@@ -72,6 +78,30 @@ def InferIntRangeInterface : OpInterface<"InferIntRangeInterface"> {
       ::mlir::intrange::detail::defaultInferResultRanges($_op,
                                                          argRanges,
                                                          setResultRanges);
+    }]>,
+
+    InterfaceMethod<[{
+      Infer the bounds on the results of this op given the lattice representation
+      of the bounds for its arguments. For each result value or block argument
+      (that isn't a branch argument, since the dataflow analysis handles
+      those case), the method should call `setValueRange` with that `Value`
+      as an argument. When implemented, `setValueRange` should be called on
+      all result values for the operation.
+
+      Unlike `inferResultRanges`/`inferResultRangesFromOptional` this method
+      does not automatically propagate poison from the inputs. This allows more
+      precise poison semantics implementation.
+    }],
+    /*retTy=*/"void",
+    /*methodName=*/"inferResultRangesOrPoison",
+    /*args=*/(ins "::llvm::ArrayRef<::mlir::IntegerValueRange>":$argRanges,
+                  "::mlir::SetIntLatticeFn":$setResultRanges),
+    /*methodBody=*/"",
+    /*defaultImplementation=*/[{
+      ::mlir::intrange::detail::defaultInferResultRangesOrPoison(
+        $_op,
+        argRanges,
+        setResultRanges);
     }]>
   ];
 }
diff --git a/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp b/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp
index c7a950d9a8871..b913f951b7abf 100644
--- a/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp
@@ -120,7 +120,7 @@ LogicalResult IntegerRangeAnalysis::visitOperation(
     propagateIfChanged(lattice, changed);
   };
 
-  inferrable.inferResultRangesFromOptional(argRanges, joinCallback);
+  inferrable.inferResultRangesOrPoison(argRanges, joinCallback);
   return success();
 }
 
@@ -162,7 +162,7 @@ void IntegerRangeAnalysis::visitNonControlFlowArguments(
       propagateIfChanged(lattice, changed);
     };
 
-    inferrable.inferResultRangesFromOptional(argRanges, joinCallback);
+    inferrable.inferResultRangesOrPoison(argRanges, joinCallback);
     return;
   }
 
diff --git a/mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp b/mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp
index 7673185487eef..7dd0d881fe587 100644
--- a/mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp
+++ b/mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp
@@ -312,7 +312,7 @@ void arith::CmpIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
 // SelectOp
 //===----------------------------------------------------------------------===//
 
-void arith::SelectOp::inferResultRangesFromOptional(
+void arith::SelectOp::inferResultRangesOrPoison(
     ArrayRef<IntegerValueRange> argRanges, SetIntLatticeFn setResultRange) {
   std::optional<APInt> mbCondVal =
       argRanges[0].isUninitialized()
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 2b2581d353673..4ba6504aa4224 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -3207,9 +3207,14 @@ void ShuffleOp::getCanonicalizationPatterns(RewritePatternSet &results,
 // InsertOp
 //===----------------------------------------------------------------------===//
 
-void vector::InsertOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
-                                         SetIntRangeFn setResultRanges) {
-  setResultRanges(getResult(), argRanges[0].rangeUnion(argRanges[1]));
+void vector::InsertOp::inferResultRangesOrPoison(
+    ArrayRef<IntegerValueRange> argRanges, SetIntLatticeFn setResultRanges) {
+  if (argRanges[0].isUninitialized() || argRanges[1].isUninitialized())
+    return;
+
+  const ConstantIntRanges &range0 = argRanges[0].getValue();
+  const ConstantIntRanges &range1 = argRanges[1].getValue();
+  setResultRanges(getResult(), range0.rangeUnion(range1));
 }
 
 void vector::InsertOp::build(OpBuilder &builder, OperationState &result,
diff --git a/mlir/lib/Interfaces/InferIntRangeInterface.cpp b/mlir/lib/Interfaces/InferIntRangeInterface.cpp
index 46b5604bb5731..24a55a86f36b2 100644
--- a/mlir/lib/Interfaces/InferIntRangeInterface.cpp
+++ b/mlir/lib/Interfaces/InferIntRangeInterface.cpp
@@ -253,3 +253,38 @@ void mlir::intrange::detail::defaultInferResultRangesFromOptional(
           setResultRanges(value, argRanges.getValue());
       });
 }
+
+void mlir::intrange::detail::defaultInferResultRangesOrPoison(
+    InferIntRangeInterface interface, ArrayRef<IntegerValueRange> argRanges,
+    SetIntLatticeFn setResultRanges) {
+
+  bool signedPoison = false;
+  bool unsignedPoison = false;
+  for (const IntegerValueRange &range : argRanges) {
+    if (range.isUninitialized())
+      continue;
+
+    const ConstantIntRanges &value = range.getValue();
+    signedPoison = signedPoison || value.isSignedPoison();
+    unsignedPoison = unsignedPoison || value.isUnsignedPoison();
+  }
+
+  auto visitor = [&](Value value, const IntegerValueRange &range) {
+    if (range.isUninitialized())
+      return;
+
+    if (!signedPoison && !unsignedPoison)
+      return setResultRanges(value, range);
+
+    const ConstantIntRanges &origRange = range.getValue();
+    auto poison = ConstantIntRanges::poison(origRange.getBitWidth());
+    APInt umin = unsignedPoison ? poison.umin() : origRange.umin();
+    APInt umax = unsignedPoison ? poison.umax() : origRange.umax();
+    APInt smin = signedPoison ? poison.smin() : origRange.smin();
+    APInt smax = signedPoison ? poison.smax() : origRange.smax();
+
+    setResultRanges(value, ConstantIntRanges(umin, umax, smin, smax));
+  };
+
+  interface.inferResultRangesFromOptional(argRanges, visitor);
+}
diff --git a/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp b/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp
index 77d1bb00edd6c..af4ea5ac1cec8 100644
--- a/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp
+++ b/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp
@@ -32,29 +32,6 @@ using namespace mlir;
 // General utilities
 //===----------------------------------------------------------------------===//
 
-/// If any of the arguments are poison, return poison.
-static ConstantIntRanges
-propagatePoison(const ConstantIntRanges &newRange,
-                ArrayRef<ConstantIntRanges> argRanges) {
-  APInt umin = newRange.umin();
-  APInt umax = newRange.umax();
-  APInt smin = newRange.smin();
-  APInt smax = newRange.smax();
-
-  unsigned width = umin.getBitWidth();
-  for (const ConstantIntRanges &argRange : argRanges) {
-    if (argRange.isSignedPoison()) {
-      smin = APInt::getZero(width);
-      smax = smin - 1;
-    }
-    if (argRange.isUnsignedPoison()) {
-      umax = APInt::getZero(width);
-      umin = umax + 1;
-    }
-  }
-  return {umin, umax, smin, smax};
-}
-
 /// Function that evaluates the result of doing something on arithmetic
 /// constants and returns std::nullopt on overflow.
 using ConstArithFn =
@@ -135,10 +112,9 @@ mlir::intrange::inferIndexOp(const InferRangeFn &inferFn,
   }
   if (truncEqual)
     // Returing the 64-bit result preserves more information.
-    return propagatePoison(sixtyFour, argRanges);
-
+    return sixtyFour;
   ConstantIntRanges merged = sixtyFour.rangeUnion(thirtyTwoAsSixtyFour);
-  return propagatePoison(merged, argRanges);
+  return merged;
 }
 
 ConstantIntRanges mlir::intrange::extRange(const ConstantIntRanges &range,
@@ -147,21 +123,21 @@ ConstantIntRanges mlir::intrange::extRange(const ConstantIntRanges &range,
   APInt umax = range.umax().zext(destWidth);
   APInt smin = range.smin().sext(destWidth);
   APInt smax = range.smax().sext(destWidth);
-  return propagatePoison({umin, umax, smin, smax}, range);
+  return {umin, umax, smin, smax};
 }
 
 ConstantIntRanges mlir::intrange::extUIRange(const ConstantIntRanges &range,
                                              unsigned destWidth) {
   APInt umin = range.umin().zext(destWidth);
   APInt umax = range.umax().zext(destWidth);
-  return propagatePoison(ConstantIntRanges::fromUnsigned(umin, umax), range);
+  return ConstantIntRanges::fromUnsigned(umin, umax);
 }
 
 ConstantIntRanges mlir::intrange::extSIRange(const ConstantIntRanges &range,
                                              unsigned destWidth) {
   APInt smin = range.smin().sext(destWidth);
   APInt smax = range.smax().sext(destWidth);
-  return propagatePoison(ConstantIntRanges::fromSigned(smin, smax), range);
+  return ConstantIntRanges::fromSigned(smin, smax);
 }
 
 ConstantIntRanges mlir::intrange::truncRange(const ConstantIntRanges &range,
@@ -197,7 +173,7 @@ ConstantIntRanges mlir::intrange::truncRange(const ConstantIntRanges &range,
                                  : range.smin().trunc(destWidth);
   APInt smax = hasSignedOverflow ? APInt::getSignedMaxValue(destWidth)
                                  : range.smax().trunc(destWidth);
-  return propagatePoison({umin, umax, smin, smax}, range);
+  return {umin, umax, smin, smax};
 }
 
 //===----------------------------------------------------------------------===//
@@ -230,7 +206,7 @@ mlir::intrange::inferAdd(ArrayRef<ConstantIntRanges> argRanges,
       uadd, lhs.umin(), rhs.umin(), lhs.umax(), rhs.umax(), /*isSigned=*/false);
   ConstantIntRanges srange = computeBoundsBy(
       sadd, lhs.smin(), rhs.smin(), lhs.smax(), rhs.smax(), /*isSigned=*/true);
-  return propagatePoison(urange.intersection(srange), argRanges);
+  return urange.intersection(srange);
 }
 
 //===----------------------------------------------------------------------===//
@@ -262,7 +238,7 @@ mlir::intrange::inferSub(ArrayRef<ConstantIntRanges> argRanges,
       usub, lhs.umin(), rhs.umax(), lhs.umax(), rhs.umin(), /*isSigned=*/false);
   ConstantIntRanges srange = computeBoundsBy(
       ssub, lhs.smin(), rhs.smax(), lhs.smax(), rhs.smin(), /*isSigned=*/true);
-  return propagatePoison(urange.intersection(srange), argRanges);
+  return urange.intersection(srange);
 }
 
 //===----------------------------------------------------------------------===//
@@ -297,7 +273,7 @@ mlir::intrange::inferMul(ArrayRef<ConstantIntRanges> argRanges,
   ConstantIntRanges srange =
       minMaxBy(smul, {lhs.smin(), lhs.smax()}, {rhs.smin(), rhs.smax()},
                /*isSigned=*/true);
-  return propagatePoison(urange.intersection(srange), argRanges);
+  return urange.intersection(srange);
 }
 
 //===----------------------------------------------------------------------===//
@@ -329,8 +305,7 @@ static ConstantIntRanges inferDivURange(const ConstantIntRanges &lhs,
 
   // X u/ Y u<= X.
   APInt umax = lhsMax;
-  return propagatePoison(ConstantIntRanges::fromUnsigned(umin, umax),
-                         {lhs, rhs});
+  return ConstantIntRanges::fromUnsigned(umin, umax);
 }
 
 ConstantIntRanges
@@ -375,12 +350,10 @@ static ConstantIntRanges inferDivSRange(const ConstantIntRanges &lhs,
       APInt result = a.sdiv_ov(b, overflowed);
       return overflowed ? std::optional<APInt>() : fixup(a, b, result);
     };
-    return propagatePoison(minMaxBy(sdiv, {lhsMin, lhsMax}, {rhsMin, rhsMax},
-                                    /*isSigned=*/true),
-                           {lhs, rhs});
+    return minMaxBy(sdiv, {lhsMin, lhsMax}, {rhsMin, rhsMax},
+                    /*isSigned=*/true);
   }
-  return propagatePoison(ConstantIntRanges::maxRange(rhsMin.getBitWidth()),
-                         {lhs, rhs});
+  return ConstantIntRanges::maxRange(rhsMin.getBitWidth());
 }
 
 ConstantIntRanges
@@ -421,7 +394,7 @@ mlir::intrange::inferCeilDivS(ArrayRef<ConstantIntRanges> argRanges) {
     auto newLhs = ConstantIntRanges::fromSigned(lhs.smin() + 1, lhs.smax());
     result = result.rangeUnion(inferDivSRange(newLhs, rhs, ceilDivSIFix));
   }
-  return propagatePoison(result, argRanges);
+  return result;
 }
 
 ConstantIntRanges
@@ -451,9 +424,6 @@ mlir::intrange::inferRemS(ArrayRef<ConstantIntRanges> argRanges) {
   const APInt &lhsMin = lhs.smin(), &lhsMax = lhs.smax(), &rhsMin = rhs.smin(),
               &rhsMax = rhs.smax();
 
-  if (lhs.isSignedPoison() || rhs.isSignedPoison())
-    return ConstantIntRanges::poison(rhsMin.getBitWidth());
-
   unsigned width = rhsMax.getBitWidth();
   APInt smin = APInt::getSignedMinValue(width);
   APInt smax = APInt::getSignedMaxValue(width);
@@ -492,9 +462,6 @@ mlir::intrange::inferRemU(ArrayRef<ConstantIntRanges> argRanges) {
   const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
   const APInt &rhsMin = rhs.umin(), &rhsMax = rhs.umax();
 
-  if (lhs.isUnsignedPoison() || rhs.isUnsignedPoison())
-    return ConstantIntRanges::poison(rhsMin.getBitWidth());
-
   unsigned width = rhsMin.getBitWidth();
   APInt umin = APInt::getZero(width);
   // Remainder can't be larger than either of its arguments.
@@ -524,8 +491,6 @@ mlir::intrange::inferRemU(ArrayRef<ConstantIntRanges> argRanges) {
 ConstantIntRanges
 mlir::intrange::inferMaxS(ArrayRef<ConstantIntRanges> argRanges) {
   const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
-  if (lhs.isSignedPoison() || rhs.isSignedPoison())
-    return ConstantIntRanges::poison(lhs.smin().getBitWidth());
 
   const APInt &smin = lhs.smin().sgt(rhs.smin()) ? lhs.smin() : rhs.smin();
   const APInt &smax = lhs.smax().sgt(rhs.smax()) ? lhs.smax() : rhs.smax();
@@ -535,8 +500,6 @@ mlir::intrange::inferMaxS(ArrayRef<ConstantIntRanges> argRanges) {
 ConstantIntRanges
 mlir::intrange::inferMaxU(ArrayRef<ConstantIntRanges> argRanges) {
   const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
-  if (lhs.isUnsignedPoison() || rhs.isUnsignedPoison())
-    return ConstantIntRanges::poison(lhs.umin().getBitWidth());
 
   const APInt &umin = lhs.umin().ugt(rhs.umin()) ? lhs.umin() : rhs.umin();
   const APInt &umax = lhs.umax().ugt(rhs.umax()) ? lhs.umax() : rhs.umax();
@@ -546,8 +509,6 @@ mlir::intrange::inferMaxU(ArrayRef<ConstantIntRanges> argRanges) {
 ConstantIntRanges
 mlir::intrange::inferMinS(ArrayRef<ConstantIntRanges> argRanges) {
   const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
-  if (lhs.isSignedPoison() || rhs.isSignedPoison())
-    return ConstantIntRanges::poison(lhs.smin().getBitWidth());
 
   const APInt &smin = lhs.smin().slt(rhs.smin()) ? lhs.smin() : rhs.smin();
   const APInt &smax = lhs.smax().slt(rhs.smax()) ? lhs.smax() : rhs.smax();
@@ -557,8 +518,6 @@ mlir::intrange::inferMinS(ArrayRef<ConstantIntRanges> argRanges) {
 ConstantIntRanges
 mlir::intrange::inferMinU(ArrayRef<ConstantIntRanges> argRanges) {
   const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
-  if (lhs.isUnsignedPoison() || rhs.isUnsignedPoison())
-    return ConstantIntRanges::poison(lhs.umin().getBitWidth());
 
   const APInt &umin = lhs.umin().ult(rhs.umin()) ? lhs.umin() : rhs.umin();
   const APInt &umax = lhs.umax().ult(rhs.umax()) ? lhs.umax() : rhs.umax();
@@ -590,10 +549,8 @@ mlir::intrange::inferAnd(ArrayRef<ConstantIntRanges> argRanges) {
   auto andi = [](const APInt &a, const APInt &b) -> std::optional<APInt> {
     return a & b;
   };
-  return propagatePoison(minMaxBy(andi, {lhsZeros, lhsOnes},
-                                  {rhsZeros, rhsOnes},
-                                  /*isSigned=*/false),
-                         argRanges);
+  return minMaxBy(andi, {lhsZeros, lhsOnes}, {rhsZeros, rhsOnes},
+                  /*isSigned=*/false);
 }
 
 ConstantIntRanges
@@ -603,9 +560,8 @@ mlir::intrange::inferOr(ArrayRef<ConstantIntRanges> argRanges) {
   auto ori = [](const APInt &a, const APInt &b) -> std::optional<APInt> {
     return a | b;
   };
-  return propagatePoison(minMaxBy(ori, {lhsZeros, lhsOnes}, {rhsZeros, rhsOnes},
-                                  /*isSigned=*/false),
-                         argRanges);
+  return minMaxBy(ori, {lhsZeros, lhsOnes}, {rhsZeros, rhsOnes},
+                  /*isSigned=*/false);
 }
 
 /// Get bitmask of all bits which can change while iterating in
@@ -622,9 +578,6 @@ mlir::intrange::inferXor(ArrayRef<ConstantIntRanges> argRanges) {
   // Construct mask of varying bits for both ranges, xor values and then replace
   // masked bits with 0s and 1s to get min and max values respectively.
   ConstantIntRanges lhs = argRanges[0], rhs = argRanges[1];
-  if (lhs.isUnsignedPoison() || rhs.isUnsignedPoison())
-    return ConstantIntRanges::poison(lhs.umin().getBitWidth());
-
   APInt mask = getVaryingBitsMask(lhs) | getVaryingBitsMask(rhs);
   APInt res = lhs.umin() ^ rhs.umin();
   APInt min = res & ~mask;
@@ -667,7 +620,7 @@ mlir::intrange::inferShl(ArrayRef<ConstantIntRanges> argRanges,
   ConstantIntRanges srange =
       minMaxBy(sshl, {lhs.smin(), lhs.smax()}, {rhsUMin, rhsUMax},
                /*isSigned=*/true);
-  return propagatePoison(urange.intersection(srange), argRanges);
+  return urange.intersection(srange);
 }
 
 ConstantIntRanges
@@ -678,10 +631,8 @@ mlir::intrange::inferShrS(ArrayRef<ConstantIntRanges> argRanges) {
     return r.uge(r.getBitWidth()) ? std::optional<APInt>() : l.ashr(r);
   };
 
-  return propagatePoison(minMaxBy(ashr, {lhs.smin(), lhs.smax()},
-                                  {rhs.umin(), rhs.umax()},
-                                  /*isSigned=*/true),
-                         argRanges);
+  return minMaxBy(ashr, {lhs.smin(), lhs.smax()}, {rhs.umin(), rhs.umax()},
+                  /*isSigned=*/true);
 }
 
 ConstantIntRanges
@@ -691,10 +642,8 @@ mlir::intrange::inferShrU(ArrayRef<ConstantIntRanges> argRanges) {
   auto lshr = [](const APInt &l, const APInt &r) -> std::optional<APInt> {
     return r.uge(r.getBitWidth()) ? std::optional<APInt>() : l.lshr(r);
   };
-  return propagatePoison(minMaxBy(lshr, {lhs.umin(), lhs.umax()},
-                                  {rhs.umin(), rhs.umax()},
-                                  /*isSigned=*/false),
-                         argRanges);
+  return minMaxBy(lshr, {lhs.umin(), lhs.umax()}, {rhs.umin(), rhs.umax()},
+                  /*isSigned=*/false);
 }
 
 //===----------------------------------------------------------------------===//

>From b3f85159c17fd3bf17b824436dcf9bd72557df89 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sat, 23 Aug 2025 11:15:01 +0200
Subject: [PATCH 6/6] comment

---
 mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp | 3 +++
 1 file changed, 3 insertions(+)

diff --git a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
index 03da1e5327e39..192258b003aef 100644
--- a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
@@ -54,6 +54,9 @@ static bool isPoison(DataFlowSolver &solver, Value value) {
     return false;
   const ConstantIntRanges &inferredRange =
       maybeInferredRange->getValue().getValue();
+
+  // Only generate poison if both signed and unsigned ranges are guranteed to be
+  // poison.
   return inferredRange.isSignedPoison() && inferredRange.isUnsignedPoison();
 }
 



More information about the Mlir-commits mailing list