[Mlir-commits] [mlir] [mlir][dataflow] Propagate errors from `visitOperation` (PR #105448)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Aug 20 16:26:42 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Ivan Butygin (Hardcode84)

<details>
<summary>Changes</summary>

Base `DataFlowAnalysis::visit` returns `LogicalResult`, but wrappers's Sparse/Dense/Forward/Backward `visitOperation` doesn't.

Sometimes it's needed to abort solver early if some unrecoverable condition detected inside analysis.

Update `visitOperation` to return `LogicalResult` and propagate it to `solver.initializeAndRun()`. Only `visitOperation` is updated for now, it's possible to update other hooks like `visitNonControlFlowArguments`, bit it's not needed immediately and let's keep this PR small.

Hijacked `UnderlyingValueAnalysis` test analysis to test it.

---

Patch is 43.79 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/105448.diff


15 Files Affected:

- (modified) mlir/include/mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h (+4-3) 
- (modified) mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h (+22-20) 
- (modified) mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h (+4-3) 
- (modified) mlir/include/mlir/Analysis/DataFlow/LivenessAnalysis.h (+2-2) 
- (modified) mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h (+14-12) 
- (modified) mlir/lib/Analysis/DataFlow/ConstantPropagationAnalysis.cpp (+6-5) 
- (modified) mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp (+32-19) 
- (modified) mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp (+6-3) 
- (modified) mlir/lib/Analysis/DataFlow/LivenessAnalysis.cpp (+6-5) 
- (modified) mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp (+41-27) 
- (added) mlir/test/Analysis/DataFlow/test-last-modified-error.mlir (+8) 
- (modified) mlir/test/lib/Analysis/DataFlow/TestDenseBackwardDataFlowAnalysis.cpp (+22-14) 
- (modified) mlir/test/lib/Analysis/DataFlow/TestDenseDataFlowAnalysis.h (+9-3) 
- (modified) mlir/test/lib/Analysis/DataFlow/TestDenseForwardDataFlowAnalysis.cpp (+21-14) 
- (modified) mlir/test/lib/Analysis/DataFlow/TestSparseBackwardDataFlowAnalysis.cpp (+7-6) 


``````````diff
diff --git a/mlir/include/mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h
index 1bf991dc193874..d2d4ff9960ea36 100644
--- a/mlir/include/mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h
+++ b/mlir/include/mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h
@@ -101,9 +101,10 @@ class SparseConstantPropagation
 public:
   using SparseForwardDataFlowAnalysis::SparseForwardDataFlowAnalysis;
 
-  void visitOperation(Operation *op,
-                      ArrayRef<const Lattice<ConstantValue> *> operands,
-                      ArrayRef<Lattice<ConstantValue> *> results) override;
+  LogicalResult
+  visitOperation(Operation *op,
+                 ArrayRef<const Lattice<ConstantValue> *> operands,
+                 ArrayRef<Lattice<ConstantValue> *> results) override;
 
   void setToEntryState(Lattice<ConstantValue> *lattice) override;
 };
