[Mlir-commits] [mlir] [mlir] [DataFlow] Fix bug in int-range-analysis (PR #126708)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Feb 11 02:12:28 PST 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: donald chen (cxy-1993)
<details>
<summary>Changes</summary>
This patch fix bug in int range analysis:
When querying the lower bound and upper bound of loop to update the value range of a loop iteration variable, the program point to depend on should be the block corresponding to the iteration variable rather than the loop operation.
---
Full diff: https://github.com/llvm/llvm-project/pull/126708.diff
2 Files Affected:
- (modified) mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp (+6-5)
- (modified) mlir/test/Interfaces/InferIntRangeInterface/infer-int-range-test-ops.mlir (+30)
``````````diff
diff --git a/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp b/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp
index 9e9411e5ede12c8..722f4df18e9818c 100644
--- a/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp
@@ -152,7 +152,7 @@ void IntegerRangeAnalysis::visitNonControlFlowArguments(
/// on a LoopLikeInterface return the lower/upper bound for that result if
/// possible.
auto getLoopBoundFromFold = [&](std::optional<OpFoldResult> loopBound,
- Type boundType, bool getUpper) {
+ Type boundType, Block *block, bool getUpper) {
unsigned int width = ConstantIntRanges::getStorageBitwidth(boundType);
if (loopBound.has_value()) {
if (auto attr = dyn_cast<Attribute>(*loopBound)) {
@@ -160,7 +160,7 @@ void IntegerRangeAnalysis::visitNonControlFlowArguments(
return bound.getValue();
} else if (auto value = llvm::dyn_cast_if_present<Value>(*loopBound)) {
const IntegerValueRangeLattice *lattice =
- getLatticeElementFor(getProgramPointAfter(op), value);
+ getLatticeElementFor(getProgramPointBefore(block), value);
if (lattice != nullptr && !lattice->getValue().isUninitialized())
return getUpper ? lattice->getValue().getValue().smax()
: lattice->getValue().getValue().smin();
@@ -180,16 +180,17 @@ void IntegerRangeAnalysis::visitNonControlFlowArguments(
return SparseForwardDataFlowAnalysis ::visitNonControlFlowArguments(
op, successor, argLattices, firstIndex);
}
+ Block *block = iv->getParentBlock();
std::optional<OpFoldResult> lowerBound = loop.getSingleLowerBound();
std::optional<OpFoldResult> upperBound = loop.getSingleUpperBound();
std::optional<OpFoldResult> step = loop.getSingleStep();
- APInt min = getLoopBoundFromFold(lowerBound, iv->getType(),
+ APInt min = getLoopBoundFromFold(lowerBound, iv->getType(), block,
/*getUpper=*/false);
- APInt max = getLoopBoundFromFold(upperBound, iv->getType(),
+ APInt max = getLoopBoundFromFold(upperBound, iv->getType(), block,
/*getUpper=*/true);
// Assume positivity for uniscoverable steps by way of getUpper = true.
APInt stepVal =
- getLoopBoundFromFold(step, iv->getType(), /*getUpper=*/true);
+ getLoopBoundFromFold(step, iv->getType(), block, /*getUpper=*/true);
if (stepVal.isNegative()) {
std::swap(min, max);
diff --git a/mlir/test/Interfaces/InferIntRangeInterface/infer-int-range-test-ops.mlir b/mlir/test/Interfaces/InferIntRangeInterface/infer-int-range-test-ops.mlir
index 1ec3441b1fde817..b98e8b07db5ce2b 100644
--- a/mlir/test/Interfaces/InferIntRangeInterface/infer-int-range-test-ops.mlir
+++ b/mlir/test/Interfaces/InferIntRangeInterface/infer-int-range-test-ops.mlir
@@ -154,3 +154,33 @@ func.func @dont_propagate_across_infinite_loop() -> index {
return %2 : index
}
+// CHECK-LABEL: @propagate_from_block_to_iterarg
+func.func @propagate_from_block_to_iterarg(%arg0: index, %arg1: i1) {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %0 = scf.if %arg1 -> (index) {
+ %1 = scf.if %arg1 -> (index) {
+ scf.yield %arg0 : index
+ } else {
+ scf.yield %arg0 : index
+ }
+ scf.yield %1 : index
+ } else {
+ scf.yield %c1 : index
+ }
+ scf.for %arg2 = %c0 to %arg0 step %c1 {
+ scf.if %arg1 {
+ %1 = arith.subi %0, %c1 : index
+ %2 = arith.muli %0, %1 : index
+ %3 = arith.addi %2, %c1 : index
+ scf.for %arg3 = %c0 to %3 step %c1 {
+ %4 = arith.cmpi uge, %arg3, %c1 : index
+ // CHECK-NOT: scf.if %false
+ scf.if %4 {
+ "test.foo"() : () -> ()
+ }
+ }
+ }
+ }
+ return
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/126708
More information about the Mlir-commits
mailing list