[Mlir-commits] [mlir] 25ad2ee - [mlir][IntegerRangeAnalysis] Don't unsoundly update constant lattice (#193546)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Apr 23 05:29:02 PDT 2026


Author: Krzysztof Drewniak
Date: 2026-04-23T14:28:58+02:00
New Revision: 25ad2ee86da1642431b98c3e3c45942c80ff7dbf

URL: https://github.com/llvm/llvm-project/commit/25ad2ee86da1642431b98c3e3c45942c80ff7dbf
DIFF: https://github.com/llvm/llvm-project/commit/25ad2ee86da1642431b98c3e3c45942c80ff7dbf.diff

LOG: [mlir][IntegerRangeAnalysis] Don't unsoundly update constant lattice (#193546)

Fixes #119045

Integer range analysis tried to be clever and update the constant value
lattice when it inferred something to be a constant. However, this
caused correctness issues, because the integer range analysis can go
from "constant" to "not a constant" once more control edges are
analyzed, but the constant value lattice is used by dead code
elimination to skip branches ... including those that might prove that
the conditional for the branch isn't actually a constant.

Unlike in my ancient attempt at a fix, the solution is just to stop
trying to be clever.

This change required loading the constant propagation analysis in
-arith-unsigned-when-equivalent to keep tests working (which should have
been loaded anyway to make the dead code analysis work correctly).

No AI tools used.

Added: 
    mlir/test/Dialect/Arith/int-range-opts-bug-119045.mlir

Modified: 
    mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h
    mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp
    mlir/lib/Dialect/Arith/Transforms/UnsignedWhenEquivalent.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h
index 9820a91291fdb..5b6ae9bf84265 100644
--- a/mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h
+++ b/mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h
@@ -26,15 +26,9 @@ class RewriterBase;
 namespace dataflow {
 
 /// This lattice element represents the integer value range of an SSA value.
-/// When this lattice is updated, it automatically updates the constant value
-/// of the SSA value (if the range can be narrowed to one).
 class IntegerValueRangeLattice : public Lattice<IntegerValueRange> {
 public:
   using Lattice::Lattice;
-
-  /// If the range can be narrowed to an integer constant, update the constant
-  /// value of the SSA value.
-  void onUpdate(DataFlowSolver *solver) const override;
 };
 
 /// Integer range analysis determines the integer value range of SSA values

diff  --git a/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp b/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp
index 818450e2bc696..b29fc28131806 100644
--- a/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp
@@ -13,7 +13,6 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Analysis/DataFlow/IntegerRangeAnalysis.h"
-#include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h"
 #include "mlir/Analysis/DataFlow/SparseAnalysis.h"
 #include "mlir/Analysis/DataFlowFramework.h"
 #include "mlir/IR/BuiltinAttributes.h"
@@ -59,37 +58,6 @@ LogicalResult staticallyNonNegative(DataFlowSolver &solver, Operation *op) {
 }
 } // namespace mlir::dataflow
 