diff --git a/mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h
index 088b6cd7d698fc..4ad5f3fcd838c0 100644
--- a/mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h
+++ b/mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h
@@ -87,9 +87,9 @@ class AbstractDenseForwardDataFlowAnalysis : public DataFlowAnalysis {
 protected:
   /// Propagate the dense lattice before the execution of an operation to the
   /// lattice after its execution.
-  virtual void visitOperationImpl(Operation *op,
-                                  const AbstractDenseLattice &before,
-                                  AbstractDenseLattice *after) = 0;
+  virtual LogicalResult visitOperationImpl(Operation *op,
+                                           const AbstractDenseLattice &before,
+                                           AbstractDenseLattice *after) = 0;
 
   /// Get the dense lattice after the execution of the given program point.
   virtual AbstractDenseLattice *getLattice(ProgramPoint point) = 0;
@@ -114,7 +114,7 @@ class AbstractDenseForwardDataFlowAnalysis : public DataFlowAnalysis {
   /// operation, then the state after the execution of the operation is set by
   /// control-flow or the callgraph. Otherwise, this function invokes the
   /// operation transfer function.
-  virtual void processOperation(Operation *op);
+  virtual LogicalResult processOperation(Operation *op);
 
   /// Propagate the dense lattice forward along the control flow edge from
   /// `regionFrom` to `regionTo` regions of the `branch` operation. `nullopt`
@@ -191,8 +191,8 @@ class DenseForwardDataFlowAnalysis
   /// Visit an operation with the dense lattice before its execution. This
   /// function is expected to set the dense lattice after its execution and
   /// trigger change propagation in case of change.
-  virtual void visitOperation(Operation *op, const LatticeT &before,
-                              LatticeT *after) = 0;
+  virtual LogicalResult visitOperation(Operation *op, const LatticeT &before,
+                                       LatticeT *after) = 0;
 
   /// Hook for customizing the behavior of lattice propagation along the call
   /// control flow edges. Two types of (forward) propagation are possible here:
@@ -263,10 +263,11 @@ class DenseForwardDataFlowAnalysis
 
   /// Type-erased wrappers that convert the abstract dense lattice to a derived
   /// lattice and invoke the virtual hooks operating on the derived lattice.
-  void visitOperationImpl(Operation *op, const AbstractDenseLattice &before,
-                          AbstractDenseLattice *after) final {
-    visitOperation(op, static_cast<const LatticeT &>(before),
-                   static_cast<LatticeT *>(after));
+  LogicalResult visitOperationImpl(Operation *op,
+                                   const AbstractDenseLattice &before,
+                                   AbstractDenseLattice *after) final {
+    return visitOperation(op, static_cast<const LatticeT &>(before),
+                          static_cast<LatticeT *>(after));
   }
   void visitCallControlFlowTransfer(CallOpInterface call,
                                     CallControlFlowAction action,
@@ -326,9 +327,9 @@ class AbstractDenseBackwardDataFlowAnalysis : public DataFlowAnalysis {
 protected:
   /// Propagate the dense lattice after the execution of an operation to the
   /// lattice before its execution.
-  virtual void visitOperationImpl(Operation *op,
-                                  const AbstractDenseLattice &after,
-                                  AbstractDenseLattice *before) = 0;
+  virtual LogicalResult visitOperationImpl(Operation *op,
+                                           const AbstractDenseLattice &after,
+                                           AbstractDenseLattice *before) = 0;
 
   /// Get the dense lattice before the execution of the program point. That is,
   /// before the execution of the given operation or after the execution of the
@@ -353,7 +354,7 @@ class AbstractDenseBackwardDataFlowAnalysis : public DataFlowAnalysis {
   /// Visit an operation. Dispatches to specialized methods for call or region
   /// control-flow operations. Otherwise, this function invokes the operation
   /// transfer function.
-  virtual void processOperation(Operation *op);
+  virtual LogicalResult processOperation(Operation *op);
 
   /// Propagate the dense lattice backwards along the control flow edge from
   /// `regionFrom` to `regionTo` regions of the `branch` operation. `nullopt`
@@ -442,8 +443,8 @@ class DenseBackwardDataFlowAnalysis
   /// Transfer function. Visits an operation with the dense lattice after its
   /// execution. This function is expected to set the dense lattice before its
   /// execution and trigger propagation in case of change.
-  virtual void visitOperation(Operation *op, const LatticeT &after,
-                              LatticeT *before) = 0;
+  virtual LogicalResult visitOperation(Operation *op, const LatticeT &after,
+                                       LatticeT *before) = 0;
 
   /// Hook for customizing the behavior of lattice propagation along the call
   /// control flow edges. Two types of (back) propagation are possible here:
@@ -513,10 +514,11 @@ class DenseBackwardDataFlowAnalysis
 
   /// Type-erased wrappers that convert the abstract dense lattice to a derived
   /// lattice and invoke the virtual hooks operating on the derived lattice.
-  void visitOperationImpl(Operation *op, const AbstractDenseLattice &after,
-                          AbstractDenseLattice *before) final {
-    visitOperation(op, static_cast<const LatticeT &>(after),
-                   static_cast<LatticeT *>(before));
+  LogicalResult visitOperationImpl(Operation *op,
+                                   const AbstractDenseLattice &after,
+                                   AbstractDenseLattice *before) final {
+    return visitOperation(op, static_cast<const LatticeT &>(after),
+                          static_cast<LatticeT *>(before));
   }
   void visitCallControlFlowTransfer(CallOpInterface call,
                                     CallControlFlowAction action,
diff --git a/mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h
index 191c023fb642cb..d4a5472cfde868 100644
--- a/mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h
+++ b/mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h
@@ -55,9 +55,10 @@ class IntegerRangeAnalysis
 
   /// Visit an operation. Invoke the transfer function on each operation that
   /// implements `InferIntRangeInterface`.
-  void visitOperation(Operation *op,
-                      ArrayRef<const IntegerValueRangeLattice *> operands,
-                      ArrayRef<IntegerValueRangeLattice *> results) override;
+  LogicalResult
+  visitOperation(Operation *op,
+                 ArrayRef<const IntegerValueRangeLattice *> operands,
+                 ArrayRef<IntegerValueRangeLattice *> results) override;
 
   /// Visit block arguments or operation results of an operation with region
   /// control-flow for which values are not defined by region control-flow. This
diff --git a/mlir/include/mlir/Analysis/DataFlow/LivenessAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/LivenessAnalysis.h
index caa03e26a3a423..cf1fd6e2d48caa 100644
--- a/mlir/include/mlir/Analysis/DataFlow/LivenessAnalysis.h
+++ b/mlir/include/mlir/Analysis/DataFlow/LivenessAnalysis.h
@@ -79,8 +79,8 @@ class LivenessAnalysis : public SparseBackwardDataFlowAnalysis<Liveness> {
 public:
   using SparseBackwardDataFlowAnalysis::SparseBackwardDataFlowAnalysis;
 
-  void visitOperation(Operation *op, ArrayRef<Liveness *> operands,
-                      ArrayRef<const Liveness *> results) override;
+  LogicalResult visitOperation(Operation *op, ArrayRef<Liveness *> operands,
+                               ArrayRef<const Liveness *> results) override;
 
   void visitBranchOperand(OpOperand &operand) override;
 
diff --git a/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
index 7aadd5409cc695..89726ae3a855c8 100644
--- a/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
+++ b/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
@@ -197,7 +197,7 @@ class AbstractSparseForwardDataFlowAnalysis : public DataFlowAnalysis {
 
   /// The operation transfer function. Given the operand lattices, this
   /// function is expected to set the result lattices.
-  virtual void
+  virtual LogicalResult
   visitOperationImpl(Operation *op,
                      ArrayRef<const AbstractSparseLattice *> operandLattices,
                      ArrayRef<AbstractSparseLattice *> resultLattices) = 0;
@@ -238,7 +238,7 @@ class AbstractSparseForwardDataFlowAnalysis : public DataFlowAnalysis {
   /// Visit an operation. If this is a call operation or an operation with
   /// region control-flow, then its result lattices are set accordingly.
   /// Otherwise, the operation transfer function is invoked.
-  void visitOperation(Operation *op);
+  LogicalResult visitOperation(Operation *op);
 
   /// Visit a block to compute the lattice values of its arguments. If this is
   /// an entry block, then the argument values are determined from the block's
@@ -277,8 +277,9 @@ class SparseForwardDataFlowAnalysis
 
   /// Visit an operation with the lattices of its operands. This function is
   /// expected to set the lattices of the operation's results.
-  virtual void visitOperation(Operation *op, ArrayRef<const StateT *> operands,
-                              ArrayRef<StateT *> results) = 0;
+  virtual LogicalResult visitOperation(Operation *op,
+                                       ArrayRef<const StateT *> operands,
+                                       ArrayRef<StateT *> results) = 0;
 
   /// Visit a call operation to an externally defined function given the
   /// lattices of its arguments.
@@ -328,10 +329,10 @@ class SparseForwardDataFlowAnalysis
 private:
   /// Type-erased wrappers that convert the abstract lattice operands to derived
   /// lattices and invoke the virtual hooks operating on the derived lattices.
-  void visitOperationImpl(
+  LogicalResult visitOperationImpl(
       Operation *op, ArrayRef<const AbstractSparseLattice *> operandLattices,
       ArrayRef<AbstractSparseLattice *> resultLattices) override {
-    visitOperation(
+    return visitOperation(
         op,
         {reinterpret_cast<const StateT *const *>(operandLattices.begin()),
          operandLattices.size()},
@@ -387,7 +388,7 @@ class AbstractSparseBackwardDataFlowAnalysis : public DataFlowAnalysis {
 
   /// The operation transfer function. Given the result lattices, this
   /// function is expected to set the operand lattices.
-  virtual void visitOperationImpl(
+  virtual LogicalResult visitOperationImpl(
       Operation *op, ArrayRef<AbstractSparseLattice *> operandLattices,
       ArrayRef<const AbstractSparseLattice *> resultLattices) = 0;
 
@@ -424,7 +425,7 @@ class AbstractSparseBackwardDataFlowAnalysis : public DataFlowAnalysis {
   /// Visit an operation. If this is a call operation or an operation with
   /// region control-flow, then its operand lattices are set accordingly.
   /// Otherwise, the operation transfer function is invoked.
-  void visitOperation(Operation *op);
+  LogicalResult visitOperation(Operation *op);
 
   /// Visit a block.
   void visitBlock(Block *block);
@@ -474,8 +475,9 @@ class SparseBackwardDataFlowAnalysis
 
   /// Visit an operation with the lattices of its results. This function is
   /// expected to set the lattices of the operation's operands.
-  virtual void visitOperation(Operation *op, ArrayRef<StateT *> operands,
-                              ArrayRef<const StateT *> results) = 0;
+  virtual LogicalResult visitOperation(Operation *op,
+                                       ArrayRef<StateT *> operands,
+                                       ArrayRef<const StateT *> results) = 0;
 
   /// Visit a call to an external function. This function is expected to set
   /// lattice values of the call operands. By default, calls `visitCallOperand`
@@ -510,10 +512,10 @@ class SparseBackwardDataFlowAnalysis
 private:
   /// Type-erased wrappers that convert the abstract lattice operands to derived
   /// lattices and invoke the virtual hooks operating on the derived lattices.
-  void visitOperationImpl(
+  LogicalResult visitOperationImpl(
       Operation *op, ArrayRef<AbstractSparseLattice *> operandLattices,
       ArrayRef<const AbstractSparseLattice *> resultLattices) override {
-    visitOperation(
+    return visitOperation(
         op,
         {reinterpret_cast<StateT *const *>(operandLattices.begin()),
          operandLattices.size()},
diff --git a/mlir/lib/Analysis/DataFlow/ConstantPropagationAnalysis.cpp b/mlir/lib/Analysis/DataFlow/ConstantPropagationAnalysis.cpp
index 16799d3c82092e..56529acd71bbf8 100644
--- a/mlir/lib/Analysis/DataFlow/ConstantPropagationAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/ConstantPropagationAnalysis.cpp
@@ -43,7 +43,7 @@ void ConstantValue::print(raw_ostream &os) const {
 // SparseConstantPropagation
 //===----------------------------------------------------------------------===//
 
-void SparseConstantPropagation::visitOperation(
+LogicalResult SparseConstantPropagation::visitOperation(
     Operation *op, ArrayRef<const Lattice<ConstantValue> *> operands,
     ArrayRef<Lattice<ConstantValue> *> results) {
   LLVM_DEBUG(llvm::dbgs() << "SCP: Visiting operation: " << *op << "\n");
@@ -54,14 +54,14 @@ void SparseConstantPropagation::visitOperation(
   // folding.
   if (op->getNumRegions()) {
     setAllToEntryStates(results);
-    return;
+    return success();
   }
 
   SmallVector<Attribute, 8> constantOperands;
   constantOperands.reserve(op->getNumOperands());
   for (auto *operandLattice : operands) {
     if (operandLattice->getValue().isUninitialized())
-      return;
+      return success();
     constantOperands.push_back(operandLattice->getValue().getConstantValue());
   }
 
@@ -77,7 +77,7 @@ void SparseConstantPropagation::visitOperation(
   foldResults.reserve(op->getNumResults());
   if (failed(op->fold(constantOperands, foldResults))) {
     setAllToEntryStates(results);
-    return;
+    return success();
   }
 
   // If the folding was in-place, mark the results as overdefined and reset
@@ -87,7 +87,7 @@ void SparseConstantPropagation::visitOperation(
     op->setOperands(originalOperands);
     op->setAttrs(originalAttrs);
     setAllToEntryStates(results);
-    return;
+    return success();
   }
 
   // Merge the fold results into the lattice for this operation.
@@ -108,6 +108,7 @@ void SparseConstantPropagation::visitOperation(
           lattice, *getLatticeElement(foldResult.get<Value>()));
     }
   }
+  return success();
 }
 
 void SparseConstantPropagation::setToEntryState(
diff --git a/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp b/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp
index 9894810f0e04b3..33c877f78f4bf6 100644
--- a/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp
@@ -30,7 +30,9 @@ using namespace mlir::dataflow;
 
 LogicalResult AbstractDenseForwardDataFlowAnalysis::initialize(Operation *top) {
   // Visit every operation and block.
-  processOperation(top);
+  if (failed(processOperation(top)))
+    return failure();
+
   for (Region &region : top->getRegions()) {
     for (Block &block : region) {
       visitBlock(&block);
@@ -44,7 +46,7 @@ LogicalResult AbstractDenseForwardDataFlowAnalysis::initialize(Operation *top) {
 
 LogicalResult AbstractDenseForwardDataFlowAnalysis::visit(ProgramPoint point) {
   if (auto *op = llvm::dyn_cast_if_present<Operation *>(point))
-    processOperation(op);
+    return processOperation(op);
   else if (auto *block = llvm::dyn_cast_if_present<Block *>(point))
     visitBlock(block);
   else
@@ -94,10 +96,11 @@ void AbstractDenseForwardDataFlowAnalysis::visitCallOperation(
   }
 }
 
-void AbstractDenseForwardDataFlowAnalysis::processOperation(Operation *op) {
+LogicalResult
+AbstractDenseForwardDataFlowAnalysis::processOperation(Operation *op) {
   // If the containing block is not executable, bail out.
   if (!getOrCreateFor<Executable>(op, op->getBlock())->isLive())
-    return;
+    return success();
 
   // Get the dense lattice to update.
   AbstractDenseLattice *after = getLattice(op);
@@ -111,16 +114,20 @@ void AbstractDenseForwardDataFlowAnalysis::processOperation(Operation *op) {
 
   // If this op implements region control-flow, then control-flow dictates its
   // transfer function.
-  if (auto branch = dyn_cast<RegionBranchOpInterface>(op))
-    return visitRegionBranchOperation(op, branch, after);
+  if (auto branch = dyn_cast<RegionBranchOpInterface>(op)) {
+    visitRegionBranchOperation(op, branch, after);
+    return success();
+  }
 
   // If this is a call operation, then join its lattices across known return
   // sites.
-  if (auto call = dyn_cast<CallOpInterface>(op))
-    return visitCallOperation(call, *before, after);
+  if (auto call = dyn_cast<CallOpInterface>(op)) {
+    visitCallOperation(call, *before, after);
+    return success();
+  }
 
   // Invoke the operation transfer function.
-  visitOperationImpl(op, *before, after);
+  return visitOperationImpl(op, *before, after);
 }
 
 void AbstractDenseForwardDataFlowAnalysis::visitBlock(Block *block) {
@@ -254,7 +261,9 @@ AbstractDenseForwardDataFlowAnalysis::getLatticeFor(ProgramPoint dependent,
 LogicalResult
 AbstractDenseBackwardDataFlowAnalysis::initialize(Operation *top) {
   // Visit every operation and block.
-  processOperation(top);
+  if (failed(processOperation(top)))
+    return failure();
+
   for (Region &region : top->getRegions()) {
     for (Block &block : region) {
       visitBlock(&block);
@@ -269,7 +278,7 @@ AbstractDenseBackwardDataFlowAnalysis::initialize(Operation *top) {
 
 LogicalResult AbstractDenseBackwardDataFlowAnalysis::visit(ProgramPoint point) {
   if (auto *op = llvm::dyn_cast_if_present<Operation *>(point))
-    processOperation(op);
+    return processOperation(op);
   else if (auto *block = llvm::dyn_cast_if_present<Block *>(point))
     visitBlock(block);
   else
@@ -323,10 +332,11 @@ void AbstractDenseBackwardDataFlowAnalysis::visitCallOperation(
                                latticeAtCalleeEntry, latticeBeforeCall);
 }
 
-void AbstractDenseBackwardDataFlowAnalysis::processOperation(Operation *op) {
+LogicalResult
+AbstractDenseBackwardDataFlowAnalysis::processOperation(Operation *op) {
   // If the containing block is not executable, bail out.
   if (!getOrCreateFor<Executable>(op, op->getBlock())->isLive())
-    return;
+    return success();
 
   // Get the dense lattice to update.
   AbstractDenseLattice *before = getLattice(op);
@@ -339,14 +349,17 @@ void AbstractDenseBackwardDataFlowAnalysis::processOperation(Operation *op) {
     after = getLatticeFor(op, op->getBlock());
 
   // Special cases where control flow may dictate data flow.
-  if (auto branch = dyn_cast<RegionBranchOpInterface>(op))
-    return visitRegionBranchOperation(op, branch, RegionBranchPoint::parent(),
-                                      before);
-  if (auto call = dyn_cast<CallOpInterface>(op))
-    return visitCallOperation(call, *after, before);
+  if (auto branch = dyn_cast<RegionBranchOpInterface>(op)) {
+    visitRegionBranchOperation(op, branch, RegionBranchPoint::parent(), before);
+    r...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/105448


More information about the Mlir-commits mailing list