[Mlir-commits] [mlir] [mlir] [DataFlow] Fix bug in int-range-analysis (PR #126708)

donald chen llvmlistbot at llvm.org
Tue Feb 11 02:14:17 PST 2025


https://github.com/cxy-1993 updated https://github.com/llvm/llvm-project/pull/126708

>From 4755da6907951c7034310910227a817d3d607fec Mon Sep 17 00:00:00 2001
From: donald chen <chenxunyu1993 at gmail.com>
Date: Tue, 11 Feb 2025 10:07:26 +0000
Subject: [PATCH] [mlir] [DataFlow] Fix bug in int-range-analysis

When querying the lower bound and upper bound of loop to update the value range
of a loop iteration variable, the program point to depend on should be the block
corresponding to the iteration variable rather than the loop operation.
---
 .../DataFlow/IntegerRangeAnalysis.cpp         | 11 +++----
 .../infer-int-range-test-ops.mlir             | 30 +++++++++++++++++++
 2 files changed, 36 insertions(+), 5 deletions(-)

diff --git a/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp b/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp
index 9e9411e5ede12c8..722f4df18e9818c 100644
--- a/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp
@@ -152,7 +152,7 @@ void IntegerRangeAnalysis::visitNonControlFlowArguments(
   /// on a LoopLikeInterface return the lower/upper bound for that result if
   /// possible.
   auto getLoopBoundFromFold = [&](std::optional<OpFoldResult> loopBound,
-                                  Type boundType, bool getUpper) {
+                                  Type boundType, Block *block, bool getUpper) {
     unsigned int width = ConstantIntRanges::getStorageBitwidth(boundType);
     if (loopBound.has_value()) {
       if (auto attr = dyn_cast<Attribute>(*loopBound)) {
@@ -160,7 +160,7 @@ void IntegerRangeAnalysis::visitNonControlFlowArguments(
           return bound.getValue();
       } else if (auto value = llvm::dyn_cast_if_present<Value>(*loopBound)) {
         const IntegerValueRangeLattice *lattice =
-            getLatticeElementFor(getProgramPointAfter(op), value);
+            getLatticeElementFor(getProgramPointBefore(block), value);
         if (lattice != nullptr && !lattice->getValue().isUninitialized())
           return getUpper ? lattice->getValue().getValue().smax()
                           : lattice->getValue().getValue().smin();
@@ -180,16 +180,17 @@ void IntegerRangeAnalysis::visitNonControlFlowArguments(
       return SparseForwardDataFlowAnalysis ::visitNonControlFlowArguments(
           op, successor, argLattices, firstIndex);
     }
+    Block *block = iv->getParentBlock();
     std::optional<OpFoldResult> lowerBound = loop.getSingleLowerBound();
     std::optional<OpFoldResult> upperBound = loop.getSingleUpperBound();
     std::optional<OpFoldResult> step = loop.getSingleStep();
-    APInt min = getLoopBoundFromFold(lowerBound, iv->getType(),
+    APInt min = getLoopBoundFromFold(lowerBound, iv->getType(), block,
                                      /*getUpper=*/false);
-    APInt max = getLoopBoundFromFold(upperBound, iv->getType(),
+    APInt max = getLoopBoundFromFold(upperBound, iv->getType(), block,
                                      /*getUpper=*/true);
     // Assume positivity for uniscoverable steps by way of getUpper = true.
     APInt stepVal =
-        getLoopBoundFromFold(step, iv->getType(), /*getUpper=*/true);
+        getLoopBoundFromFold(step, iv->getType(), block, /*getUpper=*/true);
 
     if (stepVal.isNegative()) {
       std::swap(min, max);
diff --git a/mlir/test/Interfaces/InferIntRangeInterface/infer-int-range-test-ops.mlir b/mlir/test/Interfaces/InferIntRangeInterface/infer-int-range-test-ops.mlir
index 1ec3441b1fde817..b98e8b07db5ce2b 100644
--- a/mlir/test/Interfaces/InferIntRangeInterface/infer-int-range-test-ops.mlir
+++ b/mlir/test/Interfaces/InferIntRangeInterface/infer-int-range-test-ops.mlir
@@ -154,3 +154,33 @@ func.func @dont_propagate_across_infinite_loop() -> index {
   return %2 : index
 }
 
+// CHECK-LABEL: @propagate_from_block_to_iterarg
+func.func @propagate_from_block_to_iterarg(%arg0: index, %arg1: i1) {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %0 = scf.if %arg1 -> (index) {
+    %1 = scf.if %arg1 -> (index) {
+      scf.yield %arg0 : index
+    } else {
+      scf.yield %arg0 : index
+    }
+    scf.yield %1 : index
+  } else {
+    scf.yield %c1 : index
+  }
+  scf.for %arg2 = %c0 to %arg0 step %c1 {
+    scf.if %arg1 {
+      %1 = arith.subi %0, %c1 : index
+      %2 = arith.muli %0, %1 : index
+      %3 = arith.addi %2, %c1 : index
+      scf.for %arg3 = %c0 to %3 step %c1 {
+        %4 = arith.cmpi uge, %arg3, %c1 : index
+        // CHECK-NOT: scf.if %false
+        scf.if %4 {
+          "test.foo"() : () -> ()
+        }
+      }
+    }
+  }
+  return
+}



More information about the Mlir-commits mailing list