-void IntegerValueRangeLattice::onUpdate(DataFlowSolver *solver) const {
-  Lattice::onUpdate(solver);
-
-  // If the integer range can be narrowed to a constant, update the constant
-  // value of the SSA value.
-  std::optional<APInt> constant = getValue().getValue().getConstantValue();
-  auto value = cast<Value>(anchor);
-  auto *cv = solver->getOrCreateState<Lattice<ConstantValue>>(value);
-  if (!constant)
-    return solver->propagateIfChanged(
-        cv, cv->join(ConstantValue::getUnknownConstant()));
-
-  Dialect *dialect;
-  if (auto *parent = value.getDefiningOp())
-    dialect = parent->getDialect();
-  else
-    dialect = value.getParentBlock()->getParentOp()->getDialect();
-
-  Attribute cstAttr;
-  if (isa<IntegerType, IndexType>(value.getType())) {
-    cstAttr = IntegerAttr::get(value.getType(), *constant);
-  } else if (auto shapedTy = dyn_cast<ShapedType>(value.getType())) {
-    cstAttr = SplatElementsAttr::get(shapedTy, *constant);
-  } else {
-    llvm::report_fatal_error(
-        Twine("FIXME: Don't know how to create a constant for this type: ") +
-        mlir::debugString(value.getType()));
-  }
-  solver->propagateIfChanged(cv, cv->join(ConstantValue(cstAttr, dialect)));
-}
-
 LogicalResult IntegerRangeAnalysis::visitOperation(
     Operation *op, ArrayRef<const IntegerValueRangeLattice *> operands,
     ArrayRef<IntegerValueRangeLattice *> results) {

diff  --git a/mlir/lib/Dialect/Arith/Transforms/UnsignedWhenEquivalent.cpp b/mlir/lib/Dialect/Arith/Transforms/UnsignedWhenEquivalent.cpp
index bd1eac16070eb..c9eaa66d6ea49 100644
--- a/mlir/lib/Dialect/Arith/Transforms/UnsignedWhenEquivalent.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/UnsignedWhenEquivalent.cpp
@@ -8,6 +8,7 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h"
 #include "mlir/Dialect/Arith/Transforms/Passes.h"
 
 #include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
@@ -125,6 +126,7 @@ struct ArithUnsignedWhenEquivalentPass
     Operation *op = getOperation();
     MLIRContext *ctx = op->getContext();
     DataFlowSolver solver;
+    solver.load<SparseConstantPropagation>();
     solver.load<DeadCodeAnalysis>();
     solver.load<IntegerRangeAnalysis>();
     if (failed(solver.initializeAndRun(op)))

diff  --git a/mlir/test/Dialect/Arith/int-range-opts-bug-119045.mlir b/mlir/test/Dialect/Arith/int-range-opts-bug-119045.mlir
new file mode 100644
index 0000000000000..a2cf72a00e019
--- /dev/null
+++ b/mlir/test/Dialect/Arith/int-range-opts-bug-119045.mlir
@@ -0,0 +1,61 @@
+// RUN: mlir-opt -int-range-optimizations %s
+
+// Note: I wish I had a simpler example than this, but getting rid of a
+// bunch of the arithmetic made the issue go away.
+// CHECK-LABEL: @blocks_prematurely_declared_dead_bug
+// CHECK-NOT: arith.constant true
+// CHECK-COUNT-4: cf.cond_br
+// CHECK: return
+func.func @blocks_prematurely_declared_dead_bug(%mem: memref<?xf16>) {
+  %cst = arith.constant dense<false> : vector<1xi1>
+  %c1 = arith.constant 1 : index
+  %cst_0 = arith.constant dense<0.000000e+00> : vector<1xf16>
+  %cst_1 = arith.constant 0.000000e+00 : f16
+  %c16 = arith.constant 16 : index
+  %c0 = arith.constant 0 : index
+  %c64 = arith.constant 64 : index
+  %thread_id_x = gpu.thread_id  x upper_bound 64
+  %6 = test.with_bounds { smin = 16 : index, smax = 112 : index, umin = 16 : index, umax = 112 : index } : index
+  %8 = arith.divui %6, %c16 : index
+  %9 = arith.muli %8, %c16 : index
+  cf.br ^bb1(%c0 : index)
+^bb1(%12: index):  // 2 preds: ^bb0, ^bb7
+  %13 = arith.cmpi slt, %12, %9 : index
+  cf.cond_br %13, ^bb2, ^bb8
+^bb2:  // pred: ^bb1
+  %14 = arith.subi %9, %12 : index
+  %15 = arith.minsi %14, %c64 : index
+  %16 = arith.subi %15, %thread_id_x : index
+  %17 = vector.constant_mask [1] : vector<1xi1>
+  %18 = arith.cmpi sgt, %16, %c0 : index
+  %19 = arith.select %18, %17, %cst : vector<1xi1>
+  %20 = vector.extract %19[0] : i1 from vector<1xi1>
+  %21 = vector.insert %20, %cst [0] : i1 into vector<1xi1>
+  %22 = arith.addi %12, %thread_id_x : index
+  cf.br ^bb3(%c0, %cst_0 : index, vector<1xf16>)
+^bb3(%23: index, %24: vector<1xf16>):  // 2 preds: ^bb2, ^bb6
+  %25 = arith.cmpi slt, %23, %c1 : index
+  cf.cond_br %25, ^bb4, ^bb7
+^bb4:  // pred: ^bb3
+  %26 = vector.extract %21[%23] : i1 from vector<1xi1>
+  cf.cond_br %26, ^bb5, ^bb6(%24 : vector<1xf16>)
+^bb5:  // pred: ^bb4
+  %27 = arith.addi %22, %23 : index
+  %28 = memref.load %mem[%27] : memref<?xf16>
+  %29 = vector.insert %28, %24[%23] : f16 into vector<1xf16>
+  cf.br ^bb6(%29 : vector<1xf16>)
+^bb6(%30: vector<1xf16>):  // 2 preds: ^bb4, ^bb5
+  %31 = arith.addi %23, %c1 : index
+  cf.br ^bb3(%31, %30 : index, vector<1xf16>)
+^bb7:  // pred: ^bb3
+  %37 = arith.addi %12, %c64 : index
+  cf.br ^bb1(%37 : index)
+^bb8:  // pred: ^bb1
+  %70 = arith.cmpi eq, %thread_id_x, %c0 : index
+  cf.cond_br %70, ^bb9, ^bb10
+^bb9:  // pred: ^bb8
+  memref.store %cst_1, %mem[%c0] : memref<?xf16>
+  cf.br ^bb10
+^bb10:  // 2 preds: ^bb8, ^bb9
+  return
+}


        


More information about the Mlir-commits mailing list