[Mlir-commits] [mlir] d80c271 - [mlir] An implementation of dense data-flow analysis
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Jul 7 15:12:52 PDT 2022
Author: Mogball
Date: 2022-07-07T15:12:46-07:00
New Revision: d80c271c8ac0703b5fe9ba40e55121d6dd25b389
URL: https://github.com/llvm/llvm-project/commit/d80c271c8ac0703b5fe9ba40e55121d6dd25b389
DIFF: https://github.com/llvm/llvm-project/commit/d80c271c8ac0703b5fe9ba40e55121d6dd25b389.diff
LOG: [mlir] An implementation of dense data-flow analysis
This patch introduces an implementation of dense data-flow analysis. Dense
data-flow analysis attaches a lattice before and after the execution of every
operation. The lattice state is propagated across operations by a user-defined
transfer function. The state is joined across control-flow and callgraph edges.
Thge patch provides an example pass that uses both a dense and a sparse analysis
together.
Depends on D127139
Reviewed By: rriddle, phisiart
Differential Revision: https://reviews.llvm.org/D127173
Added:
mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h
mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp
mlir/test/Analysis/DataFlow/test-last-modified-callgraph.mlir
mlir/test/Analysis/DataFlow/test-last-modified.mlir
mlir/test/lib/Analysis/DataFlow/TestDenseDataFlowAnalysis.cpp
Modified:
mlir/include/mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h
mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
mlir/lib/Analysis/CMakeLists.txt
mlir/test/lib/Analysis/CMakeLists.txt
mlir/tools/mlir-opt/mlir-opt.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h
index b07994755c03..145170f74e54 100644
--- a/mlir/include/mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h
+++ b/mlir/include/mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h
@@ -43,6 +43,12 @@ class ConstantValue {
return constant == rhs.constant;
}
+ /// Print the constant value.
+ void print(raw_ostream &os) const;
+
+ /// The pessimistic value state of the constant value is unknown.
+ static ConstantValue getPessimisticValueState(Value value) { return {}; }
+
/// The union with another constant value is null if they are
diff erent, and
/// the same if they are the same.
static ConstantValue join(const ConstantValue &lhs,
@@ -50,9 +56,6 @@ class ConstantValue {
return lhs == rhs ? lhs : ConstantValue();
}
- /// Print the constant value.
- void print(raw_ostream &os) const;
-
private:
/// The constant value.
Attribute constant;
diff --git a/mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h
new file mode 100644
index 000000000000..43432c758824
--- /dev/null
+++ b/mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h
@@ -0,0 +1,167 @@
+//===- DenseAnalysis.h - Dense data-flow analysis -------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements dense data-flow analysis using the data-flow analysis
+// framework. The analysis is forward and conditional and uses the results of
+// dead code analysis to prune dead code during the analysis.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_ANALYSIS_DENSEDATAFLOWANALYSIS_H
+#define MLIR_ANALYSIS_DENSEDATAFLOWANALYSIS_H
+
+#include "mlir/Analysis/DataFlowFramework.h"
+
+namespace mlir {
+
+class RegionBranchOpInterface;
+
+namespace dataflow {
+
+//===----------------------------------------------------------------------===//
+// AbstractDenseLattice
+//===----------------------------------------------------------------------===//
+
+/// This class represents a dense lattice. A dense lattice is attached to
+/// operations to represent the program state after their execution or to blocks
+/// to represent the program state at the beginning of the block. A dense
+/// lattice is propagated through the IR by dense data-flow analysis.
+class AbstractDenseLattice : public AnalysisState {
+public:
+ /// A dense lattice can only be created for operations and blocks.
+ using AnalysisState::AnalysisState;
+
+ /// Join the lattice across control-flow or callgraph edges.
+ virtual ChangeResult join(const AbstractDenseLattice &rhs) = 0;
+
+ /// Reset the dense lattice to a pessimistic value. This occurs when the
+ /// analysis cannot reason about the data-flow.
+ virtual ChangeResult reset() = 0;
+
+ /// Returns true if the lattice state has reached a pessimistic fixpoint. That
+ /// is, no further modifications to the lattice can occur.
+ virtual bool isAtFixpoint() const = 0;
+};
+
+//===----------------------------------------------------------------------===//
+// AbstractDenseDataFlowAnalysis
+//===----------------------------------------------------------------------===//
+
+/// Base class for dense data-flow analyses. Dense data-flow analysis attaches a
+/// lattice between the execution of operations and implements a transfer
+/// function from the lattice before each operation to the lattice after. The
+/// lattice contains information about the state of the program at that point.
+///
+/// In this implementation, a lattice attached to an operation represents the
+/// state of the program after its execution, and a lattice attached to block
+/// represents the state of the program right before it starts executing its
+/// body.
+class AbstractDenseDataFlowAnalysis : public DataFlowAnalysis {
+public:
+ using DataFlowAnalysis::DataFlowAnalysis;
+
+ /// Initialize the analysis by visiting every program point whose execution
+ /// may modify the program state; that is, every operation and block.
+ LogicalResult initialize(Operation *top) override;
+
+ /// Visit a program point that modifies the state of the program. If this is a
+ /// block, then the state is propagated from control-flow predecessors or
+ /// callsites. If this is a call operation or region control-flow operation,
+ /// then the state after the execution of the operation is set by control-flow
+ /// or the callgraph. Otherwise, this function invokes the operation transfer
+ /// function.
+ LogicalResult visit(ProgramPoint point) override;
+
+protected:
+ /// Propagate the dense lattice before the execution of an operation to the
+ /// lattice after its execution.
+ virtual void visitOperationImpl(Operation *op,
+ const AbstractDenseLattice &before,
+ AbstractDenseLattice *after) = 0;
+
+ /// Get the dense lattice after the execution of the given program point.
+ virtual AbstractDenseLattice *getLattice(ProgramPoint point) = 0;
+
+ /// Get the dense lattice after the execution of the given program point and
+ /// add it as a dependency to a program point.
+ const AbstractDenseLattice *getLatticeFor(ProgramPoint dependent,
+ ProgramPoint point);
+
+ /// Mark the dense lattice as having reached its pessimistic fixpoint and
+ /// propagate an update if it changed.
+ void reset(AbstractDenseLattice *lattice) {
+ propagateIfChanged(lattice, lattice->reset());
+ }
+
+ /// Join a lattice with another and propagate an update if it changed.
+ void join(AbstractDenseLattice *lhs, const AbstractDenseLattice &rhs) {
+ propagateIfChanged(lhs, lhs->join(rhs));
+ }
+
+private:
+ /// Visit an operation. If this is a call operation or region control-flow
+ /// operation, then the state after the execution of the operation is set by
+ /// control-flow or the callgraph. Otherwise, this function invokes the
+ /// operation transfer function.
+ void visitOperation(Operation *op);
+
+ /// Visit a block. The state at the start of the block is propagated from
+ /// control-flow predecessors or callsites
+ void visitBlock(Block *block);
+
+ /// Visit a program point within a region branch operation with predecessors
+ /// in it. This can either be an entry block of one of the regions of the
+ /// parent operation itself.
+ void visitRegionBranchOperation(ProgramPoint point,
+ RegionBranchOpInterface branch,
+ AbstractDenseLattice *after);
+};
+
+//===----------------------------------------------------------------------===//
+// DenseDataFlowAnalysis
+//===----------------------------------------------------------------------===//
+
+/// A dense (forward) data-flow analysis for propagating lattices before and
+/// after the execution of every operation across the IR by implementing
+/// transfer functions for operations.
+///
+/// `StateT` is expected to be a subclass of `AbstractDenseLattice`.
+template <typename LatticeT>
+class DenseDataFlowAnalysis : public AbstractDenseDataFlowAnalysis {
+ static_assert(
+ std::is_base_of<AbstractDenseLattice, LatticeT>::value,
+ "analysis state class expected to subclass AbstractDenseLattice");
+
+public:
+ using AbstractDenseDataFlowAnalysis::AbstractDenseDataFlowAnalysis;
+
+ /// Visit an operation with the dense lattice before its execution. This
+ /// function is expected to set the dense lattice after its execution.
+ virtual void visitOperation(Operation *op, const LatticeT &before,
+ LatticeT *after) = 0;
+
+protected:
+ /// Get the dense lattice after this program point.
+ LatticeT *getLattice(ProgramPoint point) override {
+ return getOrCreate<LatticeT>(point);
+ }
+
+private:
+ /// Type-erased wrappers that convert the abstract dense lattice to a derived
+ /// lattice and invoke the virtual hooks operating on the derived lattice.
+ void visitOperationImpl(Operation *op, const AbstractDenseLattice &before,
+ AbstractDenseLattice *after) override {
+ visitOperation(op, static_cast<const LatticeT &>(before),
+ static_cast<LatticeT *>(after));
+ }
+};
+
+} // end namespace dataflow
+} // end namespace mlir
+
+#endif // MLIR_ANALYSIS_DENSEDATAFLOWANALYSIS_H
diff --git a/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
index 5907da0ef8da..6456f7d6cec2 100644
--- a/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
+++ b/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
@@ -16,10 +16,12 @@
#define MLIR_ANALYSIS_DATAFLOW_SPARSEANALYSIS_H
#include "mlir/Analysis/DataFlowFramework.h"
-#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "llvm/ADT/SmallPtrSet.h"
namespace mlir {
+
+class RegionBranchOpInterface;
+
namespace dataflow {
//===----------------------------------------------------------------------===//
@@ -80,11 +82,10 @@ class AbstractSparseLattice : public AnalysisState {
template <typename ValueT>
class Lattice : public AbstractSparseLattice {
public:
- using AbstractSparseLattice::AbstractSparseLattice;
-
- /// Get a lattice element with a known value.
- Lattice(const ValueT &knownValue = ValueT())
- : AbstractSparseLattice(Value()), knownValue(knownValue) {}
+ /// Construct a lattice with a known value.
+ explicit Lattice(Value value)
+ : AbstractSparseLattice(value),
+ knownValue(ValueT::getPessimisticValueState(value)) {}
/// Return the value held by this lattice. This requires that the value is
/// initialized.
diff --git a/mlir/lib/Analysis/CMakeLists.txt b/mlir/lib/Analysis/CMakeLists.txt
index 2e37135adac4..efac97d665e7 100644
--- a/mlir/lib/Analysis/CMakeLists.txt
+++ b/mlir/lib/Analysis/CMakeLists.txt
@@ -12,6 +12,7 @@ set(LLVM_OPTIONAL_SOURCES
DataFlow/ConstantPropagationAnalysis.cpp
DataFlow/DeadCodeAnalysis.cpp
+ DataFlow/DenseAnalysis.cpp
DataFlow/SparseAnalysis.cpp
)
@@ -30,6 +31,7 @@ add_mlir_library(MLIRAnalysis
DataFlow/ConstantPropagationAnalysis.cpp
DataFlow/DeadCodeAnalysis.cpp
+ DataFlow/DenseAnalysis.cpp
DataFlow/SparseAnalysis.cpp
ADDITIONAL_HEADER_DIRS
diff --git a/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp b/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp
new file mode 100644
index 000000000000..4c0a16684ced
--- /dev/null
+++ b/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp
@@ -0,0 +1,172 @@
+//===- DenseAnalysis.cpp - Dense data-flow analysis -----------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Analysis/DataFlow/DenseAnalysis.h"
+#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
+#include "mlir/Interfaces/CallInterfaces.h"
+#include "mlir/Interfaces/ControlFlowInterfaces.h"
+
+using namespace mlir;
+using namespace mlir::dataflow;
+
+//===----------------------------------------------------------------------===//
+// AbstractDenseDataFlowAnalysis
+//===----------------------------------------------------------------------===//
+
+LogicalResult AbstractDenseDataFlowAnalysis::initialize(Operation *top) {
+ // Visit every operation and block.
+ visitOperation(top);
+ for (Region ®ion : top->getRegions()) {
+ for (Block &block : region) {
+ visitBlock(&block);
+ for (Operation &op : block)
+ if (failed(initialize(&op)))
+ return failure();
+ }
+ }
+ return success();
+}
+
+LogicalResult AbstractDenseDataFlowAnalysis::visit(ProgramPoint point) {
+ if (auto *op = point.dyn_cast<Operation *>())
+ visitOperation(op);
+ else if (auto *block = point.dyn_cast<Block *>())
+ visitBlock(block);
+ else
+ return failure();
+ return success();
+}
+
+void AbstractDenseDataFlowAnalysis::visitOperation(Operation *op) {
+ // If the containing block is not executable, bail out.
+ if (!getOrCreateFor<Executable>(op, op->getBlock())->isLive())
+ return;
+
+ // Get the dense lattice to update.
+ AbstractDenseLattice *after = getLattice(op);
+ if (after->isAtFixpoint())
+ return;
+
+ // If this op implements region control-flow, then control-flow dictates its
+ // transfer function.
+ if (auto branch = dyn_cast<RegionBranchOpInterface>(op))
+ return visitRegionBranchOperation(op, branch, after);
+
+ // If this is a call operation, then join its lattices across known return
+ // sites.
+ if (auto call = dyn_cast<CallOpInterface>(op)) {
+ const auto *predecessors = getOrCreateFor<PredecessorState>(op, call);
+ // If not all return sites are known, then conservatively assume we can't
+ // reason about the data-flow.
+ if (!predecessors->allPredecessorsKnown())
+ return reset(after);
+ for (Operation *predecessor : predecessors->getKnownPredecessors())
+ join(after, *getLatticeFor(op, predecessor));
+ return;
+ }
+
+ // Get the dense state before the execution of the op.
+ const AbstractDenseLattice *before;
+ if (Operation *prev = op->getPrevNode())
+ before = getLatticeFor(op, prev);
+ else
+ before = getLatticeFor(op, op->getBlock());
+ // If the incoming lattice is uninitialized, bail out.
+ if (before->isUninitialized())
+ return;
+
+ // Invoke the operation transfer function.
+ visitOperationImpl(op, *before, after);
+}
+
+void AbstractDenseDataFlowAnalysis::visitBlock(Block *block) {
+ // If the block is not executable, bail out.
+ if (!getOrCreateFor<Executable>(block, block)->isLive())
+ return;
+
+ // Get the dense lattice to update.
+ AbstractDenseLattice *after = getLattice(block);
+ if (after->isAtFixpoint())
+ return;
+
+ // The dense lattices of entry blocks are set by region control-flow or the
+ // callgraph.
+ if (block->isEntryBlock()) {
+ // Check if this block is the entry block of a callable region.
+ auto callable = dyn_cast<CallableOpInterface>(block->getParentOp());
+ if (callable && callable.getCallableRegion() == block->getParent()) {
+ const auto *callsites = getOrCreateFor<PredecessorState>(block, callable);
+ // If not all callsites are known, conservatively mark all lattices as
+ // having reached their pessimistic fixpoints.
+ if (!callsites->allPredecessorsKnown())
+ return reset(after);
+ for (Operation *callsite : callsites->getKnownPredecessors()) {
+ // Get the dense lattice before the callsite.
+ if (Operation *prev = callsite->getPrevNode())
+ join(after, *getLatticeFor(block, prev));
+ else
+ join(after, *getLatticeFor(block, callsite->getBlock()));
+ }
+ return;
+ }
+
+ // Check if we can reason about the control-flow.
+ if (auto branch = dyn_cast<RegionBranchOpInterface>(block->getParentOp()))
+ return visitRegionBranchOperation(block, branch, after);
+
+ // Otherwise, we can't reason about the data-flow.
+ return reset(after);
+ }
+
+ // Join the state with the state after the block's predecessors.
+ for (Block::pred_iterator it = block->pred_begin(), e = block->pred_end();
+ it != e; ++it) {
+ // Skip control edges that aren't executable.
+ Block *predecessor = *it;
+ if (!getOrCreateFor<Executable>(
+ block, getProgramPoint<CFGEdge>(predecessor, block))
+ ->isLive())
+ continue;
+
+ // Merge in the state from the predecessor's terminator.
+ join(after, *getLatticeFor(block, predecessor->getTerminator()));
+ }
+}
+
+void AbstractDenseDataFlowAnalysis::visitRegionBranchOperation(
+ ProgramPoint point, RegionBranchOpInterface branch,
+ AbstractDenseLattice *after) {
+ // Get the terminator predecessors.
+ const auto *predecessors = getOrCreateFor<PredecessorState>(point, point);
+ assert(predecessors->allPredecessorsKnown() &&
+ "unexpected unresolved region successors");
+
+ for (Operation *op : predecessors->getKnownPredecessors()) {
+ const AbstractDenseLattice *before;
+ // If the predecessor is the parent, get the state before the parent.
+ if (op == branch) {
+ if (Operation *prev = op->getPrevNode())
+ before = getLatticeFor(point, prev);
+ else
+ before = getLatticeFor(point, op->getBlock());
+
+ // Otherwise, get the state after the terminator.
+ } else {
+ before = getLatticeFor(point, op);
+ }
+ join(after, *before);
+ }
+}
+
+const AbstractDenseLattice *
+AbstractDenseDataFlowAnalysis::getLatticeFor(ProgramPoint dependent,
+ ProgramPoint point) {
+ AbstractDenseLattice *state = getLattice(point);
+ addDependency(state, dependent);
+ return state;
+}
diff --git a/mlir/test/Analysis/DataFlow/test-last-modified-callgraph.mlir b/mlir/test/Analysis/DataFlow/test-last-modified-callgraph.mlir
new file mode 100644
index 000000000000..c1fdf82e4bc7
--- /dev/null
+++ b/mlir/test/Analysis/DataFlow/test-last-modified-callgraph.mlir
@@ -0,0 +1,67 @@
+// RUN: mlir-opt -test-last-modified %s 2>&1 | FileCheck %s
+
+// CHECK-LABEL: test_tag: test_callsite
+// CHECK: operand #0
+// CHECK-NEXT: - a
+func.func private @single_callsite_fn(%ptr: memref<i32>) -> memref<i32> {
+ return {tag = "test_callsite"} %ptr : memref<i32>
+}
+
+func.func @test_callsite() {
+ %ptr = memref.alloc() : memref<i32>
+ %c0 = arith.constant 0 : i32
+ memref.store %c0, %ptr[] {tag_name = "a"} : memref<i32>
+ %0 = func.call @single_callsite_fn(%ptr) : (memref<i32>) -> memref<i32>
+ return
+}
+
+// CHECK-LABEL: test_tag: test_return_site
+// CHECK: operand #0
+// CHECK-NEXT: - b
+func.func private @single_return_site_fn(%ptr: memref<i32>) -> memref<i32> {
+ %c0 = arith.constant 0 : i32
+ memref.store %c0, %ptr[] {tag_name = "b"} : memref<i32>
+ return %ptr : memref<i32>
+}
+
+// CHECK-LABEL: test_tag: test_multiple_callsites
+// CHECK: operand #0
+// CHECK-NEXT: write0
+// CHECK-NEXT: write1
+func.func @test_return_site(%ptr: memref<i32>) -> memref<i32> {
+ %0 = func.call @single_return_site_fn(%ptr) : (memref<i32>) -> memref<i32>
+ return {tag = "test_return_site"} %0 : memref<i32>
+}
+
+func.func private @multiple_callsite_fn(%ptr: memref<i32>) -> memref<i32> {
+ return {tag = "test_multiple_callsites"} %ptr : memref<i32>
+}
+
+func.func @test_multiple_callsites(%a: i32, %ptr: memref<i32>) -> memref<i32> {
+ memref.store %a, %ptr[] {tag_name = "write0"} : memref<i32>
+ %0 = func.call @multiple_callsite_fn(%ptr) : (memref<i32>) -> memref<i32>
+ memref.store %a, %ptr[] {tag_name = "write1"} : memref<i32>
+ %1 = func.call @multiple_callsite_fn(%ptr) : (memref<i32>) -> memref<i32>
+ return %ptr : memref<i32>
+}
+
+// CHECK-LABEL: test_tag: test_multiple_return_sites
+// CHECK: operand #0
+// CHECK-NEXT: return0
+// CHECK-NEXT: return1
+func.func private @multiple_return_site_fn(%cond: i1, %a: i32, %ptr: memref<i32>) -> memref<i32> {
+ cf.cond_br %cond, ^a, ^b
+
+^a:
+ memref.store %a, %ptr[] {tag_name = "return0"} : memref<i32>
+ return %ptr : memref<i32>
+
+^b:
+ memref.store %a, %ptr[] {tag_name = "return1"} : memref<i32>
+ return %ptr : memref<i32>
+}
+
+func.func @test_multiple_return_sites(%cond: i1, %a: i32, %ptr: memref<i32>) -> memref<i32> {
+ %0 = func.call @multiple_return_site_fn(%cond, %a, %ptr) : (i1, i32, memref<i32>) -> memref<i32>
+ return {tag = "test_multiple_return_sites"} %0 : memref<i32>
+}
\ No newline at end of file
diff --git a/mlir/test/Analysis/DataFlow/test-last-modified.mlir b/mlir/test/Analysis/DataFlow/test-last-modified.mlir
new file mode 100644
index 000000000000..69fb7125f0c5
--- /dev/null
+++ b/mlir/test/Analysis/DataFlow/test-last-modified.mlir
@@ -0,0 +1,115 @@
+// RUN: mlir-opt -test-last-modified %s 2>&1 | FileCheck %s
+
+// CHECK-LABEL: test_tag: test_simple_mod
+// CHECK: operand #0
+// CHECK-NEXT: - a
+// CHECK: operand #1
+// CHECK-NEXT: - b
+func.func @test_simple_mod(%arg0: memref<i32>, %arg1: memref<i32>) -> (memref<i32>, memref<i32>) {
+ %c0 = arith.constant 0 : i32
+ %c1 = arith.constant 1 : i32
+ memref.store %c0, %arg0[] {tag_name = "a"} : memref<i32>
+ memref.store %c1, %arg1[] {tag_name = "b"} : memref<i32>
+ return {tag = "test_simple_mod"} %arg0, %arg1 : memref<i32>, memref<i32>
+}
+
+// CHECK-LABEL: test_tag: test_simple_mod_overwrite_a
+// CHECK: operand #1
+// CHECK-NEXT: - a
+// CHECK-LABEL: test_tag: test_simple_mod_overwrite_b
+// CHECK: operand #0
+// CHECK-NEXT: - b
+func.func @test_simple_mod_overwrite(%arg0: memref<i32>) -> memref<i32> {
+ %c0 = arith.constant 0 : i32
+ memref.store %c0, %arg0[] {tag = "test_simple_mod_overwrite_a", tag_name = "a"} : memref<i32>
+ %c1 = arith.constant 1 : i32
+ memref.store %c1, %arg0[] {tag_name = "b"} : memref<i32>
+ return {tag = "test_simple_mod_overwrite_b"} %arg0 : memref<i32>
+}
+
+// CHECK-LABEL: test_tag: test_mod_control_flow
+// CHECK: operand #0
+// CHECK-NEXT: - b
+// CHECK-NEXT: - a
+func.func @test_mod_control_flow(%cond: i1, %ptr: memref<i32>) -> memref<i32> {
+ cf.cond_br %cond, ^a, ^b
+
+^a:
+ %c0 = arith.constant 0 : i32
+ memref.store %c0, %ptr[] {tag_name = "a"} : memref<i32>
+ cf.br ^c
+
+^b:
+ %c1 = arith.constant 1 : i32
+ memref.store %c1, %ptr[] {tag_name = "b"} : memref<i32>
+ cf.br ^c
+
+^c:
+ return {tag = "test_mod_control_flow"} %ptr : memref<i32>
+}
+
+// CHECK-LABEL: test_tag: test_mod_dead_branch
+// CHECK: operand #0
+// CHECK-NEXT: - a
+func.func @test_mod_dead_branch(%arg: i32, %ptr: memref<i32>) -> memref<i32> {
+ %0 = arith.subi %arg, %arg : i32
+ %1 = arith.constant -1 : i32
+ %2 = arith.cmpi sgt, %0, %1 : i32
+ cf.cond_br %2, ^a, ^b
+
+^a:
+ %c0 = arith.constant 0 : i32
+ memref.store %c0, %ptr[] {tag_name = "a"} : memref<i32>
+ cf.br ^c
+
+^b:
+ %c1 = arith.constant 1 : i32
+ memref.store %c1, %ptr[] {tag_name = "b"} : memref<i32>
+ cf.br ^c
+
+^c:
+ return {tag = "test_mod_dead_branch"} %ptr : memref<i32>
+}
+
+// CHECK-LABEL: test_tag: test_mod_region_control_flow
+// CHECK: operand #0
+// CHECK-NEXT: then
+// CHECK-NEXT: else
+func.func @test_mod_region_control_flow(%cond: i1, %ptr: memref<i32>) -> memref<i32> {
+ scf.if %cond {
+ %c0 = arith.constant 0 : i32
+ memref.store %c0, %ptr[] {tag_name = "then"}: memref<i32>
+ } else {
+ %c1 = arith.constant 1 : i32
+ memref.store %c1, %ptr[] {tag_name = "else"} : memref<i32>
+ }
+ return {tag = "test_mod_region_control_flow"} %ptr : memref<i32>
+}
+
+// CHECK-LABEL: test_tag: test_mod_dead_region
+// CHECK: operand #0
+// CHECK-NEXT: else
+func.func @test_mod_dead_region(%ptr: memref<i32>) -> memref<i32> {
+ %false = arith.constant false
+ scf.if %false {
+ %c0 = arith.constant 0 : i32
+ memref.store %c0, %ptr[] {tag_name = "then"}: memref<i32>
+ } else {
+ %c1 = arith.constant 1 : i32
+ memref.store %c1, %ptr[] {tag_name = "else"} : memref<i32>
+ }
+ return {tag = "test_mod_dead_region"} %ptr : memref<i32>
+}
+
+// CHECK-LABEL: test_tag: unknown_memory_effects_a
+// CHECK: operand #1
+// CHECK-NEXT: - a
+// CHECK-LABEL: test_tag: unknown_memory_effects_b
+// CHECK: operand #0
+// CHECK-NEXT: - <unknown>
+func.func @unknown_memory_effects(%ptr: memref<i32>) -> memref<i32> {
+ %c0 = arith.constant 0 : i32
+ memref.store %c0, %ptr[] {tag = "unknown_memory_effects_a", tag_name = "a"} : memref<i32>
+ "test.unknown_effects"() : () -> ()
+ return {tag = "unknown_memory_effects_b"} %ptr : memref<i32>
+}
diff --git a/mlir/test/lib/Analysis/CMakeLists.txt b/mlir/test/lib/Analysis/CMakeLists.txt
index 8e9446c78bfc..80572f9501fe 100644
--- a/mlir/test/lib/Analysis/CMakeLists.txt
+++ b/mlir/test/lib/Analysis/CMakeLists.txt
@@ -12,6 +12,7 @@ add_mlir_library(MLIRTestAnalysis
TestSlice.cpp
DataFlow/TestDeadCodeAnalysis.cpp
+ DataFlow/TestDenseDataFlowAnalysis.cpp
EXCLUDE_FROM_LIBMLIR
diff --git a/mlir/test/lib/Analysis/DataFlow/TestDenseDataFlowAnalysis.cpp b/mlir/test/lib/Analysis/DataFlow/TestDenseDataFlowAnalysis.cpp
new file mode 100644
index 000000000000..669e683aca1d
--- /dev/null
+++ b/mlir/test/lib/Analysis/DataFlow/TestDenseDataFlowAnalysis.cpp
@@ -0,0 +1,278 @@
+//===- TestDeadCodeAnalysis.cpp - Test dead code analysis -----------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h"
+#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
+#include "mlir/Analysis/DataFlow/DenseAnalysis.h"
+#include "mlir/Analysis/DataFlow/SparseAnalysis.h"
+#include "mlir/Interfaces/SideEffectInterfaces.h"
+#include "mlir/Pass/Pass.h"
+
+using namespace mlir;
+using namespace mlir::dataflow;
+
+namespace {
+/// This lattice represents a single underlying value for an SSA value.
+class UnderlyingValue {
+public:
+ /// The pessimistic underlying value of a value is itself.
+ static UnderlyingValue getPessimisticValueState(Value value) {
+ return {value};
+ }
+
+ /// Create an underlying value state with a known underlying value.
+ UnderlyingValue(Value underlyingValue = {})
+ : underlyingValue(underlyingValue) {}
+
+ /// Returns the underlying value.
+ Value getUnderlyingValue() const { return underlyingValue; }
+
+ /// Join two underlying values. If there are conflicting underlying values,
+ /// go to the pessimistic value.
+ static UnderlyingValue join(const UnderlyingValue &lhs,
+ const UnderlyingValue &rhs) {
+ return lhs.underlyingValue == rhs.underlyingValue ? lhs : UnderlyingValue();
+ }
+
+ /// Compare underlying values.
+ bool operator==(const UnderlyingValue &rhs) const {
+ return underlyingValue == rhs.underlyingValue;
+ }
+
+ void print(raw_ostream &os) const { os << underlyingValue; }
+
+private:
+ Value underlyingValue;
+};
+
+/// This lattice represents, for a given memory resource, the potential last
+/// operations that modified the resource.
+class LastModification : public AbstractDenseLattice {
+public:
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(LastModification)
+
+ using AbstractDenseLattice::AbstractDenseLattice;
+
+ /// The lattice is always initialized.
+ bool isUninitialized() const override { return false; }
+
+ /// Initialize the lattice. Does nothing.
+ ChangeResult defaultInitialize() override { return ChangeResult::NoChange; }
+
+ /// Mark the lattice as having reached its pessimistic fixpoint. That is, the
+ /// last modifications of all memory resources are unknown.
+ ChangeResult reset() override {
+ if (lastMods.empty())
+ return ChangeResult::NoChange;
+ lastMods.clear();
+ return ChangeResult::Change;
+ }
+
+ /// The lattice is never at a fixpoint.
+ bool isAtFixpoint() const override { return false; }
+
+ /// Join the last modifications.
+ ChangeResult join(const AbstractDenseLattice &lattice) override {
+ const auto &rhs = static_cast<const LastModification &>(lattice);
+ ChangeResult result = ChangeResult::NoChange;
+ for (const auto &mod : rhs.lastMods) {
+ auto &lhsMod = lastMods[mod.first];
+ if (lhsMod != mod.second) {
+ lhsMod.insert(mod.second.begin(), mod.second.end());
+ result |= ChangeResult::Change;
+ }
+ }
+ return result;
+ }
+
+ /// Set the last modification of a value.
+ ChangeResult set(Value value, Operation *op) {
+ auto &lastMod = lastMods[value];
+ ChangeResult result = ChangeResult::NoChange;
+ if (lastMod.size() != 1 || *lastMod.begin() != op) {
+ result = ChangeResult::Change;
+ lastMod.clear();
+ lastMod.insert(op);
+ }
+ return result;
+ }
+
+ /// Get the last modifications of a value. Returns none if the last
+ /// modifications are not known.
+ Optional<ArrayRef<Operation *>> getLastModifiers(Value value) const {
+ auto it = lastMods.find(value);
+ if (it == lastMods.end())
+ return {};
+ return it->second.getArrayRef();
+ }
+
+ void print(raw_ostream &os) const override {
+ for (const auto &lastMod : lastMods) {
+ os << lastMod.first << ":\n";
+ for (Operation *op : lastMod.second)
+ os << " " << *op << "\n";
+ }
+ }
+
+private:
+ /// The potential last modifications of a memory resource. Use a set vector to
+ /// keep the results deterministic.
+ DenseMap<Value, SetVector<Operation *, SmallVector<Operation *, 2>,
+ SmallPtrSet<Operation *, 2>>>
+ lastMods;
+};
+
+class LastModifiedAnalysis : public DenseDataFlowAnalysis<LastModification> {
+public:
+ using DenseDataFlowAnalysis::DenseDataFlowAnalysis;
+
+ /// Visit an operation. If the operation has no memory effects, then the state
+ /// is propagated with no change. If the operation allocates a resource, then
+ /// its reaching definitions is set to empty. If the operation writes to a
+ /// resource, then its reaching definition is set to the written value.
+ void visitOperation(Operation *op, const LastModification &before,
+ LastModification *after) override;
+};
+
+/// Define the lattice class explicitly to provide a type ID.
+struct UnderlyingValueLattice : public Lattice<UnderlyingValue> {
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(UnderlyingValueLattice)
+ using Lattice::Lattice;
+};
+
+/// An analysis that uses forwarding of values along control-flow and callgraph
+/// edges to determine single underlying values for block arguments. This
+/// analysis exists so that the test analysis and pass can test the behaviour of
+/// the dense data-flow analysis on the callgraph.
+class UnderlyingValueAnalysis
+ : public SparseDataFlowAnalysis<UnderlyingValueLattice> {
+public:
+ using SparseDataFlowAnalysis::SparseDataFlowAnalysis;
+
+ /// The underlying value of the results of an operation are not known.
+ void visitOperation(Operation *op,
+ ArrayRef<const UnderlyingValueLattice *> operands,
+ ArrayRef<UnderlyingValueLattice *> results) override {
+ markAllPessimisticFixpoint(results);
+ }
+};
+} // end anonymous namespace
+
+/// Look for the most underlying value of a value.
+static Value getMostUnderlyingValue(
+ Value value,
+ function_ref<const UnderlyingValueLattice *(Value)> getUnderlyingValueFn) {
+ const UnderlyingValueLattice *underlying;
+ do {
+ underlying = getUnderlyingValueFn(value);
+ if (!underlying || underlying->isUninitialized())
+ return {};
+ Value underlyingValue = underlying->getValue().getUnderlyingValue();
+ if (underlyingValue == value)
+ break;
+ value = underlyingValue;
+ } while (true);
+ return value;
+}
+
+void LastModifiedAnalysis::visitOperation(Operation *op,
+ const LastModification &before,
+ LastModification *after) {
+ auto memory = dyn_cast<MemoryEffectOpInterface>(op);
+ // If we can't reason about the memory effects, then conservatively assume we
+ // can't deduce anything about the last modifications.
+ if (!memory)
+ return reset(after);
+
+ SmallVector<MemoryEffects::EffectInstance> effects;
+ memory.getEffects(effects);
+
+ ChangeResult result = after->join(before);
+ for (const auto &effect : effects) {
+ Value value = effect.getValue();
+
+ // If we see an effect on anything other than a value, assume we can't
+ // deduce anything about the last modifications.
+ if (!value)
+ return reset(after);
+
+ value = getMostUnderlyingValue(value, [&](Value value) {
+ return getOrCreateFor<UnderlyingValueLattice>(op, value);
+ });
+ if (!value)
+ return;
+
+ // Nothing to do for reads.
+ if (isa<MemoryEffects::Read>(effect.getEffect()))
+ continue;
+
+ result |= after->set(value, op);
+ }
+ propagateIfChanged(after, result);
+}
+
+namespace {
+struct TestLastModifiedPass
+ : public PassWrapper<TestLastModifiedPass, OperationPass<>> {
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLastModifiedPass)
+
+ StringRef getArgument() const override { return "test-last-modified"; }
+
+ void runOnOperation() override {
+ Operation *op = getOperation();
+
+ DataFlowSolver solver;
+ solver.load<DeadCodeAnalysis>();
+ solver.load<SparseConstantPropagation>();
+ solver.load<LastModifiedAnalysis>();
+ solver.load<UnderlyingValueAnalysis>();
+ if (failed(solver.initializeAndRun(op)))
+ return signalPassFailure();
+
+ raw_ostream &os = llvm::errs();
+
+ op->walk([&](Operation *op) {
+ auto tag = op->getAttrOfType<StringAttr>("tag");
+ if (!tag)
+ return;
+ os << "test_tag: " << tag.getValue() << ":\n";
+ const LastModification *lastMods =
+ solver.lookupState<LastModification>(op);
+ assert(lastMods && "expected a dense lattice");
+ for (auto &it : llvm::enumerate(op->getOperands())) {
+ os << " operand #" << it.index() << "\n";
+ Value value = getMostUnderlyingValue(it.value(), [&](Value value) {
+ return solver.lookupState<UnderlyingValueLattice>(value);
+ });
+ assert(value && "expected an underlying value");
+ if (Optional<ArrayRef<Operation *>> lastMod =
+ lastMods->getLastModifiers(value)) {
+ for (Operation *lastModifier : *lastMod) {
+ if (auto tagName =
+ lastModifier->getAttrOfType<StringAttr>("tag_name")) {
+ os << " - " << tagName.getValue() << "\n";
+ } else {
+ os << " - " << lastModifier->getName() << "\n";
+ }
+ }
+ } else {
+ os << " - <unknown>\n";
+ }
+ }
+ });
+ }
+};
+} // end anonymous namespace
+
+namespace mlir {
+namespace test {
+void registerTestLastModifiedPass() {
+ PassRegistration<TestLastModifiedPass>();
+}
+} // end namespace test
+} // end namespace mlir
diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index ecd85a7f3aaa..1b81fee9f86e 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -86,6 +86,7 @@ void registerTestIRVisitorsPass();
void registerTestGenericIRVisitorsPass();
void registerTestGenericIRVisitorsInterruptPass();
void registerTestInterfaces();
+void registerTestLastModifiedPass();
void registerTestLinalgCodegenStrategy();
void registerTestLinalgElementwiseFusion();
void registerTestLinalgFusionTransforms();
@@ -185,6 +186,7 @@ void registerTestPasses() {
mlir::test::registerTestIRVisitorsPass();
mlir::test::registerTestGenericIRVisitorsPass();
mlir::test::registerTestInterfaces();
+ mlir::test::registerTestLastModifiedPass();
mlir::test::registerTestLinalgCodegenStrategy();
mlir::test::registerTestLinalgElementwiseFusion();
mlir::test::registerTestLinalgFusionTransforms();
More information about the Mlir-commits
mailing list