[Mlir-commits] [mlir] [mlir][IntRange] Poison support in int-range analysis (PR #152932)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sat Aug 16 04:22:18 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Ivan Butygin (Hardcode84)

<details>
<summary>Changes</summary>

* Represent poisoned range as the empty range. 
* Poison is propagated through the DAG unless the select op or CFG merge is reached where we take the non-poison alternative (`feeze(poison)` should also convert poison to full range, but we doesn't model `freeze` in MLIR yet). This enables more `select` optimization opportunities. 
* Vector value range is currently modeled as union of individual elements ranges so this also allows to support the common vector pattern when we create a poison vector and insert elements on-by-one (previously the result was assumed a full range). * Update `intRangeOptimizations` to produce poison op if value range was inferred to be poison..

---

Patch is 29.33 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/152932.diff


12 Files Affected:

- (modified) mlir/include/mlir/Dialect/Arith/Transforms/Passes.td (+2-1) 
- (modified) mlir/include/mlir/Dialect/UB/IR/UBOps.h (+1) 
- (modified) mlir/include/mlir/Dialect/UB/IR/UBOps.td (+4-2) 
- (modified) mlir/include/mlir/Interfaces/InferIntRangeInterface.h (+17) 
- (modified) mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp (+24-1) 
- (modified) mlir/lib/Dialect/UB/IR/UBOps.cpp (+6) 
- (modified) mlir/lib/Interfaces/InferIntRangeInterface.cpp (+89-12) 
- (modified) mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp (+74-23) 
- (modified) mlir/test/Dialect/Arith/int-range-interface.mlir (+20) 
- (modified) mlir/test/Dialect/Arith/int-range-opts.mlir (+11-1) 
- (added) mlir/test/Dialect/UB/int-range-interface.mlir (+24) 
- (modified) mlir/test/Dialect/Vector/int-range-interface.mlir (+17) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td
index c7370b83fdb6c..15ea30ceca96d 100644
--- a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td
@@ -49,7 +49,8 @@ def ArithIntRangeOpts : Pass<"int-range-optimizations"> {
   // Explicitly depend on "arith" because this pass could create operations in
   // `arith` out of thin air in some cases.
   let dependentDialects = [
-    "::mlir::arith::ArithDialect"
+    "::mlir::arith::ArithDialect",
+    "::mlir::ub::UBDialect"
   ];
 }
 
diff --git a/mlir/include/mlir/Dialect/UB/IR/UBOps.h b/mlir/include/mlir/Dialect/UB/IR/UBOps.h
index 21de5cb0c182a..fc2dbad7a8aa7 100644
--- a/mlir/include/mlir/Dialect/UB/IR/UBOps.h
+++ b/mlir/include/mlir/Dialect/UB/IR/UBOps.h
@@ -12,6 +12,7 @@
 #include "mlir/Bytecode/BytecodeOpInterface.h"
 #include "mlir/IR/Dialect.h"
 #include "mlir/IR/OpImplementation.h"
+#include "mlir/Interfaces/InferIntRangeInterface.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 
 #include "mlir/Dialect/UB/IR/UBOpsInterfaces.h.inc"
diff --git a/mlir/include/mlir/Dialect/UB/IR/UBOps.td b/mlir/include/mlir/Dialect/UB/IR/UBOps.td
index f3d5a26ef6f9b..db88838d15dfd 100644
--- a/mlir/include/mlir/Dialect/UB/IR/UBOps.td
+++ b/mlir/include/mlir/Dialect/UB/IR/UBOps.td
@@ -9,8 +9,9 @@
 #ifndef MLIR_DIALECT_UB_IR_UBOPS_TD
 #define MLIR_DIALECT_UB_IR_UBOPS_TD
 
-include "mlir/Interfaces/SideEffectInterfaces.td"
 include "mlir/IR/AttrTypeBase.td"
+include "mlir/Interfaces/InferIntRangeInterface.td"
+include "mlir/Interfaces/SideEffectInterfaces.td"
 
 include "UBOpsInterfaces.td"
 
@@ -39,7 +40,8 @@ def PoisonAttr : UB_Attr<"Poison", "poison", [PoisonAttrInterface]> {
 // PoisonOp
 //===----------------------------------------------------------------------===//
 
-def PoisonOp : UB_Op<"poison", [ConstantLike, Pure]> {
+def PoisonOp : UB_Op<"poison", [ConstantLike, Pure,
+  DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>]> {
   let summary = "Poisoned constant operation.";
   let description = [{
     The `poison` operation materializes a compile-time poisoned constant value
diff --git a/mlir/include/mlir/Interfaces/InferIntRangeInterface.h b/mlir/include/mlir/Interfaces/InferIntRangeInterface.h
index 0e107e88f5232..cfa6cbdade21c 100644
--- a/mlir/include/mlir/Interfaces/InferIntRangeInterface.h
+++ b/mlir/include/mlir/Interfaces/InferIntRangeInterface.h
@@ -51,6 +51,9 @@ class ConstantIntRanges {
   /// The maximum value of an integer when it is interpreted as signed.
   const APInt &smax() const;
 
+  /// Get the bitwidth of the ranges.
+  unsigned getBitWidth() const;
+
   /// Return the bitwidth that should be used for integer ranges describing
   /// `type`. For concrete integer types, this is their bitwidth, for `index`,
   /// this is the internal storage bitwidth of `index` attributes, and for
@@ -62,6 +65,10 @@ class ConstantIntRanges {
   /// sint_max(width)].
   static ConstantIntRanges maxRange(unsigned bitwidth);
 
+  /// Create a poisoned range, i.e. a range that represents no valid integer
+  /// values.
+  static ConstantIntRanges poison(unsigned bitwidth);
+
   /// Create a `ConstantIntRanges` with a constant value - that is, with the
   /// bounds [value, value] for both its signed interpretations.
   static ConstantIntRanges constant(const APInt &value);
@@ -96,6 +103,16 @@ class ConstantIntRanges {
   /// value.
   std::optional<APInt> getConstantValue() const;
 
+  /// Returns true if signed range is poisoned, poisoned ranges are propagated
+  /// through the DAG and will cause the immediate UB if reached the
+  /// side-effecting operation.
+  bool isSignedPoison() const;
+
+  /// Returns true if unsigned range is poisoned, poisoned ranges are propagated
+  /// through the DAG and will cause the immediate UB if reached the
+  /// side-effecting operation.
+  bool isUnsignedPoison() const;
+
   friend raw_ostream &operator<<(raw_ostream &os,
                                  const ConstantIntRanges &range);
 
diff --git a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
index 777ff0ecaa314..03da1e5327e39 100644
--- a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
@@ -14,6 +14,7 @@
 #include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
 #include "mlir/Analysis/DataFlow/IntegerRangeAnalysis.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/UB/IR/UBOps.h"
 #include "mlir/Dialect/Utils/StaticValueUtils.h"
 #include "mlir/IR/IRMapping.h"
 #include "mlir/IR/Matchers.h"
@@ -46,6 +47,16 @@ static std::optional<APInt> getMaybeConstantValue(DataFlowSolver &solver,
   return inferredRange.getConstantValue();
 }
 
+static bool isPoison(DataFlowSolver &solver, Value value) {
+  auto *maybeInferredRange =
+      solver.lookupState<IntegerValueRangeLattice>(value);
+  if (!maybeInferredRange || maybeInferredRange->getValue().isUninitialized())
+    return false;
+  const ConstantIntRanges &inferredRange =
+      maybeInferredRange->getValue().getValue();
+  return inferredRange.isSignedPoison() && inferredRange.isUnsignedPoison();
+}
+
 static void copyIntegerRange(DataFlowSolver &solver, Value oldVal,
                              Value newVal) {
   assert(oldVal.getType() == newVal.getType() &&
@@ -63,6 +74,17 @@ LogicalResult maybeReplaceWithConstant(DataFlowSolver &solver,
                                        RewriterBase &rewriter, Value value) {
   if (value.use_empty())
     return failure();
+
+  if (isPoison(solver, value)) {
+    Value poison =
+        ub::PoisonOp::create(rewriter, value.getLoc(), value.getType());
+    if (solver.lookupState<dataflow::IntegerValueRangeLattice>(poison))
+      solver.eraseState(poison);
+    copyIntegerRange(solver, value, poison);
+    rewriter.replaceAllUsesWith(value, poison);
+    return success();
+  }
+
   std::optional<APInt> maybeConstValue = getMaybeConstantValue(solver, value);
   if (!maybeConstValue.has_value())
     return failure();
@@ -131,7 +153,8 @@ struct MaterializeKnownConstantValues : public RewritePattern {
       return failure();
 
     auto needsReplacing = [&](Value v) {
-      return getMaybeConstantValue(solver, v).has_value() && !v.use_empty();
+      return (getMaybeConstantValue(solver, v) || isPoison(solver, v)) &&
+             !v.use_empty();
     };
     bool hasConstantResults = llvm::any_of(op->getResults(), needsReplacing);
     if (op->getNumRegions() == 0)
diff --git a/mlir/lib/Dialect/UB/IR/UBOps.cpp b/mlir/lib/Dialect/UB/IR/UBOps.cpp
index ee523f9522953..4bb6f0979cfaa 100644
--- a/mlir/lib/Dialect/UB/IR/UBOps.cpp
+++ b/mlir/lib/Dialect/UB/IR/UBOps.cpp
@@ -59,6 +59,12 @@ Operation *UBDialect::materializeConstant(OpBuilder &builder, Attribute value,
 
 OpFoldResult PoisonOp::fold(FoldAdaptor /*adaptor*/) { return getValue(); }
 
+void PoisonOp::inferResultRanges(ArrayRef<ConstantIntRanges> /*argRanges*/,
+                                 SetIntRangeFn setResultRange) {
+  unsigned width = ConstantIntRanges::getStorageBitwidth(getType());
+  setResultRange(getResult(), ConstantIntRanges::poison(width));
+}
+
 #include "mlir/Dialect/UB/IR/UBOpsInterfaces.cpp.inc"
 
 #define GET_ATTRDEF_CLASSES
diff --git a/mlir/lib/Interfaces/InferIntRangeInterface.cpp b/mlir/lib/Interfaces/InferIntRangeInterface.cpp
index 9f3e97d051c85..46b5604bb5731 100644
--- a/mlir/lib/Interfaces/InferIntRangeInterface.cpp
+++ b/mlir/lib/Interfaces/InferIntRangeInterface.cpp
@@ -28,6 +28,8 @@ const APInt &ConstantIntRanges::smin() const { return sminVal; }
 
 const APInt &ConstantIntRanges::smax() const { return smaxVal; }
 
+unsigned ConstantIntRanges::getBitWidth() const { return umin().getBitWidth(); }
+
 unsigned ConstantIntRanges::getStorageBitwidth(Type type) {
   type = getElementTypeOrSelf(type);
   if (type.isIndex())
@@ -42,6 +44,21 @@ ConstantIntRanges ConstantIntRanges::maxRange(unsigned bitwidth) {
   return fromUnsigned(APInt::getZero(bitwidth), APInt::getMaxValue(bitwidth));
 }
 
+ConstantIntRanges ConstantIntRanges::poison(unsigned bitwidth) {
+  if (bitwidth == 0) {
+    auto zero = APInt::getZero(0);
+    return {zero, zero, zero, zero};
+  }
+
+  // Poison is represented by an empty range.
+  auto zero = APInt::getZero(bitwidth);
+  auto one = zero + 1;
+  auto onem = zero - 1;
+  // For i1 the valid unsigned range is [0, 1] and the valid signed range
+  // is [-1, 0].
+  return {one, zero, zero, onem};
+}
+
 ConstantIntRanges ConstantIntRanges::constant(const APInt &value) {
   return {value, value, value, value};
 }
@@ -85,15 +102,44 @@ ConstantIntRanges
 ConstantIntRanges::rangeUnion(const ConstantIntRanges &other) const {
   // "Not an integer" poisons everything and also cannot be fed to comparison
   // operators.
-  if (umin().getBitWidth() == 0)
+  if (getBitWidth() == 0)
     return *this;
-  if (other.umin().getBitWidth() == 0)
+  if (other.getBitWidth() == 0)
     return other;
 
-  const APInt &uminUnion = umin().ult(other.umin()) ? umin() : other.umin();
-  const APInt &umaxUnion = umax().ugt(other.umax()) ? umax() : other.umax();
-  const APInt &sminUnion = smin().slt(other.smin()) ? smin() : other.smin();
-  const APInt &smaxUnion = smax().sgt(other.smax()) ? smax() : other.smax();
+  APInt uminUnion;
+  APInt umaxUnion;
+  APInt sminUnion;
+  APInt smaxUnion;
+
+  // Union of poisoned range with any other range is the other range.
+  // Union is used when we need to merge ranges from multiple indepdenent
+  // sources, e.g. in `arith.select` or CFG merge. "Observing" a poisoned
+  // value (using it in side-effecting operation) will cause the immediate UB.
+  // Well-formed programs should never observe the immediate UB so we assume
+  // result is either unused or only used in circumstances when it received the
+  // non-poisoned argument.
+  if (isUnsignedPoison()) {
+    uminUnion = other.umin();
+    umaxUnion = other.umax();
+  } else if (other.isUnsignedPoison()) {
+    uminUnion = umin();
+    umaxUnion = umax();
+  } else {
+    uminUnion = umin().ult(other.umin()) ? umin() : other.umin();
+    umaxUnion = umax().ugt(other.umax()) ? umax() : other.umax();
+  }
+
+  if (isSignedPoison()) {
+    sminUnion = other.smin();
+    smaxUnion = other.smax();
+  } else if (other.isSignedPoison()) {
+    sminUnion = smin();
+    smaxUnion = smax();
+  } else {
+    sminUnion = smin().slt(other.smin()) ? smin() : other.smin();
+    smaxUnion = smax().sgt(other.smax()) ? smax() : other.smax();
+  }
 
   return {uminUnion, umaxUnion, sminUnion, smaxUnion};
 }
@@ -102,15 +148,38 @@ ConstantIntRanges
 ConstantIntRanges::intersection(const ConstantIntRanges &other) const {
   // "Not an integer" poisons everything and also cannot be fed to comparison
   // operators.
-  if (umin().getBitWidth() == 0)
+  if (getBitWidth() == 0)
     return *this;
-  if (other.umin().getBitWidth() == 0)
+  if (other.getBitWidth() == 0)
     return other;
 
-  const APInt &uminIntersect = umin().ugt(other.umin()) ? umin() : other.umin();
-  const APInt &umaxIntersect = umax().ult(other.umax()) ? umax() : other.umax();
-  const APInt &sminIntersect = smin().sgt(other.smin()) ? smin() : other.smin();
-  const APInt &smaxIntersect = smax().slt(other.smax()) ? smax() : other.smax();
+  APInt uminIntersect;
+  APInt umaxIntersect;
+  APInt sminIntersect;
+  APInt smaxIntersect;
+
+  // Intersection of poisoned range with any other range is poisoned.
+  if (isUnsignedPoison()) {
+    uminIntersect = umin();
+    umaxIntersect = umax();
+  } else if (other.isUnsignedPoison()) {
+    uminIntersect = other.umin();
+    umaxIntersect = other.umax();
+  } else {
+    uminIntersect = umin().ugt(other.umin()) ? umin() : other.umin();
+    umaxIntersect = umax().ult(other.umax()) ? umax() : other.umax();
+  }
+
+  if (isSignedPoison()) {
+    sminIntersect = smin();
+    smaxIntersect = smax();
+  } else if (other.isSignedPoison()) {
+    sminIntersect = other.smin();
+    smaxIntersect = other.smax();
+  } else {
+    sminIntersect = smin().sgt(other.smin()) ? smin() : other.smin();
+    smaxIntersect = smax().slt(other.smax()) ? smax() : other.smax();
+  }
 
   return {uminIntersect, umaxIntersect, sminIntersect, smaxIntersect};
 }
@@ -124,6 +193,14 @@ std::optional<APInt> ConstantIntRanges::getConstantValue() const {
   return std::nullopt;
 }
 
+bool ConstantIntRanges::isSignedPoison() const {
+  return getBitWidth() > 0 && smin().sgt(smax());
+}
+
+bool ConstantIntRanges::isUnsignedPoison() const {
+  return getBitWidth() > 0 && umin().ugt(umax());
+}
+
 raw_ostream &mlir::operator<<(raw_ostream &os, const ConstantIntRanges &range) {
   os << "unsigned : [";
   range.umin().print(os, /*isSigned*/ false);
diff --git a/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp b/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp
index 2f47939df5a02..cc2338c684f58 100644
--- a/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp
+++ b/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp
@@ -32,6 +32,29 @@ using namespace mlir;
 // General utilities
 //===----------------------------------------------------------------------===//
 
+/// If any of the arguments are poison, return poison.
+static ConstantIntRanges
+propagatePoison(const ConstantIntRanges &newRange,
+                ArrayRef<ConstantIntRanges> argRanges) {
+  APInt umin = newRange.umin();
+  APInt umax = newRange.umax();
+  APInt smin = newRange.smin();
+  APInt smax = newRange.smax();
+
+  unsigned width = umin.getBitWidth();
+  for (const ConstantIntRanges &argRange : argRanges) {
+    if (argRange.isSignedPoison()) {
+      smin = APInt::getZero(width);
+      smax = smin - 1;
+    }
+    if (argRange.isUnsignedPoison()) {
+      umax = APInt::getZero(width);
+      umin = umax + 1;
+    }
+  }
+  return {umin, umax, smin, smax};
+}
+
 /// Function that evaluates the result of doing something on arithmetic
 /// constants and returns std::nullopt on overflow.
 using ConstArithFn =
@@ -112,9 +135,10 @@ mlir::intrange::inferIndexOp(const InferRangeFn &inferFn,
   }
   if (truncEqual)
     // Returing the 64-bit result preserves more information.
-    return sixtyFour;
+    return propagatePoison(sixtyFour, argRanges);
+
   ConstantIntRanges merged = sixtyFour.rangeUnion(thirtyTwoAsSixtyFour);
-  return merged;
+  return propagatePoison(merged, argRanges);
 }
 
 ConstantIntRanges mlir::intrange::extRange(const ConstantIntRanges &range,
@@ -123,21 +147,21 @@ ConstantIntRanges mlir::intrange::extRange(const ConstantIntRanges &range,
   APInt umax = range.umax().zext(destWidth);
   APInt smin = range.smin().sext(destWidth);
   APInt smax = range.smax().sext(destWidth);
-  return {umin, umax, smin, smax};
+  return propagatePoison({umin, umax, smin, smax}, range);
 }
 
 ConstantIntRanges mlir::intrange::extUIRange(const ConstantIntRanges &range,
                                              unsigned destWidth) {
   APInt umin = range.umin().zext(destWidth);
   APInt umax = range.umax().zext(destWidth);
-  return ConstantIntRanges::fromUnsigned(umin, umax);
+  return propagatePoison(ConstantIntRanges::fromUnsigned(umin, umax), range);
 }
 
 ConstantIntRanges mlir::intrange::extSIRange(const ConstantIntRanges &range,
                                              unsigned destWidth) {
   APInt smin = range.smin().sext(destWidth);
   APInt smax = range.smax().sext(destWidth);
-  return ConstantIntRanges::fromSigned(smin, smax);
+  return propagatePoison(ConstantIntRanges::fromSigned(smin, smax), range);
 }
 
 ConstantIntRanges mlir::intrange::truncRange(const ConstantIntRanges &range,
@@ -173,7 +197,7 @@ ConstantIntRanges mlir::intrange::truncRange(const ConstantIntRanges &range,
                                  : range.smin().trunc(destWidth);
   APInt smax = hasSignedOverflow ? APInt::getSignedMaxValue(destWidth)
                                  : range.smax().trunc(destWidth);
-  return {umin, umax, smin, smax};
+  return propagatePoison({umin, umax, smin, smax}, range);
 }
 
 //===----------------------------------------------------------------------===//
@@ -206,7 +230,7 @@ mlir::intrange::inferAdd(ArrayRef<ConstantIntRanges> argRanges,
       uadd, lhs.umin(), rhs.umin(), lhs.umax(), rhs.umax(), /*isSigned=*/false);
   ConstantIntRanges srange = computeBoundsBy(
       sadd, lhs.smin(), rhs.smin(), lhs.smax(), rhs.smax(), /*isSigned=*/true);
-  return urange.intersection(srange);
+  return propagatePoison(urange.intersection(srange), argRanges);
 }
 
 //===----------------------------------------------------------------------===//
@@ -238,7 +262,7 @@ mlir::intrange::inferSub(ArrayRef<ConstantIntRanges> argRanges,
       usub, lhs.umin(), rhs.umax(), lhs.umax(), rhs.umin(), /*isSigned=*/false);
   ConstantIntRanges srange = computeBoundsBy(
       ssub, lhs.smin(), rhs.smax(), lhs.smax(), rhs.smin(), /*isSigned=*/true);
-  return urange.intersection(srange);
+  return propagatePoison(urange.intersection(srange), argRanges);
 }
 
 //===----------------------------------------------------------------------===//
@@ -273,7 +297,7 @@ mlir::intrange::inferMul(ArrayRef<ConstantIntRanges> argRanges,
   ConstantIntRanges srange =
       minMaxBy(smul, {lhs.smin(), lhs.smax()}, {rhs.smin(), rhs.smax()},
                /*isSigned=*/true);
-  return urange.intersection(srange);
+  return propagatePoison(urange.intersection(srange), argRanges);
 }
 
 //===----------------------------------------------------------------------===//
@@ -306,7 +330,8 @@ static ConstantIntRanges inferDivURange(const ConstantIntRanges &lhs,
 
   // X u/ Y u<= X.
   APInt umax = lhsMax;
-  return ConstantIntRanges::fromUnsigned(umin, umax);
+  return propagatePoison(ConstantIntRanges::fromUnsigned(umin, umax),
+                         {lhs, rhs});
 }
 
 ConstantIntRanges
@@ -351,10 +376,12 @@ static ConstantIntRanges inferDivSRange(const ConstantIntRanges &lhs,
       APInt result = a.sdiv_ov(b, overflowed);
       return overflowed ? std::optional<APInt>() : fixup(a, b, result);
     };
-    return minMaxBy(sdiv, {lhsMin, lhsMax}, {rhsMin, rhsMax},
-                    /*isSigned=*/true);
+    return propagatePoison(minMaxBy(sdiv, {lhsMin, lhsMax}, {rhsMin, rhsMax},
+                                    /*isSigned=*/true),
+                           {lhs, rhs});
   }
-  return ConstantIntRanges::maxRange(rhsMin.getBitWidth());
+  return propagatePoison(ConstantIntRanges::maxRange(rhsMin.getBitWidth()),
+                         {lhs, rhs});
 }
 
 ConstantIntRanges
@@ -395,7 +422,7 @@ mlir::intrange::inferCeilDivS(ArrayRef<ConstantIntRanges> argRanges) {
     auto newLhs = ConstantIntRanges::fromSigned(lhs.smin() + 1, lhs.smax());
     result = result.rangeUnion(inferDivSRange(newLhs, rhs, ceilDivSIFix));
   }
-  return result;
+  return propagatePoison(result, argRanges);
 }
 
 ConstantIntRanges
@@ -425,6 +452,9 @@ mlir::intrange::inferRemS(ArrayRef<ConstantIntRanges> argRanges) {
   const APInt &lhsMin = lhs.smin(), &lhsMax = lhs.smax(), &rhsMin = rhs.smin(),
               &rhsMax = rhs.smax();
 
+  if (lhs.isSignedPoison() || rhs.isSignedPoison())
+    return ConstantIntRanges::poison(rhsMin.getBitWidth());
+
   unsigned width = rhsMax.getBitWidth();
   APInt smin = APInt::getSignedMinValue(width);
   APInt smax = APInt::getSignedMaxValue(width);
@@ -463,6 +493,9 @@ mlir::intrange::inferRemU(ArrayRef<ConstantIntRanges> argRanges) {
   const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
   const APInt &rhsMin = rhs.umin(), &rhsMax = rhs.umax();
 
+  if (lhs.isUnsignedPoison() || rhs.isUnsignedPoison())
+    return ConstantIntRanges::poison(rhsMin.getBitWidth());
+
   unsigned width = rhsMin.getBitWidth();
   APInt umin = APInt::getZero(width);
   // Remainder can't be larger than either of its arguments.
@@ -492,6 +525,8 @@ mlir::intrange::inferRemU(ArrayRef<ConstantIntRanges> argRanges) {
 ConstantIntRanges
 mlir::intrange::inferMaxS(ArrayRef<ConstantIntRanges> argRanges) {
   const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
+  if (lhs.isSignedPoison() || rhs.isSignedPoison())
+    return ConstantIntRanges::poison(lhs.smin().getBitWidth());
 
   const APInt &smin = lhs.smin().sgt(rhs.smin()) ? lhs.smin() : rhs.smin();
   const APInt &smax = lhs.smax().sgt(rhs.smax()) ? lhs.smax() : rhs.smax();
@@ -501,6 +536,...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/152932


More information about the Mlir-commits mailing list