[Mlir-commits] [mlir] [mlir][IntRange] Poison support in int-range analysis (PR #152932)

Ivan Butygin llvmlistbot at llvm.org
Sun Aug 10 14:27:56 PDT 2025


https://github.com/Hardcode84 updated https://github.com/llvm/llvm-project/pull/152932

>From f89476dd605ef37acae1a7f9f2fb3b8e3919117d Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sun, 10 Aug 2025 19:39:54 +0200
Subject: [PATCH] [mlir][IntRange] Poison support in int-range analysis

---
 .../mlir/Dialect/Arith/Transforms/Passes.td   |  3 +-
 mlir/include/mlir/Dialect/UB/IR/UBOps.h       |  1 +
 mlir/include/mlir/Dialect/UB/IR/UBOps.td      |  6 +-
 .../mlir/Interfaces/InferIntRangeInterface.h  | 15 +++
 .../Transforms/IntRangeOptimizations.cpp      | 25 ++++-
 mlir/lib/Dialect/UB/IR/UBOps.cpp              |  6 ++
 .../lib/Interfaces/InferIntRangeInterface.cpp | 93 +++++++++++++++---
 .../Interfaces/Utils/InferIntRangeCommon.cpp  | 94 ++++++++++++++-----
 .../Dialect/Arith/int-range-interface.mlir    | 20 ++++
 mlir/test/Dialect/Arith/int-range-opts.mlir   | 12 ++-
 mlir/test/Dialect/UB/int-range-interface.mlir | 24 +++++
 .../Dialect/Vector/int-range-interface.mlir   | 17 ++++
 12 files changed, 277 insertions(+), 39 deletions(-)
 create mode 100644 mlir/test/Dialect/UB/int-range-interface.mlir

diff --git a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td
index c7370b83fdb6c..15ea30ceca96d 100644
--- a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td
@@ -49,7 +49,8 @@ def ArithIntRangeOpts : Pass<"int-range-optimizations"> {
   // Explicitly depend on "arith" because this pass could create operations in
   // `arith` out of thin air in some cases.
   let dependentDialects = [
-    "::mlir::arith::ArithDialect"
+    "::mlir::arith::ArithDialect",
+    "::mlir::ub::UBDialect"
   ];
 }
 
diff --git a/mlir/include/mlir/Dialect/UB/IR/UBOps.h b/mlir/include/mlir/Dialect/UB/IR/UBOps.h
index 21de5cb0c182a..fc2dbad7a8aa7 100644
--- a/mlir/include/mlir/Dialect/UB/IR/UBOps.h
+++ b/mlir/include/mlir/Dialect/UB/IR/UBOps.h
@@ -12,6 +12,7 @@
 #include "mlir/Bytecode/BytecodeOpInterface.h"
 #include "mlir/IR/Dialect.h"
 #include "mlir/IR/OpImplementation.h"
+#include "mlir/Interfaces/InferIntRangeInterface.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 
 #include "mlir/Dialect/UB/IR/UBOpsInterfaces.h.inc"
diff --git a/mlir/include/mlir/Dialect/UB/IR/UBOps.td b/mlir/include/mlir/Dialect/UB/IR/UBOps.td
index f3d5a26ef6f9b..db88838d15dfd 100644
--- a/mlir/include/mlir/Dialect/UB/IR/UBOps.td
+++ b/mlir/include/mlir/Dialect/UB/IR/UBOps.td
@@ -9,8 +9,9 @@
 #ifndef MLIR_DIALECT_UB_IR_UBOPS_TD
 #define MLIR_DIALECT_UB_IR_UBOPS_TD
 
-include "mlir/Interfaces/SideEffectInterfaces.td"
 include "mlir/IR/AttrTypeBase.td"
+include "mlir/Interfaces/InferIntRangeInterface.td"
+include "mlir/Interfaces/SideEffectInterfaces.td"
 
 include "UBOpsInterfaces.td"
 
@@ -39,7 +40,8 @@ def PoisonAttr : UB_Attr<"Poison", "poison", [PoisonAttrInterface]> {
 // PoisonOp
 //===----------------------------------------------------------------------===//
 
-def PoisonOp : UB_Op<"poison", [ConstantLike, Pure]> {
+def PoisonOp : UB_Op<"poison", [ConstantLike, Pure,
+  DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>]> {
   let summary = "Poisoned constant operation.";
   let description = [{
     The `poison` operation materializes a compile-time poisoned constant value
diff --git a/mlir/include/mlir/Interfaces/InferIntRangeInterface.h b/mlir/include/mlir/Interfaces/InferIntRangeInterface.h
index 0e107e88f5232..4e4f3725a69fd 100644
--- a/mlir/include/mlir/Interfaces/InferIntRangeInterface.h
+++ b/mlir/include/mlir/Interfaces/InferIntRangeInterface.h
@@ -51,6 +51,9 @@ class ConstantIntRanges {
   /// The maximum value of an integer when it is interpreted as signed.
   const APInt &smax() const;
 
+  /// Get the bitwidth of the ranges.
+  unsigned getBitWidth() const;
+
   /// Return the bitwidth that should be used for integer ranges describing
   /// `type`. For concrete integer types, this is their bitwidth, for `index`,
   /// this is the internal storage bitwidth of `index` attributes, and for
@@ -62,6 +65,10 @@ class ConstantIntRanges {
   /// sint_max(width)].
   static ConstantIntRanges maxRange(unsigned bitwidth);
 
+  /// Create a poisoned range, i.e. a range that represents no valid integer
+  /// values.
+  static ConstantIntRanges poison(unsigned bitwidth);
+
   /// Create a `ConstantIntRanges` with a constant value - that is, with the
   /// bounds [value, value] for both its signed interpretations.
   static ConstantIntRanges constant(const APInt &value);
@@ -96,6 +103,14 @@ class ConstantIntRanges {
   /// value.
   std::optional<APInt> getConstantValue() const;
 
+  /// Returns true if signed range is poisoned, i.e. no valid signed value
+  /// can be represented.
+  bool isSignedPoison() const;
+
+  /// Returns true if unsigned range is poisoned, i.e. no valid unsigned value
+  /// can be represented.
+  bool isUnsignedPoison() const;
+
   friend raw_ostream &operator<<(raw_ostream &os,
                                  const ConstantIntRanges &range);
 
diff --git a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
index 777ff0ecaa314..03da1e5327e39 100644
--- a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
@@ -14,6 +14,7 @@
 #include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
 #include "mlir/Analysis/DataFlow/IntegerRangeAnalysis.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/UB/IR/UBOps.h"
 #include "mlir/Dialect/Utils/StaticValueUtils.h"
 #include "mlir/IR/IRMapping.h"
 #include "mlir/IR/Matchers.h"
@@ -46,6 +47,16 @@ static std::optional<APInt> getMaybeConstantValue(DataFlowSolver &solver,
   return inferredRange.getConstantValue();
 }
 
+static bool isPoison(DataFlowSolver &solver, Value value) {
+  auto *maybeInferredRange =
+      solver.lookupState<IntegerValueRangeLattice>(value);
+  if (!maybeInferredRange || maybeInferredRange->getValue().isUninitialized())
+    return false;
+  const ConstantIntRanges &inferredRange =
+      maybeInferredRange->getValue().getValue();
+  return inferredRange.isSignedPoison() && inferredRange.isUnsignedPoison();
+}
+
 static void copyIntegerRange(DataFlowSolver &solver, Value oldVal,
                              Value newVal) {
   assert(oldVal.getType() == newVal.getType() &&
@@ -63,6 +74,17 @@ LogicalResult maybeReplaceWithConstant(DataFlowSolver &solver,
                                        RewriterBase &rewriter, Value value) {
   if (value.use_empty())
     return failure();
+
+  if (isPoison(solver, value)) {
+    Value poison =
+        ub::PoisonOp::create(rewriter, value.getLoc(), value.getType());
+    if (solver.lookupState<dataflow::IntegerValueRangeLattice>(poison))
+      solver.eraseState(poison);
+    copyIntegerRange(solver, value, poison);
+    rewriter.replaceAllUsesWith(value, poison);
+    return success();
+  }
+
   std::optional<APInt> maybeConstValue = getMaybeConstantValue(solver, value);
   if (!maybeConstValue.has_value())
     return failure();
@@ -131,7 +153,8 @@ struct MaterializeKnownConstantValues : public RewritePattern {
       return failure();
 
     auto needsReplacing = [&](Value v) {
-      return getMaybeConstantValue(solver, v).has_value() && !v.use_empty();
+      return (getMaybeConstantValue(solver, v) || isPoison(solver, v)) &&
+             !v.use_empty();
     };
     bool hasConstantResults = llvm::any_of(op->getResults(), needsReplacing);
     if (op->getNumRegions() == 0)
diff --git a/mlir/lib/Dialect/UB/IR/UBOps.cpp b/mlir/lib/Dialect/UB/IR/UBOps.cpp
index ee523f9522953..4bb6f0979cfaa 100644
--- a/mlir/lib/Dialect/UB/IR/UBOps.cpp
+++ b/mlir/lib/Dialect/UB/IR/UBOps.cpp
@@ -59,6 +59,12 @@ Operation *UBDialect::materializeConstant(OpBuilder &builder, Attribute value,
 
 OpFoldResult PoisonOp::fold(FoldAdaptor /*adaptor*/) { return getValue(); }
 
+void PoisonOp::inferResultRanges(ArrayRef<ConstantIntRanges> /*argRanges*/,
+                                 SetIntRangeFn setResultRange) {
+  unsigned width = ConstantIntRanges::getStorageBitwidth(getType());
+  setResultRange(getResult(), ConstantIntRanges::poison(width));
+}
+
 #include "mlir/Dialect/UB/IR/UBOpsInterfaces.cpp.inc"
 
 #define GET_ATTRDEF_CLASSES
diff --git a/mlir/lib/Interfaces/InferIntRangeInterface.cpp b/mlir/lib/Interfaces/InferIntRangeInterface.cpp
index 9f3e97d051c85..4f6ab0306229f 100644
--- a/mlir/lib/Interfaces/InferIntRangeInterface.cpp
+++ b/mlir/lib/Interfaces/InferIntRangeInterface.cpp
@@ -28,6 +28,8 @@ const APInt &ConstantIntRanges::smin() const { return sminVal; }
 
 const APInt &ConstantIntRanges::smax() const { return smaxVal; }
 
+unsigned ConstantIntRanges::getBitWidth() const { return umin().getBitWidth(); }
+
 unsigned ConstantIntRanges::getStorageBitwidth(Type type) {
   type = getElementTypeOrSelf(type);
   if (type.isIndex())
@@ -42,6 +44,21 @@ ConstantIntRanges ConstantIntRanges::maxRange(unsigned bitwidth) {
   return fromUnsigned(APInt::getZero(bitwidth), APInt::getMaxValue(bitwidth));
 }
 
+ConstantIntRanges ConstantIntRanges::poison(unsigned bitwidth) {
+  if (bitwidth == 0) {
+    auto zero = APInt::getZero(0);
+    return {zero, zero, zero, zero};
+  }
+
+  // Poison is represented by an empty range.
+  auto zero = APInt::getZero(bitwidth);
+  auto one = zero + 1;
+  auto onem = zero - 1;
+  // For i1 the valid unsigned range is [0, 1] and the valid signed range
+  // is [-1, 0].
+  return {one, zero, zero, onem};
+}
+
 ConstantIntRanges ConstantIntRanges::constant(const APInt &value) {
   return {value, value, value, value};
 }
@@ -85,15 +102,37 @@ ConstantIntRanges
 ConstantIntRanges::rangeUnion(const ConstantIntRanges &other) const {
   // "Not an integer" poisons everything and also cannot be fed to comparison
   // operators.
-  if (umin().getBitWidth() == 0)
+  if (getBitWidth() == 0)
     return *this;
-  if (other.umin().getBitWidth() == 0)
+  if (other.getBitWidth() == 0)
     return other;
 
-  const APInt &uminUnion = umin().ult(other.umin()) ? umin() : other.umin();
-  const APInt &umaxUnion = umax().ugt(other.umax()) ? umax() : other.umax();
-  const APInt &sminUnion = smin().slt(other.smin()) ? smin() : other.smin();
-  const APInt &smaxUnion = smax().sgt(other.smax()) ? smax() : other.smax();
+  APInt uminUnion;
+  APInt umaxUnion;
+  APInt sminUnion;
+  APInt smaxUnion;
+
+  if (isUnsignedPoison()) {
+    uminUnion = other.umin();
+    umaxUnion = other.umax();
+  } else if (other.isUnsignedPoison()) {
+    uminUnion = umin();
+    umaxUnion = umax();
+  } else {
+    uminUnion = umin().ult(other.umin()) ? umin() : other.umin();
+    umaxUnion = umax().ugt(other.umax()) ? umax() : other.umax();
+  }
+
+  if (isSignedPoison()) {
+    sminUnion = other.smin();
+    smaxUnion = other.smax();
+  } else if (other.isSignedPoison()) {
+    sminUnion = smin();
+    smaxUnion = smax();
+  } else {
+    sminUnion = smin().slt(other.smin()) ? smin() : other.smin();
+    smaxUnion = smax().sgt(other.smax()) ? smax() : other.smax();
+  }
 
   return {uminUnion, umaxUnion, sminUnion, smaxUnion};
 }
@@ -102,15 +141,37 @@ ConstantIntRanges
 ConstantIntRanges::intersection(const ConstantIntRanges &other) const {
   // "Not an integer" poisons everything and also cannot be fed to comparison
   // operators.
-  if (umin().getBitWidth() == 0)
+  if (getBitWidth() == 0)
     return *this;
-  if (other.umin().getBitWidth() == 0)
+  if (other.getBitWidth() == 0)
     return other;
 
-  const APInt &uminIntersect = umin().ugt(other.umin()) ? umin() : other.umin();
-  const APInt &umaxIntersect = umax().ult(other.umax()) ? umax() : other.umax();
-  const APInt &sminIntersect = smin().sgt(other.smin()) ? smin() : other.smin();
-  const APInt &smaxIntersect = smax().slt(other.smax()) ? smax() : other.smax();
+  APInt uminIntersect;
+  APInt umaxIntersect;
+  APInt sminIntersect;
+  APInt smaxIntersect;
+
+  if (isUnsignedPoison()) {
+    uminIntersect = umin();
+    umaxIntersect = umax();
+  } else if (other.isUnsignedPoison()) {
+    uminIntersect = other.umin();
+    umaxIntersect = other.umax();
+  } else {
+    uminIntersect = umin().ugt(other.umin()) ? umin() : other.umin();
+    umaxIntersect = umax().ult(other.umax()) ? umax() : other.umax();
+  }
+
+  if (isSignedPoison()) {
+    sminIntersect = smin();
+    smaxIntersect = smax();
+  } else if (other.isSignedPoison()) {
+    sminIntersect = other.smin();
+    smaxIntersect = other.smax();
+  } else {
+    sminIntersect = smin().sgt(other.smin()) ? smin() : other.smin();
+    smaxIntersect = smax().slt(other.smax()) ? smax() : other.smax();
+  }
 
   return {uminIntersect, umaxIntersect, sminIntersect, smaxIntersect};
 }
@@ -124,6 +185,14 @@ std::optional<APInt> ConstantIntRanges::getConstantValue() const {
   return std::nullopt;
 }
 
+bool ConstantIntRanges::isSignedPoison() const {
+  return getBitWidth() > 0 && smin().sgt(smax());
+}
+
+bool ConstantIntRanges::isUnsignedPoison() const {
+  return getBitWidth() > 0 && umin().ugt(umax());
+}
+
 raw_ostream &mlir::operator<<(raw_ostream &os, const ConstantIntRanges &range) {
   os << "unsigned : [";
   range.umin().print(os, /*isSigned*/ false);
diff --git a/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp b/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp
index 2f47939df5a02..36841e2f2cc9a 100644
--- a/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp
+++ b/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp
@@ -32,6 +32,29 @@ using namespace mlir;
 // General utilities
 //===----------------------------------------------------------------------===//
 
+/// If any of the arguments are poison, return poison.
+static ConstantIntRanges
+propagatePoison(const ConstantIntRanges &newRange,
+                ArrayRef<ConstantIntRanges> argRanges) {
+  APInt umin = newRange.umin();
+  APInt umax = newRange.umax();
+  APInt smin = newRange.smin();
+  APInt smax = newRange.smax();
+
+  unsigned width = umin.getBitWidth();
+  for (const ConstantIntRanges &argRange : argRanges) {
+    if (argRange.isSignedPoison()) {
+      smin = APInt::getZero(width);
+      smax = smin - 1;
+    }
+    if (argRange.isUnsignedPoison()) {
+      umax = APInt::getZero(width);
+      umin = umax + 1;
+    }
+  }
+  return {umin, umax, smin, smax};
+}
+
 /// Function that evaluates the result of doing something on arithmetic
 /// constants and returns std::nullopt on overflow.
 using ConstArithFn =
@@ -114,7 +137,7 @@ mlir::intrange::inferIndexOp(const InferRangeFn &inferFn,
     // Returing the 64-bit result preserves more information.
     return sixtyFour;
   ConstantIntRanges merged = sixtyFour.rangeUnion(thirtyTwoAsSixtyFour);
-  return merged;
+  return propagatePoison(merged, argRanges);
 }
 
 ConstantIntRanges mlir::intrange::extRange(const ConstantIntRanges &range,
@@ -123,21 +146,21 @@ ConstantIntRanges mlir::intrange::extRange(const ConstantIntRanges &range,
   APInt umax = range.umax().zext(destWidth);
   APInt smin = range.smin().sext(destWidth);
   APInt smax = range.smax().sext(destWidth);
-  return {umin, umax, smin, smax};
+  return propagatePoison({umin, umax, smin, smax}, range);
 }
 
 ConstantIntRanges mlir::intrange::extUIRange(const ConstantIntRanges &range,
                                              unsigned destWidth) {
   APInt umin = range.umin().zext(destWidth);
   APInt umax = range.umax().zext(destWidth);
-  return ConstantIntRanges::fromUnsigned(umin, umax);
+  return propagatePoison(ConstantIntRanges::fromUnsigned(umin, umax), range);
 }
 
 ConstantIntRanges mlir::intrange::extSIRange(const ConstantIntRanges &range,
                                              unsigned destWidth) {
   APInt smin = range.smin().sext(destWidth);
   APInt smax = range.smax().sext(destWidth);
-  return ConstantIntRanges::fromSigned(smin, smax);
+  return propagatePoison(ConstantIntRanges::fromSigned(smin, smax), range);
 }
 
 ConstantIntRanges mlir::intrange::truncRange(const ConstantIntRanges &range,
@@ -173,7 +196,7 @@ ConstantIntRanges mlir::intrange::truncRange(const ConstantIntRanges &range,
                                  : range.smin().trunc(destWidth);
   APInt smax = hasSignedOverflow ? APInt::getSignedMaxValue(destWidth)
                                  : range.smax().trunc(destWidth);
-  return {umin, umax, smin, smax};
+  return propagatePoison({umin, umax, smin, smax}, range);
 }
 
 //===----------------------------------------------------------------------===//
@@ -206,7 +229,7 @@ mlir::intrange::inferAdd(ArrayRef<ConstantIntRanges> argRanges,
       uadd, lhs.umin(), rhs.umin(), lhs.umax(), rhs.umax(), /*isSigned=*/false);
   ConstantIntRanges srange = computeBoundsBy(
       sadd, lhs.smin(), rhs.smin(), lhs.smax(), rhs.smax(), /*isSigned=*/true);
-  return urange.intersection(srange);
+  return propagatePoison(urange.intersection(srange), argRanges);
 }
 
 //===----------------------------------------------------------------------===//
@@ -238,7 +261,7 @@ mlir::intrange::inferSub(ArrayRef<ConstantIntRanges> argRanges,
       usub, lhs.umin(), rhs.umax(), lhs.umax(), rhs.umin(), /*isSigned=*/false);
   ConstantIntRanges srange = computeBoundsBy(
       ssub, lhs.smin(), rhs.smax(), lhs.smax(), rhs.smin(), /*isSigned=*/true);
-  return urange.intersection(srange);
+  return propagatePoison(urange.intersection(srange), argRanges);
 }
 
 //===----------------------------------------------------------------------===//
@@ -273,7 +296,7 @@ mlir::intrange::inferMul(ArrayRef<ConstantIntRanges> argRanges,
   ConstantIntRanges srange =
       minMaxBy(smul, {lhs.smin(), lhs.smax()}, {rhs.smin(), rhs.smax()},
                /*isSigned=*/true);
-  return urange.intersection(srange);
+  return propagatePoison(urange.intersection(srange), argRanges);
 }
 
 //===----------------------------------------------------------------------===//
@@ -306,7 +329,8 @@ static ConstantIntRanges inferDivURange(const ConstantIntRanges &lhs,
 
   // X u/ Y u<= X.
   APInt umax = lhsMax;
-  return ConstantIntRanges::fromUnsigned(umin, umax);
+  return propagatePoison(ConstantIntRanges::fromUnsigned(umin, umax),
+                         {lhs, rhs});
 }
 
 ConstantIntRanges
@@ -351,10 +375,12 @@ static ConstantIntRanges inferDivSRange(const ConstantIntRanges &lhs,
       APInt result = a.sdiv_ov(b, overflowed);
       return overflowed ? std::optional<APInt>() : fixup(a, b, result);
     };
-    return minMaxBy(sdiv, {lhsMin, lhsMax}, {rhsMin, rhsMax},
-                    /*isSigned=*/true);
+    return propagatePoison(minMaxBy(sdiv, {lhsMin, lhsMax}, {rhsMin, rhsMax},
+                                    /*isSigned=*/true),
+                           {lhs, rhs});
   }
-  return ConstantIntRanges::maxRange(rhsMin.getBitWidth());
+  return propagatePoison(ConstantIntRanges::maxRange(rhsMin.getBitWidth()),
+                         {lhs, rhs});
 }
 
 ConstantIntRanges
@@ -395,7 +421,7 @@ mlir::intrange::inferCeilDivS(ArrayRef<ConstantIntRanges> argRanges) {
     auto newLhs = ConstantIntRanges::fromSigned(lhs.smin() + 1, lhs.smax());
     result = result.rangeUnion(inferDivSRange(newLhs, rhs, ceilDivSIFix));
   }
-  return result;
+  return propagatePoison(result, argRanges);
 }
 
 ConstantIntRanges
@@ -425,6 +451,9 @@ mlir::intrange::inferRemS(ArrayRef<ConstantIntRanges> argRanges) {
   const APInt &lhsMin = lhs.smin(), &lhsMax = lhs.smax(), &rhsMin = rhs.smin(),
               &rhsMax = rhs.smax();
 
+  if (lhs.isSignedPoison() || rhs.isSignedPoison())
+    return ConstantIntRanges::poison(rhsMin.getBitWidth());
+
   unsigned width = rhsMax.getBitWidth();
   APInt smin = APInt::getSignedMinValue(width);
   APInt smax = APInt::getSignedMaxValue(width);
@@ -463,6 +492,9 @@ mlir::intrange::inferRemU(ArrayRef<ConstantIntRanges> argRanges) {
   const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
   const APInt &rhsMin = rhs.umin(), &rhsMax = rhs.umax();
 
+  if (lhs.isUnsignedPoison() || rhs.isUnsignedPoison())
+    return ConstantIntRanges::poison(rhsMin.getBitWidth());
+
   unsigned width = rhsMin.getBitWidth();
   APInt umin = APInt::getZero(width);
   // Remainder can't be larger than either of its arguments.
@@ -492,6 +524,8 @@ mlir::intrange::inferRemU(ArrayRef<ConstantIntRanges> argRanges) {
 ConstantIntRanges
 mlir::intrange::inferMaxS(ArrayRef<ConstantIntRanges> argRanges) {
   const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
+  if (lhs.isSignedPoison() || rhs.isSignedPoison())
+    return ConstantIntRanges::poison(lhs.smin().getBitWidth());
 
   const APInt &smin = lhs.smin().sgt(rhs.smin()) ? lhs.smin() : rhs.smin();
   const APInt &smax = lhs.smax().sgt(rhs.smax()) ? lhs.smax() : rhs.smax();
@@ -501,6 +535,8 @@ mlir::intrange::inferMaxS(ArrayRef<ConstantIntRanges> argRanges) {
 ConstantIntRanges
 mlir::intrange::inferMaxU(ArrayRef<ConstantIntRanges> argRanges) {
   const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
+  if (lhs.isUnsignedPoison() || rhs.isUnsignedPoison())
+    return ConstantIntRanges::poison(lhs.umin().getBitWidth());
 
   const APInt &umin = lhs.umin().ugt(rhs.umin()) ? lhs.umin() : rhs.umin();
   const APInt &umax = lhs.umax().ugt(rhs.umax()) ? lhs.umax() : rhs.umax();
@@ -510,6 +546,8 @@ mlir::intrange::inferMaxU(ArrayRef<ConstantIntRanges> argRanges) {
 ConstantIntRanges
 mlir::intrange::inferMinS(ArrayRef<ConstantIntRanges> argRanges) {
   const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
+  if (lhs.isSignedPoison() || rhs.isSignedPoison())
+    return ConstantIntRanges::poison(lhs.smin().getBitWidth());
 
   const APInt &smin = lhs.smin().slt(rhs.smin()) ? lhs.smin() : rhs.smin();
   const APInt &smax = lhs.smax().slt(rhs.smax()) ? lhs.smax() : rhs.smax();
@@ -519,6 +557,8 @@ mlir::intrange::inferMinS(ArrayRef<ConstantIntRanges> argRanges) {
 ConstantIntRanges
 mlir::intrange::inferMinU(ArrayRef<ConstantIntRanges> argRanges) {
   const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
+  if (lhs.isUnsignedPoison() || rhs.isUnsignedPoison())
+    return ConstantIntRanges::poison(lhs.umin().getBitWidth());
 
   const APInt &umin = lhs.umin().ult(rhs.umin()) ? lhs.umin() : rhs.umin();
   const APInt &umax = lhs.umax().ult(rhs.umax()) ? lhs.umax() : rhs.umax();
@@ -550,8 +590,10 @@ mlir::intrange::inferAnd(ArrayRef<ConstantIntRanges> argRanges) {
   auto andi = [](const APInt &a, const APInt &b) -> std::optional<APInt> {
     return a & b;
   };
-  return minMaxBy(andi, {lhsZeros, lhsOnes}, {rhsZeros, rhsOnes},
-                  /*isSigned=*/false);
+  return propagatePoison(minMaxBy(andi, {lhsZeros, lhsOnes},
+                                  {rhsZeros, rhsOnes},
+                                  /*isSigned=*/false),
+                         argRanges);
 }
 
 ConstantIntRanges
@@ -561,8 +603,9 @@ mlir::intrange::inferOr(ArrayRef<ConstantIntRanges> argRanges) {
   auto ori = [](const APInt &a, const APInt &b) -> std::optional<APInt> {
     return a | b;
   };
-  return minMaxBy(ori, {lhsZeros, lhsOnes}, {rhsZeros, rhsOnes},
-                  /*isSigned=*/false);
+  return propagatePoison(minMaxBy(ori, {lhsZeros, lhsOnes}, {rhsZeros, rhsOnes},
+                                  /*isSigned=*/false),
+                         argRanges);
 }
 
 /// Get bitmask of all bits which can change while iterating in
@@ -579,6 +622,9 @@ mlir::intrange::inferXor(ArrayRef<ConstantIntRanges> argRanges) {
   // Construct mask of varying bits for both ranges, xor values and then replace
   // masked bits with 0s and 1s to get min and max values respectively.
   ConstantIntRanges lhs = argRanges[0], rhs = argRanges[1];
+  if (lhs.isUnsignedPoison() || rhs.isUnsignedPoison())
+    return ConstantIntRanges::poison(lhs.umin().getBitWidth());
+
   APInt mask = getVaryingBitsMask(lhs) | getVaryingBitsMask(rhs);
   APInt res = lhs.umin() ^ rhs.umin();
   APInt min = res & ~mask;
@@ -621,7 +667,7 @@ mlir::intrange::inferShl(ArrayRef<ConstantIntRanges> argRanges,
   ConstantIntRanges srange =
       minMaxBy(sshl, {lhs.smin(), lhs.smax()}, {rhsUMin, rhsUMax},
                /*isSigned=*/true);
-  return urange.intersection(srange);
+  return propagatePoison(urange.intersection(srange), argRanges);
 }
 
 ConstantIntRanges
@@ -632,8 +678,10 @@ mlir::intrange::inferShrS(ArrayRef<ConstantIntRanges> argRanges) {
     return r.uge(r.getBitWidth()) ? std::optional<APInt>() : l.ashr(r);
   };
 
-  return minMaxBy(ashr, {lhs.smin(), lhs.smax()}, {rhs.umin(), rhs.umax()},
-                  /*isSigned=*/true);
+  return propagatePoison(minMaxBy(ashr, {lhs.smin(), lhs.smax()},
+                                  {rhs.umin(), rhs.umax()},
+                                  /*isSigned=*/true),
+                         argRanges);
 }
 
 ConstantIntRanges
@@ -643,8 +691,10 @@ mlir::intrange::inferShrU(ArrayRef<ConstantIntRanges> argRanges) {
   auto lshr = [](const APInt &l, const APInt &r) -> std::optional<APInt> {
     return r.uge(r.getBitWidth()) ? std::optional<APInt>() : l.lshr(r);
   };
-  return minMaxBy(lshr, {lhs.umin(), lhs.umax()}, {rhs.umin(), rhs.umax()},
-                  /*isSigned=*/false);
+  return propagatePoison(minMaxBy(lshr, {lhs.umin(), lhs.umax()},
+                                  {rhs.umin(), rhs.umax()},
+                                  /*isSigned=*/false),
+                         argRanges);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Arith/int-range-interface.mlir b/mlir/test/Dialect/Arith/int-range-interface.mlir
index 2128d36f1a28e..35e755e98b1f9 100644
--- a/mlir/test/Dialect/Arith/int-range-interface.mlir
+++ b/mlir/test/Dialect/Arith/int-range-interface.mlir
@@ -654,6 +654,26 @@ func.func @select_union(%arg0 : index, %arg1 : i1) -> i1 {
     func.return %5 : i1
 }
 
+// CHECK-LABEL: func @select_poison
+// CHECK: test.reflect_bounds {smax = 10 : index, smin = 0 : index, umax = 10 : index, umin = 0 : index}
+func.func @select_poison(%arg0: i1) -> index {
+    %0 = test.with_bounds { umin = 0 : index, umax = 10 : index, smin = 0 : index, smax = 10 : index } : index
+    %1 = test.with_bounds { umin = 1 : index, umax = 0 : index, smin = 1 : index, smax = 0 : index } : index
+    %2 = arith.select %arg0, %0, %1 : index
+    %3 = test.reflect_bounds %2 : index
+    func.return %3 : index
+}
+
+// CHECK-LABEL: func @add_posion
+// CHECK: test.reflect_bounds {smax = -1 : index, smin = 0 : index, umax = 0 : index, umin = 1 : index}
+func.func @add_posion() -> index {
+    %0 = test.with_bounds { umin = 0 : index, umax = 10 : index, smin = 0 : index, smax = 10 : index } : index
+    %1 = test.with_bounds { umin = 1 : index, umax = 0 : index, smin = 1 : index, smax = 0 : index } : index
+    %2 = arith.addi %0, %1 : index
+    %3 = test.reflect_bounds %2 : index
+    func.return %3 : index
+}
+
 // CHECK-LABEL: func @if_union
 // CHECK: %[[true:.*]] = arith.constant true
 // CHECK: return %[[true]]
diff --git a/mlir/test/Dialect/Arith/int-range-opts.mlir b/mlir/test/Dialect/Arith/int-range-opts.mlir
index ea5969a100258..d613f38a55f01 100644
--- a/mlir/test/Dialect/Arith/int-range-opts.mlir
+++ b/mlir/test/Dialect/Arith/int-range-opts.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -int-range-optimizations --split-input-file %s | FileCheck %s
+// RUN: mlir-opt --int-range-optimizations --split-input-file %s | FileCheck %s
 
 // CHECK-LABEL: func @test
 //       CHECK:   %[[C:.*]] = arith.constant false
@@ -132,3 +132,13 @@ func.func @wraps() -> i8 {
   %mod = arith.remsi %val, %c64 : i8
   return %mod : i8
 }
+
+// -----
+
+// CHECK-LABEL: func @create_poison_op
+// CHECK: %[[RES:.*]] = ub.poison : i32
+// CHECK: return %[[RES]]
+func.func @create_poison_op() -> i32 {
+  %val = test.with_bounds { umin = 1 : i32, umax = 0 : i32, smin = 1 : i32, smax = 0 : i32 } : i32
+  return %val : i32
+}
diff --git a/mlir/test/Dialect/UB/int-range-interface.mlir b/mlir/test/Dialect/UB/int-range-interface.mlir
new file mode 100644
index 0000000000000..69f4923ffe6c7
--- /dev/null
+++ b/mlir/test/Dialect/UB/int-range-interface.mlir
@@ -0,0 +1,24 @@
+// RUN: mlir-opt --int-range-optimizations %s | FileCheck %s
+
+// CHECK-LABEL: func @poison
+// CHECK: test.reflect_bounds {smax = -1 : si32, smin = 0 : si32, umax = 0 : ui32, umin = 1 : ui32}
+func.func @poison() -> i32 {
+    %0 = ub.poison : i32
+    %1 = test.reflect_bounds %0 : i32
+    func.return %1 : i32
+}
+
+// CHECK-LABEL: func @poison_i1
+// CHECK: test.reflect_bounds {smax = -1 : si1, smin = 0 : si1, umax = 0 : ui1, umin = 1 : ui1}
+func.func @poison_i1() -> i1 {
+    %0 = ub.poison : i1
+    %1 = test.reflect_bounds %0 : i1
+    func.return %1 : i1
+}
+
+// CHECK-LABEL: func @poison_non_int
+// Check it doesn't crash.
+func.func @poison_non_int() -> f32 {
+    %0 = ub.poison : f32
+    func.return %0 : f32
+}
diff --git a/mlir/test/Dialect/Vector/int-range-interface.mlir b/mlir/test/Dialect/Vector/int-range-interface.mlir
index b2f16bb3dac9c..3182ac6bf8b4b 100644
--- a/mlir/test/Dialect/Vector/int-range-interface.mlir
+++ b/mlir/test/Dialect/Vector/int-range-interface.mlir
@@ -116,3 +116,20 @@ func.func @vector_step() -> vector<8xindex> {
   %1 = test.reflect_bounds %0 : vector<8xindex>
   func.return %1 : vector<8xindex>
 }
+
+// CHECK-LABEL: func @poison_vector_insert
+// CHECK: test.reflect_bounds {smax = 4 : index, smin = 1 : index, umax = 4 : index, umin = 1 : index}
+func.func @poison_vector_insert() -> vector<4xindex> {
+  %0 = ub.poison : vector<4xindex>
+  %1 = test.with_bounds { umin = 1 : index, umax = 1 : index, smin = 1 : index, smax = 1 : index } : index
+  %2 = test.with_bounds { umin = 2 : index, umax = 2 : index, smin = 2 : index, smax = 2 : index } : index
+  %3 = test.with_bounds { umin = 3 : index, umax = 3 : index, smin = 3 : index, smax = 3 : index } : index
+  %4 = test.with_bounds { umin = 4 : index, umax = 4 : index, smin = 4 : index, smax = 4 : index } : index
+  %5 = vector.insert %1, %0[0] : index into vector<4xindex>
+  %6 = vector.insert %2, %5[1] : index into vector<4xindex>
+  %7 = vector.insert %3, %6[2] : index into vector<4xindex>
+  %8 = vector.insert %4, %7[3] : index into vector<4xindex>
+
+  %9 = test.reflect_bounds %8 : vector<4xindex>
+  func.return %9 : vector<4xindex>
+}



More information about the Mlir-commits mailing list