[Mlir-commits] [mlir] a9ab845 - [MLIR][analysis] Fix error in the sparse backward dataflow analysis

Srishti Srivastava llvmlistbot at llvm.org
Fri Jul 28 23:31:29 PDT 2023


Author: Srishti Srivastava
Date: 2023-07-29T06:31:24Z
New Revision: a9ab845cb17e01ba83404d0fc82ac523d9f8dad0

URL: https://github.com/llvm/llvm-project/commit/a9ab845cb17e01ba83404d0fc82ac523d9f8dad0
DIFF: https://github.com/llvm/llvm-project/commit/a9ab845cb17e01ba83404d0fc82ac523d9f8dad0.diff

LOG: [MLIR][analysis] Fix error in the sparse backward dataflow analysis

Earlier, in the sparse backward dataflow analysis, data from the results
of an op implementing `RegionBranchOpInterface` was considered to flow
into the operands of every op that did not implement the
`RegionBranchTerminatorOpInterface` but was return-like and present
in a region of the former. It was thus also expected that the number of
results of the former be equal to the number of operands in the latter.

This understanding of dataflow is incorrect and thus this expectation is
also not justified. This commit fixes this incorrect understanding.

This commit ensures that these return-like ops are handled just like the
ops implementing the `RegionBranchTerminatorOpInterface`, which means
that, if this op has a region `A` whose successors are regions `B`, `C`,
and `D`, then data flows from the arguments (successor inputs) of `B`,
`C`, and `D` to the corresponding successor operands of this op.

This fix is also propagated to liveness analysis that earlier relied on
this incorrect implementation of the sparse backward dataflow analysis
framework and corrects some incorrect assumptions made in it.

Also cleaned up some unnecessary comments from the test file.

Issue: https://github.com/llvm/llvm-project/issues/64139.

Signed-off-by: Srishti Srivastava <srishtisrivastava.ai at gmail.com>

Reviewed By: jcai19, matthiaskramm, Mogball

Differential Revision: https://reviews.llvm.org/D156376

Added: 
    

Modified: 
    mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
    mlir/lib/Analysis/DataFlow/LivenessAnalysis.cpp
    mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
    mlir/test/Analysis/DataFlow/test-liveness-analysis.mlir
    mlir/test/Analysis/DataFlow/test-written-to.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
