[Mlir-commits] [mlir] 4e98d61 - [mlir] Implement backward dataflow.
Alex Zinenko
llvmlistbot at llvm.org
Tue Dec 13 09:35:36 PST 2022
Author: Matthias Kramm
Date: 2022-12-13T18:35:27+01:00
New Revision: 4e98d611ef67e46e60743b8429bad5eb531a1e7c
URL: https://github.com/llvm/llvm-project/commit/4e98d611ef67e46e60743b8429bad5eb531a1e7c
DIFF: https://github.com/llvm/llvm-project/commit/4e98d611ef67e46e60743b8429bad5eb531a1e7c.diff
LOG: [mlir] Implement backward dataflow.
This enables interprocedural lifeness analysis, very busy expression
analysis, etc.
Reviewed By: Mogball
Differential Revision: https://reviews.llvm.org/D138935
Added:
mlir/test/Analysis/DataFlow/test-written-to.mlir
mlir/test/lib/Analysis/DataFlow/TestBackwardDataFlowAnalysis.cpp
Modified:
mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
mlir/test/lib/Analysis/CMakeLists.txt
mlir/tools/mlir-opt/mlir-opt.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
index 62d32b21ac656..a178c6024fdee 100644
--- a/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
+++ b/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
@@ -16,6 +16,7 @@
#define MLIR_ANALYSIS_DATAFLOW_SPARSEANALYSIS_H
#include "mlir/Analysis/DataFlowFramework.h"
+#include "mlir/IR/SymbolTable.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "llvm/ADT/SmallPtrSet.h"
@@ -39,7 +40,15 @@ class AbstractSparseLattice : public AnalysisState {
/// Join the information contained in 'rhs' into this lattice. Returns
/// if the value of the lattice changed.
- virtual ChangeResult join(const AbstractSparseLattice &rhs) = 0;
+ virtual ChangeResult join(const AbstractSparseLattice &rhs) {
+ return ChangeResult::NoChange;
+ }
+
+ /// Meet (intersect) the information in this lattice with 'rhs'. Returns
+ /// if the value of the lattice changed.
+ virtual ChangeResult meet(const AbstractSparseLattice &rhs) {
+ return ChangeResult::NoChange;
+ }
/// When the lattice gets updated, propagate an update to users of the value
/// using its use-def chain to subscribed analyses.
@@ -86,14 +95,18 @@ class Lattice : public AbstractSparseLattice {
return const_cast<Lattice<ValueT> *>(this)->getValue();
}
+ using LatticeT = Lattice<ValueT>;
+
/// Join the information contained in the 'rhs' lattice into this
/// lattice. Returns if the state of the current lattice changed.
ChangeResult join(const AbstractSparseLattice &rhs) override {
- const Lattice<ValueT> &rhsLattice =
- static_cast<const Lattice<ValueT> &>(rhs);
+ return join(static_cast<const LatticeT &>(rhs).getValue());
+ }
- // Join the rhs value into this lattice.
- return join(rhsLattice.getValue());
+ /// Meet (intersect) the information contained in the 'rhs' lattice with
+ /// this lattice. Returns if the state of the current lattice changed.
+ ChangeResult meet(const AbstractSparseLattice &rhs) override {
+ return meet(static_cast<const LatticeT &>(rhs).getValue());
}
/// Join the information contained in the 'rhs' value into this
@@ -114,6 +127,38 @@ class Lattice : public AbstractSparseLattice {
return ChangeResult::Change;
}
+ /// Trait to check if `T` provides a `meet` method. Needed since for forward
+ /// analysis, lattices will only have a `join`, no `meet`, but we want to use
+ /// the same `Lattice` class for both directions.
+ template <typename T, typename... Args>
+ using has_meet = decltype(std::declval<T>().meet());
+ template <typename T>
+ using lattice_has_meet = llvm::is_detected<has_meet, T>;
+
+ /// Meet (intersect) the information contained in the 'rhs' value with this
+ /// lattice. Returns if the state of the current lattice changed. If the
+ /// lattice elements don't have a `meet` method, this is a no-op (see below.)
+ template <typename VT, std::enable_if_t<lattice_has_meet<VT>::value>>
+ ChangeResult meet(const VT &rhs) {
+ ValueT newValue = ValueT::meet(value, rhs);
+ assert(ValueT::meet(newValue, value) == newValue &&
+ "expected `meet` to be monotonic");
+ assert(ValueT::meet(newValue, rhs) == newValue &&
+ "expected `meet` to be monotonic");
+
+ // Update the current optimistic value if something changed.
+ if (newValue == value)
+ return ChangeResult::NoChange;
+
+ value = newValue;
+ return ChangeResult::Change;
+ }
+
+ template <typename VT>
+ ChangeResult meet(const VT &rhs) {
+ return ChangeResult::NoChange;
+ }
+
/// Print the lattice element.
void print(raw_ostream &os) const override { value.print(os); }
@@ -289,6 +334,135 @@ class SparseDataFlowAnalysis : public AbstractSparseDataFlowAnalysis {
}
};
+//===----------------------------------------------------------------------===//
+// AbstractSparseBackwardDataFlowAnalysis
+//===----------------------------------------------------------------------===//
+
+/// Base class for sparse (backward) data-flow analyses. Similar to
+/// AbstractSparseDataFlowAnalysis, but walks bottom to top.
+class AbstractSparseBackwardDataFlowAnalysis : public DataFlowAnalysis {
+public:
+ /// Initialize the analysis by visiting the operation and everything nested
+ /// under it.
+ LogicalResult initialize(Operation *top) override;
+
+ /// Visit a program point. If this is a call operation or an operation with
+ /// block or region control-flow, then operand lattices are set accordingly.
+ /// Otherwise, invokes the operation transfer function (`visitOperationImpl`).
+ LogicalResult visit(ProgramPoint point) override;
+
+protected:
+ explicit AbstractSparseBackwardDataFlowAnalysis(
+ DataFlowSolver &solver, SymbolTableCollection &symbolTable);
+
+ /// The operation transfer function. Given the result lattices, this
+ /// function is expected to set the operand lattices.
+ virtual void visitOperationImpl(
+ Operation *op, ArrayRef<AbstractSparseLattice *> operandLattices,
+ ArrayRef<const AbstractSparseLattice *> resultLattices) = 0;
+
+ // Visit operands on branch instructions that are not forwarded
+ virtual void visitBranchOperand(OpOperand &operand) = 0;
+
+ /// Set the given lattice element(s) at control flow exit point(s).
+ virtual void setToExitState(AbstractSparseLattice *lattice) = 0;
+
+ /// Set the given lattice element(s) at control flow exit point(s).
+ void setAllToExitStates(ArrayRef<AbstractSparseLattice *> lattices);
+
+ /// Get the lattice element for a value.
+ virtual AbstractSparseLattice *getLatticeElement(Value value) = 0;
+
+ /// Get the lattice elements for a range of values.
+ SmallVector<AbstractSparseLattice *> getLatticeElements(ValueRange values);
+
+ /// Join the lattice element and propagate and update if it changed.
+ void meet(AbstractSparseLattice *lhs, const AbstractSparseLattice &rhs);
+
+private:
+ /// Recursively initialize the analysis on nested operations and blocks.
+ LogicalResult initializeRecursively(Operation *op);
+
+ /// Visit an operation. If this is a call operation or an operation with
+ /// region control-flow, then its operand lattices are set accordingly.
+ /// Otherwise, the operation transfer function is invoked.
+ void visitOperation(Operation *op);
+
+ /// Visit a block.
+ void visitBlock(Block *block);
+
+ /// Visit an op with regions (like e.g. `scf.while`)
+ void visitRegionSuccessors(RegionBranchOpInterface branch,
+ ArrayRef<AbstractSparseLattice *> operands);
+
+ /// Get the lattice element for a value, and also set up
+ /// dependencies so that the analysis on the given ProgramPoint is re-invoked
+ /// if the value changes.
+ const AbstractSparseLattice *getLatticeElementFor(ProgramPoint point,
+ Value value);
+
+ /// Get the lattice elements for a range of values, and also set up
+ /// dependencies so that the analysis on the given ProgramPoint is re-invoked
+ /// if any of the values change.
+ SmallVector<const AbstractSparseLattice *>
+ getLatticeElementsFor(ProgramPoint point, ValueRange values);
+
+ SymbolTableCollection &symbolTable;
+};
+
+//===----------------------------------------------------------------------===//
+// SparseBackwardDataFlowAnalysis
+//===----------------------------------------------------------------------===//
+
+/// A sparse (backward) data-flow analysis for propagating SSA value lattices
+/// backwards across the IR by implementing transfer functions for operations.
+///
+/// `StateT` is expected to be a subclass of `AbstractSparseLattice`.
+template <typename StateT>
+class SparseBackwardDataFlowAnalysis
+ : public AbstractSparseBackwardDataFlowAnalysis {
+public:
+ explicit SparseBackwardDataFlowAnalysis(DataFlowSolver &solver,
+ SymbolTableCollection &symbolTable)
+ : AbstractSparseBackwardDataFlowAnalysis(solver, symbolTable) {}
+
+ /// Visit an operation with the lattices of its results. This function is
+ /// expected to set the lattices of the operation's operands.
+ virtual void visitOperation(Operation *op, ArrayRef<StateT *> operands,
+ ArrayRef<const StateT *> results) = 0;
+
+protected:
+ /// Get the lattice element for a value.
+ StateT *getLatticeElement(Value value) override {
+ return getOrCreate<StateT>(value);
+ }
+
+ /// Set the given lattice element(s) at control flow exit point(s).
+ virtual void setToExitState(StateT *lattice) = 0;
+ void setToExitState(AbstractSparseLattice *lattice) override {
+ return setToExitState(reinterpret_cast<StateT *>(lattice));
+ }
+ void setAllToExitStates(ArrayRef<StateT *> lattices) {
+ AbstractSparseBackwardDataFlowAnalysis::setAllToExitStates(
+ {reinterpret_cast<AbstractSparseLattice *const *>(lattices.begin()),
+ lattices.size()});
+ }
+
+private:
+ /// Type-erased wrappers that convert the abstract lattice operands to derived
+ /// lattices and invoke the virtual hooks operating on the derived lattices.
+ void visitOperationImpl(
+ Operation *op, ArrayRef<AbstractSparseLattice *> operandLattices,
+ ArrayRef<const AbstractSparseLattice *> resultLattices) override {
+ visitOperation(
+ op,
+ {reinterpret_cast<StateT *const *>(operandLattices.begin()),
+ operandLattices.size()},
+ {reinterpret_cast<const StateT *const *>(resultLattices.begin()),
+ resultLattices.size()});
+ }
+};
+
} // end namespace dataflow
} // end namespace mlir
diff --git a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
index 1ea60b45700cf..c5d2ac4ca01ff 100644
--- a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
@@ -41,7 +41,7 @@ LogicalResult AbstractSparseDataFlowAnalysis::initialize(Operation *top) {
if (region.empty())
continue;
for (Value argument : region.front().getArguments())
- setAllToEntryStates(getLatticeElement(argument));
+ setToEntryState(getLatticeElement(argument));
}
return initializeRecursively(top);
@@ -281,3 +281,271 @@ void AbstractSparseDataFlowAnalysis::join(AbstractSparseLattice *lhs,
const AbstractSparseLattice &rhs) {
propagateIfChanged(lhs, lhs->join(rhs));
}
+
+//===----------------------------------------------------------------------===//
+// AbstractSparseBackwardDataFlowAnalysis
+//===----------------------------------------------------------------------===//
+
+AbstractSparseBackwardDataFlowAnalysis::AbstractSparseBackwardDataFlowAnalysis(
+ DataFlowSolver &solver, SymbolTableCollection &symbolTable)
+ : DataFlowAnalysis(solver), symbolTable(symbolTable) {
+ registerPointKind<CFGEdge>();
+}
+
+LogicalResult
+AbstractSparseBackwardDataFlowAnalysis::initialize(Operation *top) {
+ return initializeRecursively(top);
+}
+
+LogicalResult
+AbstractSparseBackwardDataFlowAnalysis::initializeRecursively(Operation *op) {
+ visitOperation(op);
+ for (Region ®ion : op->getRegions()) {
+ for (Block &block : region) {
+ getOrCreate<Executable>(&block)->blockContentSubscribe(this);
+ // Initialize ops in reverse order, so we can do as much initial
+ // propagation as possible without having to go through the
+ // solver queue.
+ for (auto it = block.rbegin(); it != block.rend(); it++)
+ if (failed(initializeRecursively(&*it)))
+ return failure();
+ }
+ }
+ return success();
+}
+
+LogicalResult
+AbstractSparseBackwardDataFlowAnalysis::visit(ProgramPoint point) {
+ if (Operation *op = point.dyn_cast<Operation *>())
+ visitOperation(op);
+ else if (Block *block = point.dyn_cast<Block *>())
+ // For backward dataflow, we don't have to do any work for the blocks
+ // themselves. CFG edges between blocks are processed by the BranchOp
+ // logic in `visitOperation`, and entry blocks for functions are tied
+ // to the CallOp arguments by visitOperation.
+ return success();
+ else
+ return failure();
+ return success();
+}
+
+SmallVector<AbstractSparseLattice *>
+AbstractSparseBackwardDataFlowAnalysis::getLatticeElements(ValueRange values) {
+ SmallVector<AbstractSparseLattice *> resultLattices;
+ resultLattices.reserve(values.size());
+ for (Value result : values) {
+ AbstractSparseLattice *resultLattice = getLatticeElement(result);
+ resultLattices.push_back(resultLattice);
+ }
+ return resultLattices;
+}
+
+SmallVector<const AbstractSparseLattice *>
+AbstractSparseBackwardDataFlowAnalysis::getLatticeElementsFor(
+ ProgramPoint point, ValueRange values) {
+ SmallVector<const AbstractSparseLattice *> resultLattices;
+ resultLattices.reserve(values.size());
+ for (Value result : values) {
+ const AbstractSparseLattice *resultLattice =
+ getLatticeElementFor(point, result);
+ resultLattices.push_back(resultLattice);
+ }
+ return resultLattices;
+}
+
+static MutableArrayRef<OpOperand> operandsToOpOperands(OperandRange &operands) {
+ return MutableArrayRef<OpOperand>(operands.getBase(), operands.size());
+}
+
+void AbstractSparseBackwardDataFlowAnalysis::visitOperation(Operation *op) {
+ // If we're in a dead block, bail out.
+ if (!getOrCreate<Executable>(op->getBlock())->isLive())
+ return;
+
+ SmallVector<AbstractSparseLattice *> operandLattices =
+ getLatticeElements(op->getOperands());
+ SmallVector<const AbstractSparseLattice *> resultLattices =
+ getLatticeElementsFor(op, op->getResults());
+
+ // Block arguments of region branch operations flow back into the operands
+ // of the parent op
+ if (auto branch = dyn_cast<RegionBranchOpInterface>(op)) {
+ visitRegionSuccessors(branch, operandLattices);
+ return;
+ }
+
+ if (auto branch = dyn_cast<BranchOpInterface>(op)) {
+ // Block arguments of successor blocks flow back into our operands.
+
+ // We remember all operands not forwarded to any block in a BitVector.
+ // We can't just cut out a range here, since the non-forwarded ops might
+ // be non-contiguous (if there's more than one successor).
+ BitVector unaccounted(op->getNumOperands(), true);
+
+ for (auto [index, block] : llvm::enumerate(op->getSuccessors())) {
+ SuccessorOperands successorOperands = branch.getSuccessorOperands(index);
+ OperandRange forwarded = successorOperands.getForwardedOperands();
+ if (forwarded.size()) {
+ MutableArrayRef<OpOperand> operands = op->getOpOperands().slice(
+ forwarded.getBeginOperandIndex(), forwarded.size());
+ for (OpOperand &operand : operands) {
+ unaccounted.reset(operand.getOperandNumber());
+ if (Optional<BlockArgument> blockArg =
+ detail::getBranchSuccessorArgument(
+ successorOperands, operand.getOperandNumber(), block)) {
+ meet(getLatticeElement(operand.get()),
+ *getLatticeElementFor(op, *blockArg));
+ }
+ }
+ }
+ }
+ // Operands not forwarded to successor blocks are typically parameters
+ // of the branch operation itself (for example the boolean for if/else).
+ for (int index : unaccounted.set_bits()) {
+ OpOperand &operand = op->getOpOperand(index);
+ visitBranchOperand(operand);
+ }
+ return;
+ }
+
+ // For function calls, connect the arguments of the entry blocks
+ // to the operands of the call op.
+ if (auto call = dyn_cast<CallOpInterface>(op)) {
+ Operation *callableOp = call.resolveCallable(&symbolTable);
+ if (auto callable = dyn_cast_or_null<CallableOpInterface>(callableOp)) {
+ Region *region = callable.getCallableRegion();
+ if (!region->empty()) {
+ Block &block = region->front();
+ for (auto [blockArg, operand] :
+ llvm::zip(block.getArguments(), operandLattices)) {
+ meet(operand, *getLatticeElementFor(op, blockArg));
+ }
+ }
+ return;
+ }
+ }
+
+ // The block arguments of the branched to region flow back into the
+ // operands of the yield operation.
+ if (auto terminator = dyn_cast<RegionBranchTerminatorOpInterface>(op)) {
+ if (auto branch = dyn_cast<RegionBranchOpInterface>(op->getParentOp())) {
+ SmallVector<RegionSuccessor> successors;
+ SmallVector<Attribute> operands(op->getNumOperands(), nullptr);
+ branch.getSuccessorRegions(op->getParentRegion()->getRegionNumber(),
+ operands, successors);
+ // All operands not forwarded to any successor. This set can be
+ // non-contiguous in the presence of multiple successors.
+ BitVector unaccounted(op->getNumOperands(), true);
+
+ for (const RegionSuccessor &successor : successors) {
+ ValueRange inputs = successor.getSuccessorInputs();
+ Region *region = successor.getSuccessor();
+ OperandRange operands =
+ region ? terminator.getSuccessorOperands(region->getRegionNumber())
+ : terminator.getSuccessorOperands({});
+ MutableArrayRef<OpOperand> opoperands = operandsToOpOperands(operands);
+ for (auto [opoperand, input] : llvm::zip(opoperands, inputs)) {
+ meet(getLatticeElement(opoperand.get()),
+ *getLatticeElementFor(op, input));
+ unaccounted.reset(
+ const_cast<OpOperand &>(opoperand).getOperandNumber());
+ }
+ }
+ // Visit operands of the branch op not forwarded to the next region.
+ // (Like e.g. the boolean of `scf.conditional`)
+ for (int index : unaccounted.set_bits()) {
+ visitBranchOperand(op->getOpOperand(index));
+ }
+ return;
+ }
+ }
+
+ // yield-like ops usually don't implement `RegionBranchTerminatorOpInterface`,
+ // since they behave like a return in the sense that they forward to the
+ // results of some other (here: the parent) op.
+ if (op->hasTrait<OpTrait::ReturnLike>()) {
+ if (auto branch = dyn_cast<RegionBranchOpInterface>(op->getParentOp())) {
+ OperandRange operands = op->getOperands();
+ ResultRange results = op->getParentOp()->getResults();
+ assert(results.size() == operands.size() &&
+ "Can't derive arg mapping for yield-like op.");
+ for (auto [operand, result] : llvm::zip(operands, results))
+ meet(getLatticeElement(operand), *getLatticeElementFor(op, result));
+ return;
+ }
+
+ // Going backwards, the operands of the return are derived from the
+ // results of all CallOps calling this CallableOp.
+ if (auto callable = dyn_cast<CallableOpInterface>(op->getParentOp())) {
+ const PredecessorState *callsites =
+ getOrCreateFor<PredecessorState>(op, callable);
+ if (callsites->allPredecessorsKnown()) {
+ for (Operation *call : callsites->getKnownPredecessors()) {
+ SmallVector<const AbstractSparseLattice *> callResultLattices =
+ getLatticeElementsFor(op, call->getResults());
+ for (auto [op, result] :
+ llvm::zip(operandLattices, callResultLattices))
+ meet(op, *result);
+ }
+ } else {
+ // If we don't know all the callers, we can't know where the
+ // returned values go. Note that, in particular, this will trigger
+ // for the return ops of any public functions.
+ setAllToExitStates(operandLattices);
+ }
+ return;
+ }
+ }
+
+ visitOperationImpl(op, operandLattices, resultLattices);
+}
+
+void AbstractSparseBackwardDataFlowAnalysis::visitRegionSuccessors(
+ RegionBranchOpInterface branch,
+ ArrayRef<AbstractSparseLattice *> operandLattices) {
+ Operation *op = branch.getOperation();
+ SmallVector<RegionSuccessor> successors;
+ SmallVector<Attribute> operands(op->getNumOperands(), nullptr);
+ branch.getSuccessorRegions(/*index=*/{}, operands, successors);
+
+ // All operands not forwarded to any successor. This set can be non-contiguous
+ // in the presence of multiple successors.
+ BitVector unaccounted(op->getNumOperands(), true);
+
+ for (RegionSuccessor &successor : successors) {
+ Region *region = successor.getSuccessor();
+ OperandRange operands =
+ region ? branch.getSuccessorEntryOperands(region->getRegionNumber())
+ : branch.getSuccessorEntryOperands({});
+ MutableArrayRef<OpOperand> opoperands = operandsToOpOperands(operands);
+ ValueRange inputs = successor.getSuccessorInputs();
+ for (auto [operand, input] : llvm::zip(opoperands, inputs)) {
+ meet(getLatticeElement(operand.get()), *getLatticeElementFor(op, input));
+ unaccounted.reset(operand.getOperandNumber());
+ }
+ }
+ // All operands not forwarded to regions are typically parameters of the
+ // branch operation itself (for example the boolean for if/else).
+ for (int index : unaccounted.set_bits()) {
+ visitBranchOperand(op->getOpOperand(index));
+ }
+}
+
+const AbstractSparseLattice *
+AbstractSparseBackwardDataFlowAnalysis::getLatticeElementFor(ProgramPoint point,
+ Value value) {
+ AbstractSparseLattice *state = getLatticeElement(value);
+ addDependency(state, point);
+ return state;
+}
+
+void AbstractSparseBackwardDataFlowAnalysis::setAllToExitStates(
+ ArrayRef<AbstractSparseLattice *> lattices) {
+ for (AbstractSparseLattice *lattice : lattices)
+ setToExitState(lattice);
+}
+
+void AbstractSparseBackwardDataFlowAnalysis::meet(
+ AbstractSparseLattice *lhs, const AbstractSparseLattice &rhs) {
+ propagateIfChanged(lhs, lhs->meet(rhs));
+}
diff --git a/mlir/test/Analysis/DataFlow/test-written-to.mlir b/mlir/test/Analysis/DataFlow/test-written-to.mlir
new file mode 100644
index 0000000000000..09c9212fa216c
--- /dev/null
+++ b/mlir/test/Analysis/DataFlow/test-written-to.mlir
@@ -0,0 +1,248 @@
+// RUN: mlir-opt -split-input-file -test-written-to %s 2>&1 | FileCheck %s
+
+// CHECK-LABEL: test_tag: constant0
+// CHECK: result #0: [a]
+// CHECK-LABEL: test_tag: constant1
+// CHECK: result #0: [b]
+func.func @test_two_writes(%m0: memref<i32>, %m1: memref<i32>) -> (memref<i32>, memref<i32>) {
+ %c0 = arith.constant {tag = "constant0"} 0 : i32
+ %c1 = arith.constant {tag = "constant1"} 1 : i32
+ memref.store %c0, %m0[] {tag_name = "a"} : memref<i32>
+ memref.store %c1, %m1[] {tag_name = "b"} : memref<i32>
+ return %m0, %m1 : memref<i32>, memref<i32>
+}
+
+// -----
+
+// CHECK-LABEL: test_tag: c0
+// CHECK: result #0: [b]
+// CHECK-LABEL: test_tag: c1
+// CHECK: result #0: [b]
+// CHECK-LABEL: test_tag: condition
+// CHECK: result #0: [brancharg0]
+// CHECK-LABEL: test_tag: c2
+// CHECK: result #0: [a]
+// CHECK-LABEL: test_tag: c3
+// CHECK: result #0: [a]
+func.func @test_if(%m0: memref<i32>, %m1: memref<i32>, %condition: i1) {
+ %c0 = arith.constant {tag = "c0"} 2 : i32
+ %c1 = arith.constant {tag = "c1"} 3 : i32
+ %condition2 = arith.addi %condition, %condition {tag = "condition"} : i1
+ %0, %1 = scf.if %condition2 -> (i32, i32) {
+ %c2 = arith.constant {tag = "c2"} 0 : i32
+ scf.yield %c2, %c0: i32, i32
+ } else {
+ %c3 = arith.constant {tag = "c3"} 1 : i32
+ scf.yield %c3, %c1: i32, i32
+ }
+ memref.store %0, %m0[] {tag_name = "a"} : memref<i32>
+ memref.store %1, %m1[] {tag_name = "b"} : memref<i32>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: test_tag: c0
+// CHECK: result #0: [a c]
+// CHECK-LABEL: test_tag: c1
+// CHECK: result #0: [b c]
+// CHECK-LABEL: test_tag: br
+// CHECK: operand #0: [brancharg0]
+func.func @test_blocks(%m0: memref<i32>,
+ %m1: memref<i32>,
+ %m2: memref<i32>, %cond : i1) {
+ %0 = arith.constant {tag = "c0"} 0 : i32
+ %1 = arith.constant {tag = "c1"} 1 : i32
+ cf.cond_br %cond, ^a(%0: i32), ^b(%1: i32) {tag = "br"}
+^a(%a0: i32):
+ memref.store %a0, %m0[] {tag_name = "a"} : memref<i32>
+ cf.br ^c(%a0 : i32)
+^b(%b0: i32):
+ memref.store %b0, %m1[] {tag_name = "b"} : memref<i32>
+ cf.br ^c(%b0 : i32)
+^c(%c0 : i32):
+ memref.store %c0, %m2[] {tag_name = "c"} : memref<i32>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: test_tag: two
+// CHECK: result #0: [a]
+func.func @test_infinite_loop(%m0: memref<i32>) {
+ %0 = arith.constant 0 : i32
+ %1 = arith.constant 1 : i32
+ %2 = arith.constant {tag = "two"} 2 : i32
+ %3 = arith.constant -1 : i32
+ cf.br ^loop(%0, %1, %2: i32, i32, i32)
+^loop(%a: i32, %b: i32, %c: i32):
+ memref.store %a, %m0[] {tag_name = "a"} : memref<i32>
+ cf.br ^loop(%b, %c, %3 : i32, i32, i32)
+}
+
+// -----
+
+// CHECK-LABEL: test_tag: c0
+// CHECK: result #0: [a b c]
+func.func @test_switch(%flag: i32, %m0: memref<i32>) {
+ %0 = arith.constant {tag = "c0"} 0 : i32
+ cf.switch %flag : i32, [
+ default: ^a(%0 : i32),
+ 42: ^b(%0 : i32),
+ 43: ^c(%0 : i32)
+ ]
+^a(%a0: i32):
+ memref.store %a0, %m0[] {tag_name = "a"} : memref<i32>
+ cf.br ^c(%a0 : i32)
+^b(%b0: i32):
+ memref.store %b0, %m0[] {tag_name = "b"} : memref<i32>
+ cf.br ^c(%b0 : i32)
+^c(%c0 : i32):
+ memref.store %c0, %m0[] {tag_name = "c"} : memref<i32>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: test_tag: add
+// CHECK: result #0: [a]
+func.func @test_caller(%m0: memref<f32>, %arg: f32) {
+ %0 = arith.addf %arg, %arg {tag = "add"} : f32
+ %1 = func.call @callee(%0) : (f32) -> f32
+ %2 = arith.mulf %1, %1 : f32
+ %3 = arith.mulf %2, %2 : f32
+ %4 = arith.mulf %3, %3 : f32
+ memref.store %4, %m0[] {tag_name = "a"} : memref<f32>
+ return
+}
+
+func.func private @callee(%0 : f32) -> f32 {
+ %1 = arith.mulf %0, %0 : f32
+ %2 = arith.mulf %1, %1 : f32
+ func.return %2 : f32
+}
+
+// -----
+
+func.func private @callee(%0 : f32) -> f32 {
+ %1 = arith.mulf %0, %0 : f32
+ func.return %1 : f32
+}
+
+// CHECK-LABEL: test_tag: sub
+// CHECK: result #0: [a]
+func.func @test_caller_below_callee(%m0: memref<f32>, %arg: f32) {
+ %0 = arith.subf %arg, %arg {tag = "sub"} : f32
+ %1 = func.call @callee(%0) : (f32) -> f32
+ memref.store %1, %m0[] {tag_name = "a"} : memref<f32>
+ return
+}
+
+// -----
+
+func.func private @callee1(%0 : f32) -> f32 {
+ %1 = func.call @callee2(%0) : (f32) -> f32
+ func.return %1 : f32
+}
+
+func.func private @callee2(%0 : f32) -> f32 {
+ %1 = func.call @callee3(%0) : (f32) -> f32
+ func.return %1 : f32
+}
+
+func.func private @callee3(%0 : f32) -> f32 {
+ func.return %0 : f32
+}
+
+// CHECK-LABEL: test_tag: mul
+// CHECK: result #0: [a]
+func.func @test_callchain(%m0: memref<f32>, %arg: f32) {
+ %0 = arith.mulf %arg, %arg {tag = "mul"} : f32
+ %1 = func.call @callee1(%0) : (f32) -> f32
+ memref.store %1, %m0[] {tag_name = "a"} : memref<f32>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: test_tag: zero
+// CHECK: result #0: [c]
+// CHECK-LABEL: test_tag: init
+// CHECK: result #0: [a b]
+// CHECK-LABEL: test_tag: condition
+// CHECK: operand #0: [brancharg0]
+func.func @test_while(%m0: memref<i32>, %init : i32, %cond: i1) {
+ %zero = arith.constant {tag = "zero"} 0 : i32
+ %init2 = arith.addi %init, %init {tag = "init"} : i32
+ %0, %1 = scf.while (%arg1 = %zero, %arg2 = %init2) : (i32, i32) -> (i32, i32) {
+ memref.store %arg2, %m0[] {tag_name = "a"} : memref<i32>
+ scf.condition(%cond) {tag = "condition"} %arg1, %arg2 : i32, i32
+ } do {
+ ^bb0(%arg1: i32, %arg2: i32):
+ memref.store %arg1, %m0[] {tag_name = "c"} : memref<i32>
+ %res = arith.addi %arg2, %arg2 : i32
+ scf.yield %arg1, %res: i32, i32
+ }
+ memref.store %1, %m0[] {tag_name = "b"} : memref<i32>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: test_tag: zero
+// CHECK: result #0: [brancharg0]
+// CHECK-LABEL: test_tag: ten
+// CHECK: result #0: [brancharg1]
+// CHECK-LABEL: test_tag: one
+// CHECK: result #0: [brancharg2]
+// CHECK-LABEL: test_tag: x
+// CHECK: result #0: [a]
+func.func @test_for(%m0: memref<i32>) {
+ %zero = arith.constant {tag = "zero"} 0 : index
+ %ten = arith.constant {tag = "ten"} 10 : index
+ %one = arith.constant {tag = "one"} 1 : index
+ %x = arith.constant {tag = "x"} 0 : i32
+ %0 = scf.for %i = %zero to %ten step %one iter_args(%ix = %x) -> (i32) {
+ scf.yield %ix : i32
+ }
+ memref.store %0, %m0[] {tag_name = "a"} : memref<i32>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: test_tag: default_a
+// CHECK-LABEL: result #0: [a]
+// CHECK-LABEL: test_tag: default_b
+// CHECK-LABEL: result #0: [b]
+// CHECK-LABEL: test_tag: 1a
+// CHECK-LABEL: result #0: [a]
+// CHECK-LABEL: test_tag: 1b
+// CHECK-LABEL: result #0: [b]
+// CHECK-LABEL: test_tag: 2a
+// CHECK-LABEL: result #0: [a]
+// CHECK-LABEL: test_tag: 2b
+// CHECK-LABEL: result #0: [b]
+// CHECK-LABEL: test_tag: switch
+// CHECK-LABEL: operand #0: [brancharg0]
+func.func @test_switch(%arg0 : index, %m0: memref<i32>) {
+ %0, %1 = scf.index_switch %arg0 {tag="switch"} -> i32, i32
+ case 1 {
+ %2 = arith.constant {tag="1a"} 10 : i32
+ %3 = arith.constant {tag="1b"} 100 : i32
+ scf.yield %2, %3 : i32, i32
+ }
+ case 2 {
+ %4 = arith.constant {tag="2a"} 20 : i32
+ %5 = arith.constant {tag="2b"} 200 : i32
+ scf.yield %4, %5 : i32, i32
+ }
+ default {
+ %6 = arith.constant {tag="default_a"} 30 : i32
+ %7 = arith.constant {tag="default_b"} 300 : i32
+ scf.yield %6, %7 : i32, i32
+ }
+ memref.store %0, %m0[] {tag_name = "a"} : memref<i32>
+ memref.store %1, %m0[] {tag_name = "b"} : memref<i32>
+ return
+}
diff --git a/mlir/test/lib/Analysis/CMakeLists.txt b/mlir/test/lib/Analysis/CMakeLists.txt
index 9d3761dec08f2..d83a8d5c070bc 100644
--- a/mlir/test/lib/Analysis/CMakeLists.txt
+++ b/mlir/test/lib/Analysis/CMakeLists.txt
@@ -12,6 +12,7 @@ add_mlir_library(MLIRTestAnalysis
DataFlow/TestDeadCodeAnalysis.cpp
DataFlow/TestDenseDataFlowAnalysis.cpp
+ DataFlow/TestBackwardDataFlowAnalysis.cpp
EXCLUDE_FROM_LIBMLIR
diff --git a/mlir/test/lib/Analysis/DataFlow/TestBackwardDataFlowAnalysis.cpp b/mlir/test/lib/Analysis/DataFlow/TestBackwardDataFlowAnalysis.cpp
new file mode 100644
index 0000000000000..9579052c9940a
--- /dev/null
+++ b/mlir/test/lib/Analysis/DataFlow/TestBackwardDataFlowAnalysis.cpp
@@ -0,0 +1,142 @@
+//===- TestBackwardDataFlowAnalysis.cpp - Test dead code analysis ---------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h"
+#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
+#include "mlir/Analysis/DataFlow/SparseAnalysis.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Interfaces/SideEffectInterfaces.h"
+#include "mlir/Pass/Pass.h"
+
+using namespace mlir;
+using namespace mlir::dataflow;
+
+namespace {
+
+/// This lattice represents, for a given value, the set of memory resources that
+/// this value, or anything derived from this value, is potentially written to.
+struct WrittenTo : public AbstractSparseLattice {
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(WrittenTo)
+ using AbstractSparseLattice::AbstractSparseLattice;
+
+ void print(raw_ostream &os) const override {
+ os << "[";
+ llvm::interleave(
+ writes, os, [&](const StringAttr &a) { os << a.str(); }, " ");
+ os << "]";
+ }
+ ChangeResult addWrites(const SetVector<StringAttr> &writes) {
+ int size_before = this->writes.size();
+ this->writes.insert(writes.begin(), writes.end());
+ int size_after = this->writes.size();
+ return size_before == size_after ? ChangeResult::NoChange
+ : ChangeResult::Change;
+ }
+ ChangeResult meet(const AbstractSparseLattice &other) override {
+ auto rhs = reinterpret_cast<const WrittenTo *>(&other);
+ return addWrites(rhs->writes);
+ }
+
+ SetVector<StringAttr> writes;
+};
+
+/// An analysis that, by going backwards along the dataflow graph, annotates
+/// each value with all the memory resources it (or anything derived from it)
+/// is eventually written to.
+class WrittenToAnalysis : public SparseBackwardDataFlowAnalysis<WrittenTo> {
+public:
+ using SparseBackwardDataFlowAnalysis::SparseBackwardDataFlowAnalysis;
+
+ void visitOperation(Operation *op, ArrayRef<WrittenTo *> operands,
+ ArrayRef<const WrittenTo *> results) override;
+
+ void visitBranchOperand(OpOperand &operand) override;
+
+ void setToExitState(WrittenTo *lattice) override { lattice->writes.clear(); }
+};
+
+void WrittenToAnalysis::visitOperation(Operation *op,
+ ArrayRef<WrittenTo *> operands,
+ ArrayRef<const WrittenTo *> results) {
+ if (auto store = dyn_cast<memref::StoreOp>(op)) {
+ SetVector<StringAttr> new_writes;
+ new_writes.insert(op->getAttrOfType<StringAttr>("tag_name"));
+ propagateIfChanged(operands[0], operands[0]->addWrites(new_writes));
+ return;
+ } else {
+ // By default, every result of an op depends on every operand.
+ for (const WrittenTo *r : results) {
+ for (WrittenTo *operand : operands) {
+ meet(operand, *r);
+ }
+ addDependency(const_cast<WrittenTo *>(r), op);
+ }
+ }
+}
+
+void WrittenToAnalysis::visitBranchOperand(OpOperand &operand) {
+ // Mark branch operands as "brancharg%d", with %d the operand number.
+ WrittenTo *lattice = getLatticeElement(operand.get());
+ SetVector<StringAttr> new_writes;
+ new_writes.insert(
+ StringAttr::get(operand.getOwner()->getContext(),
+ "brancharg" + Twine(operand.getOperandNumber())));
+ propagateIfChanged(lattice, lattice->addWrites(new_writes));
+}
+
+} // end anonymous namespace
+
+namespace {
+struct TestWrittenToPass
+ : public PassWrapper<TestWrittenToPass, OperationPass<>> {
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestWrittenToPass)
+
+ StringRef getArgument() const override { return "test-written-to"; }
+
+ void runOnOperation() override {
+ Operation *op = getOperation();
+
+ SymbolTableCollection symbolTable;
+
+ DataFlowSolver solver;
+ solver.load<DeadCodeAnalysis>();
+ solver.load<SparseConstantPropagation>();
+ solver.load<WrittenToAnalysis>(symbolTable);
+ if (failed(solver.initializeAndRun(op)))
+ return signalPassFailure();
+
+ raw_ostream &os = llvm::outs();
+ op->walk([&](Operation *op) {
+ auto tag = op->getAttrOfType<StringAttr>("tag");
+ if (!tag)
+ return;
+ os << "test_tag: " << tag.getValue() << ":\n";
+ for (auto [index, operand] : llvm::enumerate(op->getOperands())) {
+ const WrittenTo *writtenTo = solver.lookupState<WrittenTo>(operand);
+ assert(writtenTo && "expected a sparse lattice");
+ os << " operand #" << index << ": ";
+ writtenTo->print(os);
+ os << "\n";
+ }
+ for (auto [index, operand] : llvm::enumerate(op->getResults())) {
+ const WrittenTo *writtenTo = solver.lookupState<WrittenTo>(operand);
+ assert(writtenTo && "expected a sparse lattice");
+ os << " result #" << index << ": ";
+ writtenTo->print(os);
+ os << "\n";
+ }
+ });
+ }
+};
+} // end anonymous namespace
+
+namespace mlir {
+namespace test {
+void registerTestWrittenToPass() { PassRegistration<TestWrittenToPass>(); }
+} // end namespace test
+} // end namespace mlir
diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index e9200b7bf9724..622b9c945b6d5 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -120,6 +120,7 @@ void registerTestTilingInterface();
void registerTestTopologicalSortAnalysisPass();
void registerTestTransformDialectEraseSchedulePass();
void registerTestTransformDialectInterpreterPass();
+void registerTestWrittenToPass();
void registerTestVectorLowerings();
void registerTestNvgpuLowerings();
} // namespace test
@@ -227,6 +228,7 @@ void registerTestPasses() {
mlir::test::registerTestTransformDialectInterpreterPass();
mlir::test::registerTestVectorLowerings();
mlir::test::registerTestNvgpuLowerings();
+ mlir::test::registerTestWrittenToPass();
}
#endif
More information about the Mlir-commits
mailing list