[Mlir-commits] [mlir] f15a6c9 - [mlir] [DataFlow] Fix bug in int-range-analysis (#126708)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Feb 11 17:59:01 PST 2025
Author: donald chen
Date: 2025-02-12T09:58:56+08:00
New Revision: f15a6c99fa552f82dad46e6bf3c8ff958c8b6e7f
URL: https://github.com/llvm/llvm-project/commit/f15a6c99fa552f82dad46e6bf3c8ff958c8b6e7f
DIFF: https://github.com/llvm/llvm-project/commit/f15a6c99fa552f82dad46e6bf3c8ff958c8b6e7f.diff
LOG: [mlir] [DataFlow] Fix bug in int-range-analysis (#126708)
When querying the lower bound and upper bound of loop to update the
value range of a loop iteration variable, the program point to depend on
should be the block corresponding to the iteration variable rather than
the loop operation.
Added:
Modified:
mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp
mlir/test/Interfaces/InferIntRangeInterface/infer-int-range-test-ops.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp b/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp
index 9e9411e5ede12..722f4df18e981 100644
--- a/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp
@@ -152,7 +152,7 @@ void IntegerRangeAnalysis::visitNonControlFlowArguments(
/// on a LoopLikeInterface return the lower/upper bound for that result if
/// possible.
auto getLoopBoundFromFold = [&](std::optional<OpFoldResult> loopBound,
- Type boundType, bool getUpper) {
+ Type boundType, Block *block, bool getUpper) {
unsigned int width = ConstantIntRanges::getStorageBitwidth(boundType);
if (loopBound.has_value()) {
if (auto attr = dyn_cast<Attribute>(*loopBound)) {
@@ -160,7 +160,7 @@ void IntegerRangeAnalysis::visitNonControlFlowArguments(
return bound.getValue();
} else if (auto value = llvm::dyn_cast_if_present<Value>(*loopBound)) {
const IntegerValueRangeLattice *lattice =
- getLatticeElementFor(getProgramPointAfter(op), value);
+ getLatticeElementFor(getProgramPointBefore(block), value);
if (lattice != nullptr && !lattice->getValue().isUninitialized())
return getUpper ? lattice->getValue().getValue().smax()
: lattice->getValue().getValue().smin();
@@ -180,16 +180,17 @@ void IntegerRangeAnalysis::visitNonControlFlowArguments(
return SparseForwardDataFlowAnalysis ::visitNonControlFlowArguments(
op, successor, argLattices, firstIndex);
}
+ Block *block = iv->getParentBlock();
std::optional<OpFoldResult> lowerBound = loop.getSingleLowerBound();
std::optional<OpFoldResult> upperBound = loop.getSingleUpperBound();
std::optional<OpFoldResult> step = loop.getSingleStep();
- APInt min = getLoopBoundFromFold(lowerBound, iv->getType(),
+ APInt min = getLoopBoundFromFold(lowerBound, iv->getType(), block,
/*getUpper=*/false);
- APInt max = getLoopBoundFromFold(upperBound, iv->getType(),
+ APInt max = getLoopBoundFromFold(upperBound, iv->getType(), block,
/*getUpper=*/true);
// Assume positivity for uniscoverable steps by way of getUpper = true.
APInt stepVal =
- getLoopBoundFromFold(step, iv->getType(), /*getUpper=*/true);
+ getLoopBoundFromFold(step, iv->getType(), block, /*getUpper=*/true);
if (stepVal.isNegative()) {
std::swap(min, max);
diff --git a/mlir/test/Interfaces/InferIntRangeInterface/infer-int-range-test-ops.mlir b/mlir/test/Interfaces/InferIntRangeInterface/infer-int-range-test-ops.mlir
index 1ec3441b1fde8..b98e8b07db5ce 100644
--- a/mlir/test/Interfaces/InferIntRangeInterface/infer-int-range-test-ops.mlir
+++ b/mlir/test/Interfaces/InferIntRangeInterface/infer-int-range-test-ops.mlir
@@ -154,3 +154,33 @@ func.func @dont_propagate_across_infinite_loop() -> index {
return %2 : index
}
+// CHECK-LABEL: @propagate_from_block_to_iterarg
+func.func @propagate_from_block_to_iterarg(%arg0: index, %arg1: i1) {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %0 = scf.if %arg1 -> (index) {
+ %1 = scf.if %arg1 -> (index) {
+ scf.yield %arg0 : index
+ } else {
+ scf.yield %arg0 : index
+ }
+ scf.yield %1 : index
+ } else {
+ scf.yield %c1 : index
+ }
+ scf.for %arg2 = %c0 to %arg0 step %c1 {
+ scf.if %arg1 {
+ %1 = arith.subi %0, %c1 : index
+ %2 = arith.muli %0, %1 : index
+ %3 = arith.addi %2, %c1 : index
+ scf.for %arg3 = %c0 to %3 step %c1 {
+ %4 = arith.cmpi uge, %arg3, %c1 : index
+ // CHECK-NOT: scf.if %false
+ scf.if %4 {
+ "test.foo"() : () -> ()
+ }
+ }
+ }
+ }
+ return
+}
More information about the Mlir-commits
mailing list