index ba4dac3bfca2b7..c0cb09ddfd8c2a 100644
--- a/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
+++ b/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
@@ -397,6 +397,14 @@ class AbstractSparseBackwardDataFlowAnalysis : public DataFlowAnalysis {
   void visitRegionSuccessors(RegionBranchOpInterface branch,
                              ArrayRef<AbstractSparseLattice *> operands);
 
+  /// Visit a terminator (an op implementing `RegionBranchTerminatorOpInterface`
+  /// or a return-like op) to compute the lattice values of its operands, given
+  /// its parent op `branch`. The lattice value of an operand is determined
+  /// based on the corresponding arguments in `terminator`'s region
+  /// successor(s).
+  void visitRegionSuccessorsFromTerminator(Operation *terminator,
+                                           RegionBranchOpInterface branch);
+
   /// Get the lattice element for a value, and also set up
   /// dependencies so that the analysis on the given ProgramPoint is re-invoked
   /// if the value changes.

diff  --git a/mlir/lib/Analysis/DataFlow/LivenessAnalysis.cpp b/mlir/lib/Analysis/DataFlow/LivenessAnalysis.cpp
index 45ad9c8497c20e..968b06572633e6 100644
--- a/mlir/lib/Analysis/DataFlow/LivenessAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/LivenessAnalysis.cpp
@@ -89,13 +89,15 @@ void LivenessAnalysis::visitOperation(Operation *op,
 
 void LivenessAnalysis::visitBranchOperand(OpOperand &operand) {
   // We know (at the moment) and assume (for the future) that `operand` is a
-  // non-forwarded branch operand of an op of type `RegionBranchOpInterface`,
-  // `BranchOpInterface`, or `RegionBranchTerminatorOpInterface`.
+  // non-forwarded branch operand of a `RegionBranchOpInterface`,
+  // `BranchOpInterface`, `RegionBranchTerminatorOpInterface` or return-like op.
   Operation *op = operand.getOwner();
   assert((isa<RegionBranchOpInterface>(op) || isa<BranchOpInterface>(op) ||
-          isa<RegionBranchTerminatorOpInterface>(op)) &&
+          isa<RegionBranchTerminatorOpInterface>(op) ||
+          op->hasTrait<OpTrait::ReturnLike>()) &&
          "expected the op to be `RegionBranchOpInterface`, "
-         "`BranchOpInterface`, or `RegionBranchTerminatorOpInterface`");
+         "`BranchOpInterface`, `RegionBranchTerminatorOpInterface`, or "
+         "return-like");
 
   // The lattices of the non-forwarded branch operands don't get updated like
   // the forwarded branch operands or the non-branch operands. Thus they need
@@ -120,10 +122,14 @@ void LivenessAnalysis::visitBranchOperand(OpOperand &operand) {
     // successors.
     blocks = op->getSuccessors();
   } else {
-    // When the op is a `RegionBranchTerminatorOpInterface`, like a
-    // `scf.condition` op, its branch operand controls the flow into this op's
-    // parent's (which is a `RegionBranchOpInterface`'s) regions.
-    for (Region &region : op->getParentOp()->getRegions()) {
+    // When the op is a `RegionBranchTerminatorOpInterface`, like an
+    // `scf.condition` op or return-like, like an `scf.yield` op, its branch
+    // operand controls the flow into this op's parent's (which is a
+    // `RegionBranchOpInterface`'s) regions.
+    Operation *parentOp = op->getParentOp();
+    assert(isa<RegionBranchOpInterface>(parentOp) &&
+           "expected parent op to implement `RegionBranchOpInterface`");
+    for (Region &region : parentOp->getRegions()) {
       for (Block &block : region)
         blocks.push_back(&block);
     }
@@ -155,10 +161,11 @@ void LivenessAnalysis::visitBranchOperand(OpOperand &operand) {
   visitOperation(op, operandLiveness, resultsLiveness);
 
   // We also visit the parent op with the parent's results and this operand if
-  // `op` is a `RegionBranchTerminatorOpInterface` because its non-forwarded
-  // operand depends on not only its memory effects/results but also on those of
-  // its parent's.
-  if (!isa<RegionBranchTerminatorOpInterface>(op))
+  // `op` is a `RegionBranchTerminatorOpInterface` or return-like because its
+  // non-forwarded operand depends on not only its memory effects/results but
+  // also on those of its parent's.
+  if (!isa<RegionBranchTerminatorOpInterface>(op) &&
+      !op->hasTrait<OpTrait::ReturnLike>())
     return;
   Operation *parentOp = op->getParentOp();
   SmallVector<const Liveness *, 4> parentResultsLiveness;

diff  --git a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
index c8242f39492877..3007b3826e439f 100644
--- a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
@@ -429,55 +429,25 @@ void AbstractSparseBackwardDataFlowAnalysis::visitOperation(Operation *op) {
     }
   }
 
-  // The block arguments of the branched to region flow back into the
-  // operands of the yield operation.
-  if (auto terminator = dyn_cast<RegionBranchTerminatorOpInterface>(op)) {
+  // When the region of an op implementing `RegionBranchOpInterface` has a
+  // terminator implementing `RegionBranchTerminatorOpInterface` or a
+  // return-like terminator, the region's successors' arguments flow back into
+  // the "successor operands" of this terminator.
+  //
+  // A successor operand with respect to an op implementing
+  // `RegionBranchOpInterface` is an operand that is forwarded to a region
+  // successor's input. There are two types of successor operands: the operands
+  // of this op itself and the operands of the terminators of the regions of
+  // this op.
+  if (isa<RegionBranchTerminatorOpInterface>(op) ||
+      op->hasTrait<OpTrait::ReturnLike>()) {
     if (auto branch = dyn_cast<RegionBranchOpInterface>(op->getParentOp())) {
-      SmallVector<RegionSuccessor> successors;
-      SmallVector<Attribute> operands(op->getNumOperands(), nullptr);
-      branch.getSuccessorRegions(op->getParentRegion()->getRegionNumber(),
-                                 operands, successors);
-      // All operands not forwarded to any successor. This set can be
-      // non-contiguous in the presence of multiple successors.
-      BitVector unaccounted(op->getNumOperands(), true);
-
-      for (const RegionSuccessor &successor : successors) {
-        ValueRange inputs = successor.getSuccessorInputs();
-        Region *region = successor.getSuccessor();
-        OperandRange operands =
-            region ? terminator.getSuccessorOperands(region->getRegionNumber())
-                   : terminator.getSuccessorOperands({});
-        MutableArrayRef<OpOperand> opoperands = operandsToOpOperands(operands);
-        for (auto [opoperand, input] : llvm::zip(opoperands, inputs)) {
-          meet(getLatticeElement(opoperand.get()),
-               *getLatticeElementFor(op, input));
-          unaccounted.reset(
-              const_cast<OpOperand &>(opoperand).getOperandNumber());
-        }
-      }
-      // Visit operands of the branch op not forwarded to the next region.
-      // (Like e.g. the boolean of `scf.conditional`)
-      for (int index : unaccounted.set_bits()) {
-        visitBranchOperand(op->getOpOperand(index));
-      }
+      visitRegionSuccessorsFromTerminator(op, branch);
       return;
     }
   }
 
-  // yield-like ops usually don't implement `RegionBranchTerminatorOpInterface`,
-  // since they behave like a return in the sense that they forward to the
-  // results of some other (here: the parent) op.
   if (op->hasTrait<OpTrait::ReturnLike>()) {
-    if (auto branch = dyn_cast<RegionBranchOpInterface>(op->getParentOp())) {
-      OperandRange operands = op->getOperands();
-      ResultRange results = op->getParentOp()->getResults();
-      assert(results.size() == operands.size() &&
-             "Can't derive arg mapping for yield-like op.");
-      for (auto [operand, result] : llvm::zip(operands, results))
-        meet(getLatticeElement(operand), *getLatticeElementFor(op, result));
-      return;
-    }
-
     // Going backwards, the operands of the return are derived from the
     // results of all CallOps calling this CallableOp.
     if (auto callable = dyn_cast<CallableOpInterface>(op->getParentOp())) {
@@ -535,6 +505,46 @@ void AbstractSparseBackwardDataFlowAnalysis::visitRegionSuccessors(
   }
 }
 
+void AbstractSparseBackwardDataFlowAnalysis::
+    visitRegionSuccessorsFromTerminator(Operation *terminator,
+                                        RegionBranchOpInterface branch) {
+  assert(isa<RegionBranchTerminatorOpInterface>(terminator) ||
+         terminator->hasTrait<OpTrait::ReturnLike>() &&
+             "expected a `RegionBranchTerminatorOpInterface` op or a "
+             "return-like op");
+  assert(terminator->getParentOp() == branch.getOperation() &&
+         "expected `branch` to be the parent op of `terminator`");
+
+  SmallVector<Attribute> operandAttributes(terminator->getNumOperands(),
+                                           nullptr);
+  SmallVector<RegionSuccessor> successors;
+  branch.getSuccessorRegions(terminator->getParentRegion()->getRegionNumber(),
+                             operandAttributes, successors);
+  // All operands not forwarded to any successor. This set can be
+  // non-contiguous in the presence of multiple successors.
+  BitVector unaccounted(terminator->getNumOperands(), true);
+
+  for (const RegionSuccessor &successor : successors) {
+    ValueRange inputs = successor.getSuccessorInputs();
+    Region *region = successor.getSuccessor();
+    OperandRange operands =
+        region ? *getRegionBranchSuccessorOperands(terminator,
+                                                   region->getRegionNumber())
+               : *getRegionBranchSuccessorOperands(terminator, {});
+    MutableArrayRef<OpOperand> opOperands = operandsToOpOperands(operands);
+    for (auto [opOperand, input] : llvm::zip(opOperands, inputs)) {
+      meet(getLatticeElement(opOperand.get()),
+           *getLatticeElementFor(terminator, input));
+      unaccounted.reset(const_cast<OpOperand &>(opOperand).getOperandNumber());
+    }
+  }
+  // Visit operands of the branch op not forwarded to the next region.
+  // (Like e.g. the boolean of `scf.conditional`)
+  for (int index : unaccounted.set_bits()) {
+    visitBranchOperand(terminator->getOpOperand(index));
+  }
+}
+
 const AbstractSparseLattice *
 AbstractSparseBackwardDataFlowAnalysis::getLatticeElementFor(ProgramPoint point,
                                                              Value value) {

diff  --git a/mlir/test/Analysis/DataFlow/test-liveness-analysis.mlir b/mlir/test/Analysis/DataFlow/test-liveness-analysis.mlir
index 9a92e3ba94310e..a040fb3961a9d3 100644
--- a/mlir/test/Analysis/DataFlow/test-liveness-analysis.mlir
+++ b/mlir/test/Analysis/DataFlow/test-liveness-analysis.mlir
@@ -16,8 +16,6 @@ func.func @test_1_type_1.a(%arg0: memref<i32>) {
 // where its op could take the control has an op with memory effects"
 // %arg2 is live because it can make the control go into a block with a memory
 // effecting op.
-// Note that if `visitBranchOperand()` was left empty, it would have been
-// incorrectly marked as "not live".
 // CHECK-LABEL: test_tag: br:
 // CHECK-NEXT:  operand #0: live
 // CHECK-NEXT:  operand #1: live
@@ -41,8 +39,6 @@ func.func @test_2_RegionBranchOpInterface_type_1.b(%arg0: memref<i32>, %arg1: me
 // where its op could take the control has an op with memory effects"
 // %arg0 is live because it can make the control go into a block with a memory
 // effecting op.
-// Note that if `visitBranchOperand()` was left empty, it would have been
-// incorrectly marked as "not live".
 // CHECK-LABEL: test_tag: flag:
 // CHECK-NEXT:  operand #0: live
 func.func @test_3_BranchOpInterface_type_1.b(%arg0: i32, %arg1: memref<i32>, %arg2: memref<i32>) {
@@ -77,26 +73,35 @@ func.func @test_4_type_2() -> (f32){
 // Positive test: Type (3) "is used to compute a value of type (1) or (2)"
 // %arg1 is live because the scf.while has a live result and %arg1 is a
 // non-forwarded branch operand.
-// Note that if `visitBranchOperand()` was left empty, it would have been
-// incorrectly marked as "not live".
 // %arg2 is live because it is forwarded to the live result of the scf.while
 // op.
-// Negative test: %arg3 is not live even though %arg1 and %arg2 are live
-// because it is neither a non-forwarded branch operand nor a forwarded
-// operand that forwards to a live value. It actually is a forwarded operand
-// that forwards to a non-live value.
+// %arg5 is live because it is forwarded to %arg8 which is live.
+// %arg8 is live because it is forwarded to %arg4 which is live as it writes
+// to memory.
+// Negative test:
+// %arg3 is not live even though %arg1, %arg2, and %arg5 are live because it
+// is neither a non-forwarded branch operand nor a forwarded operand that
+// forwards to a live value. It actually is a forwarded operand that forwards
+// to non-live values %0#1 and %arg7.
 // CHECK-LABEL: test_tag: condition:
 // CHECK-NEXT:  operand #0: live
 // CHECK-NEXT:  operand #1: live
 // CHECK-NEXT:  operand #2: not live
+// CHECK-NEXT:  operand #3: live
+// CHECK-LABEL: test_tag: add:
+// CHECK-NEXT:  operand #0: live
 func.func @test_5_RegionBranchTerminatorOpInterface_type_3(%arg0: memref<i32>, %arg1: i1) -> (i32) {
   %c0_i32 = arith.constant 0 : i32
   %c1_i32 = arith.constant 1 : i32
-  %0:2 = scf.while (%arg2 = %c0_i32, %arg3 = %c1_i32) : (i32, i32) -> (i32, i32) {
-    scf.condition(%arg1) {tag = "condition"} %arg2, %arg3 : i32, i32
+  %c2_i32 = arith.constant 2 : i32
+  %0:3 = scf.while (%arg2 = %c0_i32, %arg3 = %c1_i32, %arg4 = %c2_i32, %arg5 = %c2_i32) : (i32, i32, i32, i32) -> (i32, i32, i32) {
+    memref.store %arg4, %arg0[] : memref<i32>
+    scf.condition(%arg1) {tag = "condition"} %arg2, %arg3, %arg5 : i32, i32, i32
   } do {
-  ^bb0(%arg2: i32, %arg3: i32):
-    scf.yield %arg2, %arg3 : i32, i32
+  ^bb0(%arg6: i32, %arg7: i32, %arg8: i32):
+    %1 = arith.addi %arg8, %arg8 {tag = "add"} : i32
+    %c3_i32 = arith.constant 3 : i32
+    scf.yield %arg6, %arg7, %arg8, %c3_i32 : i32, i32, i32, i32
   }
   return %0#0 : i32
 }
@@ -112,12 +117,10 @@ func.func private @private0(%0 : i32) -> i32 {
 // zero, ten, and one are live because they are used to decide the number of
 // times the `for` loop executes, which in turn decides the value stored in
 // memory.
-// Note that if `visitBranchOperand()` was left empty, they would have been
-// incorrectly marked as "not live".
 // in_private0 and x are also live because they decide the value stored in
 // memory.
-// Negative test: y is not live even though the non-forwarded branch operand
-// and x are live.
+// Negative test:
+// y is not live even though the non-forwarded branch operand and x are live.
 // CHECK-LABEL: test_tag: in_private0:
 // CHECK-NEXT:  operand #0: live
 // CHECK-NEXT:  operand #1: live

diff  --git a/mlir/test/Analysis/DataFlow/test-written-to.mlir b/mlir/test/Analysis/DataFlow/test-written-to.mlir
index 11a9f0316aecd7..1ff92f56a1a80c 100644
--- a/mlir/test/Analysis/DataFlow/test-written-to.mlir
+++ b/mlir/test/Analysis/DataFlow/test-written-to.mlir
@@ -168,9 +168,10 @@ func.func @test_callchain(%m0: memref<f32>, %arg: f32) {
 // CHECK-LABEL: test_tag: zero
 // CHECK: result #0: [c]
 // CHECK-LABEL: test_tag: init
-// CHECK: result #0: [a b]
+// CHECK: result #0: [a b c]
 // CHECK-LABEL: test_tag: condition
 // CHECK: operand #0: [brancharg0]
+// CHECK: operand #2: [a b c]
 func.func @test_while(%m0: memref<i32>, %init : i32, %cond: i1) {
   %zero = arith.constant {tag = "zero"} 0 : i32
   %init2 = arith.addi %init, %init {tag = "init"} : i32
@@ -181,7 +182,7 @@ func.func @test_while(%m0: memref<i32>, %init : i32, %cond: i1) {
    ^bb0(%arg1: i32, %arg2: i32):
     memref.store %arg1, %m0[] {tag_name = "c"} : memref<i32>
     %res = arith.addi %arg2, %arg2 : i32
-    scf.yield %arg1, %res: i32, i32
+    scf.yield %res, %res: i32, i32
   }
   memref.store %1, %m0[] {tag_name = "b"} : memref<i32>
   return
@@ -189,6 +190,32 @@ func.func @test_while(%m0: memref<i32>, %init : i32, %cond: i1) {
 
 // -----
 
+// CHECK-LABEL: test_tag: zero
+// CHECK: result #0: []
+// CHECK-LABEL: test_tag: one
+// CHECK: result #0: [a]
+// CHECK-LABEL: test_tag: condition
+// CHECK: operand #0: [brancharg0]
+//
+// The important thing to note in this test is that the sparse backward dataflow
+// analysis framework also works on complex region branch ops like this one
+// where the number of operands in the `scf.yield` op don't match the number of
+// results in the parent op.
+func.func @test_complex_while(%m0: memref<i32>, %cond: i1) {
+  %zero = arith.constant {tag = "zero"} 0 : i32
+  %one = arith.constant {tag = "one"} 1 : i32
+  %0 = scf.while (%arg1 = %zero, %arg2 = %one) : (i32, i32) -> (i32) {
+    scf.condition(%cond) {tag = "condition"} %arg2 : i32
+  } do {
+   ^bb0(%arg1: i32):
+    scf.yield %arg1, %arg1: i32, i32
+  }
+  memref.store %0, %m0[] {tag_name = "a"} : memref<i32>
+  return
+}
+
+// -----
+
 // CHECK-LABEL: test_tag: zero
 // CHECK: result #0: [brancharg0]
 // CHECK-LABEL: test_tag: ten


        


More information about the Mlir-commits mailing list