[Mlir-commits] [mlir] [mlir][liveness] fix bugs in liveness analysis (PR #133416)

donald chen llvmlistbot at llvm.org
Fri Mar 28 03:57:23 PDT 2025


https://github.com/cxy-1993 updated https://github.com/llvm/llvm-project/pull/133416

>From d66973aaca6406b71384d533d898965b0cee2351 Mon Sep 17 00:00:00 2001
From: donald chen <chenxunyu1993 at gmail.com>
Date: Fri, 28 Mar 2025 18:22:36 +0800
Subject: [PATCH 1/2] [mlir][liveness] fix bugs in liveness analysis

This patch fixes the following bugs:
  - In SparseBackwardAnalysis, the setToExitState function should propagate
  changes if it modifies the lattice. Previously, this issue was masked
  because multi-block scenarios were not tested, and the traversal order of
  backward data flow analysis starts from the end of the program.
  - The method in liveness analysis for determining whether the non-forwarded
  operand in branch/region branch operations is live is incorrect, which may
  cause originally live variables to be marked as not live.
---
 .../mlir/Analysis/DataFlow/SparseAnalysis.h   |  6 +-
 .../Analysis/DataFlow/LivenessAnalysis.cpp    | 81 +++++++++++++------
 .../DataFlow/test-liveness-analysis.mlir      | 37 ++++++++-
 3 files changed, 95 insertions(+), 29 deletions(-)

diff --git a/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
index b9cb549a0e438..1b2c679176107 100644
--- a/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
+++ b/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
@@ -413,10 +413,12 @@ class AbstractSparseBackwardDataFlowAnalysis : public DataFlowAnalysis {
   // Visit operands on call instructions that are not forwarded.
   virtual void visitCallOperand(OpOperand &operand) = 0;
 
-  /// Set the given lattice element(s) at control flow exit point(s).
+  /// Set the given lattice element(s) at control flow exit point(s) and
+  /// propagate the update if it chaned.
   virtual void setToExitState(AbstractSparseLattice *lattice) = 0;
 
-  /// Set the given lattice element(s) at control flow exit point(s).
+  /// Set the given lattice element(s) at control flow exit point(s) and
+  /// propagate the update if it chaned.
   void setAllToExitStates(ArrayRef<AbstractSparseLattice *> lattices);
 
   /// Get the lattice element for a value.
diff --git a/mlir/lib/Analysis/DataFlow/LivenessAnalysis.cpp b/mlir/lib/Analysis/DataFlow/LivenessAnalysis.cpp
index 9fb4d9df2530d..07d2d400d6a7b 100644
--- a/mlir/lib/Analysis/DataFlow/LivenessAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/LivenessAnalysis.cpp
@@ -59,7 +59,9 @@ ChangeResult Liveness::meet(const AbstractSparseLattice &other) {
 ///   (1.a) is an operand of an op with memory effects OR
 ///   (1.b) is a non-forwarded branch operand and its branch op could take the
 ///   control to a block that has an op with memory effects OR
-///   (1.c) is a non-forwarded call operand.
+///   (1.c) is a non-forwarded branch operand and its branch op could result
+///   in different result OR
+///   (1.d) is a non-forwarded call operand.
 ///
 /// A value `A` is said to be "used to compute" value `B` iff `B` cannot be
 /// computed in the absence of `A`. Thus, in this implementation, we say that
@@ -106,51 +108,76 @@ void LivenessAnalysis::visitBranchOperand(OpOperand &operand) {
   // the forwarded branch operands or the non-branch operands. Thus they need
   // to be handled separately. This is where we handle them.
 
-  // This marks values of type (1.b) liveness as "live". A non-forwarded
+  // This marks values of type (1.b/1.c) liveness as "live". A non-forwarded
   // branch operand will be live if a block where its op could take the control
-  // has an op with memory effects.
+  // has an op with memory effects or could result in different results.
   // Populating such blocks in `blocks`.
+  bool mayLive = false;
   SmallVector<Block *, 4> blocks;
   if (isa<RegionBranchOpInterface>(op)) {
-    // When the op is a `RegionBranchOpInterface`, like an `scf.for` or an
-    // `scf.index_switch` op, its branch operand controls the flow into this
-    // op's regions.
-    for (Region &region : op->getRegions()) {
-      for (Block &block : region)
-        blocks.push_back(&block);
+    if (op->getNumResults() != 0) {
+      // This mark value of type 1.c liveness as may live, because the region
+      // branch operation has a return value, and the non-forwarded operand can
+      // determine the region to jump to, it can thereby control the result of
+      // the region branch operation.
+      // Therefore, we conservatively consider the non-forwarded operand of the
+      // region branch operation with result may live.
+      mayLive = true;
+    } else {
+      // When the op is a `RegionBranchOpInterface`, like an `scf.for` or an
+      // `scf.index_switch` op, its branch operand controls the flow into this
+      // op's regions.
+      for (Region &region : op->getRegions()) {
+        for (Block &block : region)
+          blocks.push_back(&block);
+      }
     }
   } else if (isa<BranchOpInterface>(op)) {
-    // When the op is a `BranchOpInterface`, like a `cf.cond_br` or a
-    // `cf.switch` op, its branch operand controls the flow into this op's
-    // successors.
-    blocks = op->getSuccessors();
+    // We cannot track all successor blocks of the branch operation(More
+    // specifically, it's the successor's successor). Additionally, different
+    // blocks might also lead to the different block argument described in 1.c.
+    // Therefore, we conservatively consider the non-forwarded operand of the
+    // branch operation may live.
+    mayLive = true;
   } else {
-    // When the op is a `RegionBranchTerminatorOpInterface`, like an
-    // `scf.condition` op or return-like, like an `scf.yield` op, its branch
-    // operand controls the flow into this op's parent's (which is a
-    // `RegionBranchOpInterface`'s) regions.
     Operation *parentOp = op->getParentOp();
     assert(isa<RegionBranchOpInterface>(parentOp) &&
            "expected parent op to implement `RegionBranchOpInterface`");
-    for (Region &region : parentOp->getRegions()) {
-      for (Block &block : region)
-        blocks.push_back(&block);
+    if (parentOp->getNumResults() != 0) {
+      // This mark value of type 1.c liveness as may live, because the region
+      // branch operation has a return value, and the non-forwarded operand can
+      // determine the region to jump to, it can thereby control the result of
+      // the region branch operation.
+      // Therefore, we conservatively consider the non-forwarded operand of the
+      // region branch operation with result may live.
+      mayLive = true;
+    } else {
+      // When the op is a `RegionBranchTerminatorOpInterface`, like an
+      // `scf.condition` op or return-like, like an `scf.yield` op, its branch
+      // operand controls the flow into this op's parent's (which is a
+      // `RegionBranchOpInterface`'s) regions.
+      for (Region &region : parentOp->getRegions()) {
+        for (Block &block : region)
+          blocks.push_back(&block);
+      }
     }
   }
-  bool foundMemoryEffectingOp = false;
   for (Block *block : blocks) {
-    if (foundMemoryEffectingOp)
+    if (mayLive)
       break;
     for (Operation &nestedOp : *block) {
       if (!isMemoryEffectFree(&nestedOp)) {
-        Liveness *operandLiveness = getLatticeElement(operand.get());
-        propagateIfChanged(operandLiveness, operandLiveness->markLive());
-        foundMemoryEffectingOp = true;
+        mayLive = true;
         break;
       }
     }
   }
 
+  if (mayLive) {
+    Liveness *operandLiveness = getLatticeElement(operand.get());
+    propagateIfChanged(operandLiveness, operandLiveness->markLive());
+  }
+
   // Now that we have checked for memory-effecting ops in the blocks of concern,
   // we will simply visit the op with this non-forwarded operand to potentially
   // mark it "live" due to type (1.a/3) liveness.
@@ -191,8 +218,12 @@ void LivenessAnalysis::visitCallOperand(OpOperand &operand) {
 }
 
 void LivenessAnalysis::setToExitState(Liveness *lattice) {
+  if (lattice->isLive) {
+    return;
+  }
   // This marks values of type (2) liveness as "live".
   (void)lattice->markLive();
+  propagateIfChanged(lattice, ChangeResult::Change);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Analysis/DataFlow/test-liveness-analysis.mlir b/mlir/test/Analysis/DataFlow/test-liveness-analysis.mlir
index b6aed1c0b054e..a89a0f4084e99 100644
--- a/mlir/test/Analysis/DataFlow/test-liveness-analysis.mlir
+++ b/mlir/test/Analysis/DataFlow/test-liveness-analysis.mlir
@@ -59,16 +59,49 @@ func.func @test_3_BranchOpInterface_type_1.b(%arg0: i32, %arg1: memref<i32>, %ar
 
 // -----
 
+// Positive test: Type(1.c) "is a non-forwarded branch operand and its branch
+// op could result in different result"
+// CHECK-LABEL: test_tag: cond_br:
+// CHECK-NEXT:  operand #0: live
+// CHECK-NEXT:  operand #1: live
+// CHECK-NEXT:  operand #2: live
+func.func @test_branch_result_in_different_result_1.c(%arg0 : tensor<f32>, %arg1 : tensor<f32>, %arg2 : i1) -> tensor<f32> {
+  cf.cond_br %arg2, ^bb1(%arg0 : tensor<f32>), ^bb2(%arg1 : tensor<f32>) {tag = "cond_br"}
+^bb1(%0 : tensor<f32>):
+  cf.br ^bb3(%0 : tensor<f32>)
+^bb2(%1 : tensor<f32>):
+  cf.br ^bb3(%1 : tensor<f32>)
+^bb3(%2 : tensor<f32>):
+  return %2 : tensor<f32>
+}
+
+// -----
+
+// Positive test: Type(1.c) "is a non-forwarded branch operand and its branch
+// op could result in different result"
+// CHECK-LABEL: test_tag: region_branch:
+// CHECK-NEXT:  operand #0: live
+func.func @test_region_branch_result_in_different_result_1.c(%arg0 : tensor<f32>, %arg1 : tensor<f32>, %arg2 : i1) -> tensor<f32> {
+  %0 = scf.if %arg2 -> tensor<f32> {
+    scf.yield %arg0 : tensor<f32>
+  } else {
+    scf.yield %arg1 : tensor<f32>
+  } {tag="region_branch"}
+  return %0 : tensor<f32>
+}
+
+// -----
+
 func.func private @private(%arg0 : i32, %arg1 : i32) {
   func.return
 }
 
-// Positive test: Type (1.c) "is a non-forwarded call operand"
+// Positive test: Type (1.d) "is a non-forwarded call operand"
 // CHECK-LABEL: test_tag: call
 // CHECK-LABEL:  operand #0: not live
 // CHECK-LABEL:  operand #1: not live
 // CHECK-LABEL:  operand #2: live
-func.func @test_4_type_1.c(%arg0: i32, %arg1: i32, %device: i32, %m0: memref<i32>) {
+func.func @test_4_type_1.d(%arg0: i32, %arg1: i32, %device: i32, %m0: memref<i32>) {
   test.call_on_device @private(%arg0, %arg1), %device {tag = "call"} : (i32, i32, i32) -> ()
   return
 }

>From e1f8c146cf3f1fbd4bd566fb095ee1cb179def30 Mon Sep 17 00:00:00 2001
From: donald chen <chenxunyu1993 at gmail.com>
Date: Fri, 28 Mar 2025 18:55:29 +0800
Subject: [PATCH 2/2] add more precise analysis for non-forwarded operand in
 branch operation

---
 .../Analysis/DataFlow/LivenessAnalysis.cpp    | 26 ++++++++++++++-----
 1 file changed, 19 insertions(+), 7 deletions(-)

diff --git a/mlir/lib/Analysis/DataFlow/LivenessAnalysis.cpp b/mlir/lib/Analysis/DataFlow/LivenessAnalysis.cpp
index 07d2d400d6a7b..c12149a1a0242 100644
--- a/mlir/lib/Analysis/DataFlow/LivenessAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/LivenessAnalysis.cpp
@@ -60,7 +60,7 @@ ChangeResult Liveness::meet(const AbstractSparseLattice &other) {
 ///   (1.b) is a non-forwarded branch operand and its branch op could take the
 ///   control to a block that has an op with memory effects OR
 ///   (1.c) is a non-forwarded branch operand and its branch op could result
-///   in different result OR
+///   in different live result OR
 ///   (1.d) is a non-forwarded call operand.
 ///
 /// A value `A` is said to be "used to compute" value `B` iff `B` cannot be
@@ -120,9 +120,15 @@ void LivenessAnalysis::visitBranchOperand(OpOperand &operand) {
       // branch operation has a return value, and the non-forwarded operand can
       // determine the region to jump to, it can thereby control the result of
       // the region branch operation.
-      // Therefore, we conservatively consider the non-forwarded operand of the
-      // region branch operation with result may live.
-      mayLive = true;
+      // Therefore, if the result value is live, we conservatively consider the
+      // non-forwarded operand of the region branch operation with result may
+      // live and record all result.
+      for (Value result : op->getResults()) {
+        if (getLatticeElement(result)->isLive) {
+          mayLive = true;
+          break;
+        }
+      }
     } else {
       // When the op is a `RegionBranchOpInterface`, like an `scf.for` or an
       // `scf.index_switch` op, its branch operand controls the flow into this
@@ -148,9 +154,15 @@ void LivenessAnalysis::visitBranchOperand(OpOperand &operand) {
       // branch operation has a return value, and the non-forwarded operand can
       // determine the region to jump to, it can thereby control the result of
       // the region branch operation.
-      // Therefore, we conservatively consider the non-forwarded operand of the
-      // region branch operation with result may live.
-      mayLive = true;
+      // Therefore, if the result value is live, we conservatively consider the
+      // non-forwarded operand of the region branch operation with result may
+      // live and record all result.
+      for (Value result : parentOp->getResults()) {
+        if (getLatticeElement(result)->isLive) {
+          mayLive = true;
+          break;
+        }
+      }
     } else {
       // When the op is a `RegionBranchTerminatorOpInterface`, like an
       // `scf.condition` op or return-like, like an `scf.yield` op, its branch



More information about the Mlir-commits mailing list