[Mlir-commits] [mlir] [mlir][dataflow] Propagate errors from `visitOperation` (PR #105448)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Aug 20 16:26:42 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Ivan Butygin (Hardcode84)
<details>
<summary>Changes</summary>
Base `DataFlowAnalysis::visit` returns `LogicalResult`, but wrappers's Sparse/Dense/Forward/Backward `visitOperation` doesn't.
Sometimes it's needed to abort solver early if some unrecoverable condition detected inside analysis.
Update `visitOperation` to return `LogicalResult` and propagate it to `solver.initializeAndRun()`. Only `visitOperation` is updated for now, it's possible to update other hooks like `visitNonControlFlowArguments`, bit it's not needed immediately and let's keep this PR small.
Hijacked `UnderlyingValueAnalysis` test analysis to test it.
---
Patch is 43.79 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/105448.diff
15 Files Affected:
- (modified) mlir/include/mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h (+4-3)
- (modified) mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h (+22-20)
- (modified) mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h (+4-3)
- (modified) mlir/include/mlir/Analysis/DataFlow/LivenessAnalysis.h (+2-2)
- (modified) mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h (+14-12)
- (modified) mlir/lib/Analysis/DataFlow/ConstantPropagationAnalysis.cpp (+6-5)
- (modified) mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp (+32-19)
- (modified) mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp (+6-3)
- (modified) mlir/lib/Analysis/DataFlow/LivenessAnalysis.cpp (+6-5)
- (modified) mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp (+41-27)
- (added) mlir/test/Analysis/DataFlow/test-last-modified-error.mlir (+8)
- (modified) mlir/test/lib/Analysis/DataFlow/TestDenseBackwardDataFlowAnalysis.cpp (+22-14)
- (modified) mlir/test/lib/Analysis/DataFlow/TestDenseDataFlowAnalysis.h (+9-3)
- (modified) mlir/test/lib/Analysis/DataFlow/TestDenseForwardDataFlowAnalysis.cpp (+21-14)
- (modified) mlir/test/lib/Analysis/DataFlow/TestSparseBackwardDataFlowAnalysis.cpp (+7-6)
``````````diff
diff --git a/mlir/include/mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h
index 1bf991dc193874..d2d4ff9960ea36 100644
--- a/mlir/include/mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h
+++ b/mlir/include/mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h
@@ -101,9 +101,10 @@ class SparseConstantPropagation
public:
using SparseForwardDataFlowAnalysis::SparseForwardDataFlowAnalysis;
- void visitOperation(Operation *op,
- ArrayRef<const Lattice<ConstantValue> *> operands,
- ArrayRef<Lattice<ConstantValue> *> results) override;
+ LogicalResult
+ visitOperation(Operation *op,
+ ArrayRef<const Lattice<ConstantValue> *> operands,
+ ArrayRef<Lattice<ConstantValue> *> results) override;
void setToEntryState(Lattice<ConstantValue> *lattice) override;
};
diff --git a/mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h
index 088b6cd7d698fc..4ad5f3fcd838c0 100644
--- a/mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h
+++ b/mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h
@@ -87,9 +87,9 @@ class AbstractDenseForwardDataFlowAnalysis : public DataFlowAnalysis {
protected:
/// Propagate the dense lattice before the execution of an operation to the
/// lattice after its execution.
- virtual void visitOperationImpl(Operation *op,
- const AbstractDenseLattice &before,
- AbstractDenseLattice *after) = 0;
+ virtual LogicalResult visitOperationImpl(Operation *op,
+ const AbstractDenseLattice &before,
+ AbstractDenseLattice *after) = 0;
/// Get the dense lattice after the execution of the given program point.
virtual AbstractDenseLattice *getLattice(ProgramPoint point) = 0;
@@ -114,7 +114,7 @@ class AbstractDenseForwardDataFlowAnalysis : public DataFlowAnalysis {
/// operation, then the state after the execution of the operation is set by
/// control-flow or the callgraph. Otherwise, this function invokes the
/// operation transfer function.
- virtual void processOperation(Operation *op);
+ virtual LogicalResult processOperation(Operation *op);
/// Propagate the dense lattice forward along the control flow edge from
/// `regionFrom` to `regionTo` regions of the `branch` operation. `nullopt`
@@ -191,8 +191,8 @@ class DenseForwardDataFlowAnalysis
/// Visit an operation with the dense lattice before its execution. This
/// function is expected to set the dense lattice after its execution and
/// trigger change propagation in case of change.
- virtual void visitOperation(Operation *op, const LatticeT &before,
- LatticeT *after) = 0;
+ virtual LogicalResult visitOperation(Operation *op, const LatticeT &before,
+ LatticeT *after) = 0;
/// Hook for customizing the behavior of lattice propagation along the call
/// control flow edges. Two types of (forward) propagation are possible here:
@@ -263,10 +263,11 @@ class DenseForwardDataFlowAnalysis
/// Type-erased wrappers that convert the abstract dense lattice to a derived
/// lattice and invoke the virtual hooks operating on the derived lattice.
- void visitOperationImpl(Operation *op, const AbstractDenseLattice &before,
- AbstractDenseLattice *after) final {
- visitOperation(op, static_cast<const LatticeT &>(before),
- static_cast<LatticeT *>(after));
+ LogicalResult visitOperationImpl(Operation *op,
+ const AbstractDenseLattice &before,
+ AbstractDenseLattice *after) final {
+ return visitOperation(op, static_cast<const LatticeT &>(before),
+ static_cast<LatticeT *>(after));
}
void visitCallControlFlowTransfer(CallOpInterface call,
CallControlFlowAction action,
@@ -326,9 +327,9 @@ class AbstractDenseBackwardDataFlowAnalysis : public DataFlowAnalysis {
protected:
/// Propagate the dense lattice after the execution of an operation to the
/// lattice before its execution.
- virtual void visitOperationImpl(Operation *op,
- const AbstractDenseLattice &after,
- AbstractDenseLattice *before) = 0;
+ virtual LogicalResult visitOperationImpl(Operation *op,
+ const AbstractDenseLattice &after,
+ AbstractDenseLattice *before) = 0;
/// Get the dense lattice before the execution of the program point. That is,
/// before the execution of the given operation or after the execution of the
@@ -353,7 +354,7 @@ class AbstractDenseBackwardDataFlowAnalysis : public DataFlowAnalysis {
/// Visit an operation. Dispatches to specialized methods for call or region
/// control-flow operations. Otherwise, this function invokes the operation
/// transfer function.
- virtual void processOperation(Operation *op);
+ virtual LogicalResult processOperation(Operation *op);
/// Propagate the dense lattice backwards along the control flow edge from
/// `regionFrom` to `regionTo` regions of the `branch` operation. `nullopt`
@@ -442,8 +443,8 @@ class DenseBackwardDataFlowAnalysis
/// Transfer function. Visits an operation with the dense lattice after its
/// execution. This function is expected to set the dense lattice before its
/// execution and trigger propagation in case of change.
- virtual void visitOperation(Operation *op, const LatticeT &after,
- LatticeT *before) = 0;
+ virtual LogicalResult visitOperation(Operation *op, const LatticeT &after,
+ LatticeT *before) = 0;
/// Hook for customizing the behavior of lattice propagation along the call
/// control flow edges. Two types of (back) propagation are possible here:
@@ -513,10 +514,11 @@ class DenseBackwardDataFlowAnalysis
/// Type-erased wrappers that convert the abstract dense lattice to a derived
/// lattice and invoke the virtual hooks operating on the derived lattice.
- void visitOperationImpl(Operation *op, const AbstractDenseLattice &after,
- AbstractDenseLattice *before) final {
- visitOperation(op, static_cast<const LatticeT &>(after),
- static_cast<LatticeT *>(before));
+ LogicalResult visitOperationImpl(Operation *op,
+ const AbstractDenseLattice &after,
+ AbstractDenseLattice *before) final {
+ return visitOperation(op, static_cast<const LatticeT &>(after),
+ static_cast<LatticeT *>(before));
}
void visitCallControlFlowTransfer(CallOpInterface call,
CallControlFlowAction action,
diff --git a/mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h
index 191c023fb642cb..d4a5472cfde868 100644
--- a/mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h
+++ b/mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h
@@ -55,9 +55,10 @@ class IntegerRangeAnalysis
/// Visit an operation. Invoke the transfer function on each operation that
/// implements `InferIntRangeInterface`.
- void visitOperation(Operation *op,
- ArrayRef<const IntegerValueRangeLattice *> operands,
- ArrayRef<IntegerValueRangeLattice *> results) override;
+ LogicalResult
+ visitOperation(Operation *op,
+ ArrayRef<const IntegerValueRangeLattice *> operands,
+ ArrayRef<IntegerValueRangeLattice *> results) override;
/// Visit block arguments or operation results of an operation with region
/// control-flow for which values are not defined by region control-flow. This
diff --git a/mlir/include/mlir/Analysis/DataFlow/LivenessAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/LivenessAnalysis.h
index caa03e26a3a423..cf1fd6e2d48caa 100644
--- a/mlir/include/mlir/Analysis/DataFlow/LivenessAnalysis.h
+++ b/mlir/include/mlir/Analysis/DataFlow/LivenessAnalysis.h
@@ -79,8 +79,8 @@ class LivenessAnalysis : public SparseBackwardDataFlowAnalysis<Liveness> {
public:
using SparseBackwardDataFlowAnalysis::SparseBackwardDataFlowAnalysis;
- void visitOperation(Operation *op, ArrayRef<Liveness *> operands,
- ArrayRef<const Liveness *> results) override;
+ LogicalResult visitOperation(Operation *op, ArrayRef<Liveness *> operands,
+ ArrayRef<const Liveness *> results) override;
void visitBranchOperand(OpOperand &operand) override;
diff --git a/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
index 7aadd5409cc695..89726ae3a855c8 100644
--- a/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
+++ b/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
@@ -197,7 +197,7 @@ class AbstractSparseForwardDataFlowAnalysis : public DataFlowAnalysis {
/// The operation transfer function. Given the operand lattices, this
/// function is expected to set the result lattices.
- virtual void
+ virtual LogicalResult
visitOperationImpl(Operation *op,
ArrayRef<const AbstractSparseLattice *> operandLattices,
ArrayRef<AbstractSparseLattice *> resultLattices) = 0;
@@ -238,7 +238,7 @@ class AbstractSparseForwardDataFlowAnalysis : public DataFlowAnalysis {
/// Visit an operation. If this is a call operation or an operation with
/// region control-flow, then its result lattices are set accordingly.
/// Otherwise, the operation transfer function is invoked.
- void visitOperation(Operation *op);
+ LogicalResult visitOperation(Operation *op);
/// Visit a block to compute the lattice values of its arguments. If this is
/// an entry block, then the argument values are determined from the block's
@@ -277,8 +277,9 @@ class SparseForwardDataFlowAnalysis
/// Visit an operation with the lattices of its operands. This function is
/// expected to set the lattices of the operation's results.
- virtual void visitOperation(Operation *op, ArrayRef<const StateT *> operands,
- ArrayRef<StateT *> results) = 0;
+ virtual LogicalResult visitOperation(Operation *op,
+ ArrayRef<const StateT *> operands,
+ ArrayRef<StateT *> results) = 0;
/// Visit a call operation to an externally defined function given the
/// lattices of its arguments.
@@ -328,10 +329,10 @@ class SparseForwardDataFlowAnalysis
private:
/// Type-erased wrappers that convert the abstract lattice operands to derived
/// lattices and invoke the virtual hooks operating on the derived lattices.
- void visitOperationImpl(
+ LogicalResult visitOperationImpl(
Operation *op, ArrayRef<const AbstractSparseLattice *> operandLattices,
ArrayRef<AbstractSparseLattice *> resultLattices) override {
- visitOperation(
+ return visitOperation(
op,
{reinterpret_cast<const StateT *const *>(operandLattices.begin()),
operandLattices.size()},
@@ -387,7 +388,7 @@ class AbstractSparseBackwardDataFlowAnalysis : public DataFlowAnalysis {
/// The operation transfer function. Given the result lattices, this
/// function is expected to set the operand lattices.
- virtual void visitOperationImpl(
+ virtual LogicalResult visitOperationImpl(
Operation *op, ArrayRef<AbstractSparseLattice *> operandLattices,
ArrayRef<const AbstractSparseLattice *> resultLattices) = 0;
@@ -424,7 +425,7 @@ class AbstractSparseBackwardDataFlowAnalysis : public DataFlowAnalysis {
/// Visit an operation. If this is a call operation or an operation with
/// region control-flow, then its operand lattices are set accordingly.
/// Otherwise, the operation transfer function is invoked.
- void visitOperation(Operation *op);
+ LogicalResult visitOperation(Operation *op);
/// Visit a block.
void visitBlock(Block *block);
@@ -474,8 +475,9 @@ class SparseBackwardDataFlowAnalysis
/// Visit an operation with the lattices of its results. This function is
/// expected to set the lattices of the operation's operands.
- virtual void visitOperation(Operation *op, ArrayRef<StateT *> operands,
- ArrayRef<const StateT *> results) = 0;
+ virtual LogicalResult visitOperation(Operation *op,
+ ArrayRef<StateT *> operands,
+ ArrayRef<const StateT *> results) = 0;
/// Visit a call to an external function. This function is expected to set
/// lattice values of the call operands. By default, calls `visitCallOperand`
@@ -510,10 +512,10 @@ class SparseBackwardDataFlowAnalysis
private:
/// Type-erased wrappers that convert the abstract lattice operands to derived
/// lattices and invoke the virtual hooks operating on the derived lattices.
- void visitOperationImpl(
+ LogicalResult visitOperationImpl(
Operation *op, ArrayRef<AbstractSparseLattice *> operandLattices,
ArrayRef<const AbstractSparseLattice *> resultLattices) override {
- visitOperation(
+ return visitOperation(
op,
{reinterpret_cast<StateT *const *>(operandLattices.begin()),
operandLattices.size()},
diff --git a/mlir/lib/Analysis/DataFlow/ConstantPropagationAnalysis.cpp b/mlir/lib/Analysis/DataFlow/ConstantPropagationAnalysis.cpp
index 16799d3c82092e..56529acd71bbf8 100644
--- a/mlir/lib/Analysis/DataFlow/ConstantPropagationAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/ConstantPropagationAnalysis.cpp
@@ -43,7 +43,7 @@ void ConstantValue::print(raw_ostream &os) const {
// SparseConstantPropagation
//===----------------------------------------------------------------------===//
-void SparseConstantPropagation::visitOperation(
+LogicalResult SparseConstantPropagation::visitOperation(
Operation *op, ArrayRef<const Lattice<ConstantValue> *> operands,
ArrayRef<Lattice<ConstantValue> *> results) {
LLVM_DEBUG(llvm::dbgs() << "SCP: Visiting operation: " << *op << "\n");
@@ -54,14 +54,14 @@ void SparseConstantPropagation::visitOperation(
// folding.
if (op->getNumRegions()) {
setAllToEntryStates(results);
- return;
+ return success();
}
SmallVector<Attribute, 8> constantOperands;
constantOperands.reserve(op->getNumOperands());
for (auto *operandLattice : operands) {
if (operandLattice->getValue().isUninitialized())
- return;
+ return success();
constantOperands.push_back(operandLattice->getValue().getConstantValue());
}
@@ -77,7 +77,7 @@ void SparseConstantPropagation::visitOperation(
foldResults.reserve(op->getNumResults());
if (failed(op->fold(constantOperands, foldResults))) {
setAllToEntryStates(results);
- return;
+ return success();
}
// If the folding was in-place, mark the results as overdefined and reset
@@ -87,7 +87,7 @@ void SparseConstantPropagation::visitOperation(
op->setOperands(originalOperands);
op->setAttrs(originalAttrs);
setAllToEntryStates(results);
- return;
+ return success();
}
// Merge the fold results into the lattice for this operation.
@@ -108,6 +108,7 @@ void SparseConstantPropagation::visitOperation(
lattice, *getLatticeElement(foldResult.get<Value>()));
}
}
+ return success();
}
void SparseConstantPropagation::setToEntryState(
diff --git a/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp b/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp
index 9894810f0e04b3..33c877f78f4bf6 100644
--- a/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp
@@ -30,7 +30,9 @@ using namespace mlir::dataflow;
LogicalResult AbstractDenseForwardDataFlowAnalysis::initialize(Operation *top) {
// Visit every operation and block.
- processOperation(top);
+ if (failed(processOperation(top)))
+ return failure();
+
for (Region ®ion : top->getRegions()) {
for (Block &block : region) {
visitBlock(&block);
@@ -44,7 +46,7 @@ LogicalResult AbstractDenseForwardDataFlowAnalysis::initialize(Operation *top) {
LogicalResult AbstractDenseForwardDataFlowAnalysis::visit(ProgramPoint point) {
if (auto *op = llvm::dyn_cast_if_present<Operation *>(point))
- processOperation(op);
+ return processOperation(op);
else if (auto *block = llvm::dyn_cast_if_present<Block *>(point))
visitBlock(block);
else
@@ -94,10 +96,11 @@ void AbstractDenseForwardDataFlowAnalysis::visitCallOperation(
}
}
-void AbstractDenseForwardDataFlowAnalysis::processOperation(Operation *op) {
+LogicalResult
+AbstractDenseForwardDataFlowAnalysis::processOperation(Operation *op) {
// If the containing block is not executable, bail out.
if (!getOrCreateFor<Executable>(op, op->getBlock())->isLive())
- return;
+ return success();
// Get the dense lattice to update.
AbstractDenseLattice *after = getLattice(op);
@@ -111,16 +114,20 @@ void AbstractDenseForwardDataFlowAnalysis::processOperation(Operation *op) {
// If this op implements region control-flow, then control-flow dictates its
// transfer function.
- if (auto branch = dyn_cast<RegionBranchOpInterface>(op))
- return visitRegionBranchOperation(op, branch, after);
+ if (auto branch = dyn_cast<RegionBranchOpInterface>(op)) {
+ visitRegionBranchOperation(op, branch, after);
+ return success();
+ }
// If this is a call operation, then join its lattices across known return
// sites.
- if (auto call = dyn_cast<CallOpInterface>(op))
- return visitCallOperation(call, *before, after);
+ if (auto call = dyn_cast<CallOpInterface>(op)) {
+ visitCallOperation(call, *before, after);
+ return success();
+ }
// Invoke the operation transfer function.
- visitOperationImpl(op, *before, after);
+ return visitOperationImpl(op, *before, after);
}
void AbstractDenseForwardDataFlowAnalysis::visitBlock(Block *block) {
@@ -254,7 +261,9 @@ AbstractDenseForwardDataFlowAnalysis::getLatticeFor(ProgramPoint dependent,
LogicalResult
AbstractDenseBackwardDataFlowAnalysis::initialize(Operation *top) {
// Visit every operation and block.
- processOperation(top);
+ if (failed(processOperation(top)))
+ return failure();
+
for (Region ®ion : top->getRegions()) {
for (Block &block : region) {
visitBlock(&block);
@@ -269,7 +278,7 @@ AbstractDenseBackwardDataFlowAnalysis::initialize(Operation *top) {
LogicalResult AbstractDenseBackwardDataFlowAnalysis::visit(ProgramPoint point) {
if (auto *op = llvm::dyn_cast_if_present<Operation *>(point))
- processOperation(op);
+ return processOperation(op);
else if (auto *block = llvm::dyn_cast_if_present<Block *>(point))
visitBlock(block);
else
@@ -323,10 +332,11 @@ void AbstractDenseBackwardDataFlowAnalysis::visitCallOperation(
latticeAtCalleeEntry, latticeBeforeCall);
}
-void AbstractDenseBackwardDataFlowAnalysis::processOperation(Operation *op) {
+LogicalResult
+AbstractDenseBackwardDataFlowAnalysis::processOperation(Operation *op) {
// If the containing block is not executable, bail out.
if (!getOrCreateFor<Executable>(op, op->getBlock())->isLive())
- return;
+ return success();
// Get the dense lattice to update.
AbstractDenseLattice *before = getLattice(op);
@@ -339,14 +349,17 @@ void AbstractDenseBackwardDataFlowAnalysis::processOperation(Operation *op) {
after = getLatticeFor(op, op->getBlock());
// Special cases where control flow may dictate data flow.
- if (auto branch = dyn_cast<RegionBranchOpInterface>(op))
- return visitRegionBranchOperation(op, branch, RegionBranchPoint::parent(),
- before);
- if (auto call = dyn_cast<CallOpInterface>(op))
- return visitCallOperation(call, *after, before);
+ if (auto branch = dyn_cast<RegionBranchOpInterface>(op)) {
+ visitRegionBranchOperation(op, branch, RegionBranchPoint::parent(), before);
+ r...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/105448
More information about the Mlir-commits
mailing list