[Mlir-commits] [mlir] ead75d9 - (Reland)[mlir] Add a generic data-flow analysis framework
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Jun 14 14:33:17 PDT 2022
Author: Mogball
Date: 2022-06-14T21:33:05Z
New Revision: ead75d9434ec83817a715582065349ff70932be6
URL: https://github.com/llvm/llvm-project/commit/ead75d9434ec83817a715582065349ff70932be6
DIFF: https://github.com/llvm/llvm-project/commit/ead75d9434ec83817a715582065349ff70932be6.diff
LOG: (Reland)[mlir] Add a generic data-flow analysis framework
Removes one element of the pointer union to make it work on 32-bit
systems.
This patch introduces a generic data-flow analysis framework to MLIR. The framework implements a fixed-point iteration algorithm and a dependency graph between lattice states and analysis. Lattice states and points are fully extensible to support highly-customizable analyses.
Reviewed By: phisiart, rriddle
Differential Revision: https://reviews.llvm.org/D126751
Added:
mlir/include/mlir/Analysis/DataFlowFramework.h
mlir/lib/Analysis/DataFlowFramework.cpp
mlir/test/Analysis/test-foo-analysis.mlir
mlir/test/lib/Analysis/TestDataFlowFramework.cpp
Modified:
mlir/lib/Analysis/CMakeLists.txt
mlir/test/lib/Analysis/CMakeLists.txt
mlir/tools/mlir-opt/mlir-opt.cpp
utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel
Removed:
################################################################################
diff --git a/mlir/include/mlir/Analysis/DataFlowFramework.h b/mlir/include/mlir/Analysis/DataFlowFramework.h
new file mode 100644
index 0000000000000..9b85182532232
--- /dev/null
+++ b/mlir/include/mlir/Analysis/DataFlowFramework.h
@@ -0,0 +1,454 @@
+//===- DataFlowFramework.h - A generic framework for data-flow analysis ---===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines a generic framework for writing data-flow analysis in MLIR.
+// The framework consists of a solver, which runs the fixed-point iteration and
+// manages analysis dependencies, and a data-flow analysis class used to
+// implement specific analyses.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_ANALYSIS_DATAFLOWFRAMEWORK_H
+#define MLIR_ANALYSIS_DATAFLOWFRAMEWORK_H
+
+#include "mlir/Analysis/DataFlowAnalysis.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/Support/StorageUniquer.h"
+#include "llvm/ADT/SetVector.h"
+#include "llvm/Support/TypeName.h"
+#include <queue>
+
+namespace mlir {
+
+/// Forward declare the analysis state class.
+class AnalysisState;
+
+//===----------------------------------------------------------------------===//
+// GenericProgramPoint
+//===----------------------------------------------------------------------===//
+
+/// Abstract class for generic program points. In classical data-flow analysis,
+/// programs points represent positions in a program to which lattice elements
+/// are attached. In sparse data-flow analysis, these can be SSA values, and in
+/// dense data-flow analysis, these are the program points before and after
+/// every operation.
+///
+/// In the general MLIR data-flow analysis framework, program points are an
+/// extensible concept. Program points are uniquely identifiable objects to
+/// which analysis states can be attached. The semantics of program points are
+/// defined by the analyses that specify their transfer functions.
+///
+/// Program points are implemented using MLIR's storage uniquer framework and
+/// type ID system to provide RTTI.
+class GenericProgramPoint : public StorageUniquer::BaseStorage {
+public:
+ virtual ~GenericProgramPoint();
+
+ /// Get the abstract program point's type identifier.
+ TypeID getTypeID() const { return typeID; }
+
+ /// Get a derived source location for the program point.
+ virtual Location getLoc() const = 0;
+
+ /// Print the program point.
+ virtual void print(raw_ostream &os) const = 0;
+
+protected:
+ /// Create an abstract program point with type identifier.
+ explicit GenericProgramPoint(TypeID typeID) : typeID(typeID) {}
+
+private:
+ /// The type identifier of the program point.
+ TypeID typeID;
+};
+
+//===----------------------------------------------------------------------===//
+// GenericProgramPointBase
+//===----------------------------------------------------------------------===//
+
+/// Base class for generic program points based on a concrete program point
+/// type and a content key. This class defines the common methods required for
+/// operability with the storage uniquer framework.
+///
+/// The provided key type uniquely identifies the concrete program point
+/// instance and are the data members of the class.
+template <typename ConcreteT, typename Value>
+class GenericProgramPointBase : public GenericProgramPoint {
+public:
+ /// The concrete key type used by the storage uniquer. This class is uniqued
+ /// by its contents.
+ using KeyTy = Value;
+ /// Alias for the base class.
+ using Base = GenericProgramPointBase<ConcreteT, Value>;
+
+ /// Construct an instance of the program point using the provided value and
+ /// the type ID of the concrete type.
+ template <typename ValueT>
+ explicit GenericProgramPointBase(ValueT &&value)
+ : GenericProgramPoint(TypeID::get<ConcreteT>()),
+ value(std::forward<ValueT>(value)) {}
+
+ /// Get a uniqued instance of this program point class with the given
+ /// arguments.
+ template <typename... Args>
+ static ConcreteT *get(StorageUniquer &uniquer, Args &&...args) {
+ return uniquer.get<ConcreteT>(/*initFn=*/{}, std::forward<Args>(args)...);
+ }
+
+ /// Allocate space for a program point and construct it in-place.
+ template <typename ValueT>
+ static ConcreteT *construct(StorageUniquer::StorageAllocator &alloc,
+ ValueT &&value) {
+ return new (alloc.allocate<ConcreteT>())
+ ConcreteT(std::forward<ValueT>(value));
+ }
+
+ /// Two program points are equal if their values are equal.
+ bool operator==(const Value &value) const { return this->value == value; }
+
+ /// Provide LLVM-style RTTI using type IDs.
+ static bool classof(const GenericProgramPoint *point) {
+ return point->getTypeID() == TypeID::get<ConcreteT>();
+ }
+
+ /// Get the contents of the program point.
+ const Value &getValue() const { return value; }
+
+private:
+ /// The program point value.
+ Value value;
+};
+
+//===----------------------------------------------------------------------===//
+// ProgramPoint
+//===----------------------------------------------------------------------===//
+
+/// Fundamental IR components are supported as first-class program points.
+struct ProgramPoint
+ : public PointerUnion<GenericProgramPoint *, Operation *, Value, Block *> {
+ using ParentTy =
+ PointerUnion<GenericProgramPoint *, Operation *, Value, Block *>;
+ /// Inherit constructors.
+ using ParentTy::PointerUnion;
+ /// Allow implicit conversion from the parent type.
+ ProgramPoint(ParentTy point = nullptr) : ParentTy(point) {}
+
+ /// Print the program point.
+ void print(raw_ostream &os) const;
+
+ /// Get the source location of the program point.
+ Location getLoc() const;
+};
+
+/// Forward declaration of the data-flow analysis class.
+class DataFlowAnalysis;
+
+//===----------------------------------------------------------------------===//
+// DataFlowSolver
+//===----------------------------------------------------------------------===//
+
+/// The general data-flow analysis solver. This class is responsible for
+/// orchestrating child data-flow analyses, running the fixed-point iteration
+/// algorithm, managing analysis state and program point memory, and tracking
+/// dependencies beteen analyses, program points, and analysis states.
+///
+/// Steps to run a data-flow analysis:
+///
+/// 1. Load and initialize children analyses. Children analyses are instantiated
+/// in the solver and initialized, building their dependency relations.
+/// 2. Configure and run the analysis. The solver invokes the children analyses
+/// according to their dependency relations until a fixed point is reached.
+/// 3. Query analysis state results from the solver.
+///
+/// TODO: Optimize the internal implementation of the solver.
+class DataFlowSolver {
+public:
+ /// Load an analysis into the solver. Return the analysis instance.
+ template <typename AnalysisT, typename... Args>
+ AnalysisT *load(Args &&...args);
+
+ /// Initialize the children analyses starting from the provided top-level
+ /// operation and run the analysis until fixpoint.
+ LogicalResult initializeAndRun(Operation *top);
+
+ /// Lookup an analysis state for the given program point. Returns null if one
+ /// does not exist.
+ template <typename StateT, typename PointT>
+ const StateT *lookupState(PointT point) const {
+ auto it = analysisStates.find({point, TypeID::get<StateT>()});
+ if (it == analysisStates.end())
+ return nullptr;
+ return static_cast<const StateT *>(it->second.get());
+ }
+
+ /// Get a uniqued program point instance. If one is not present, it is
+ /// created with the provided arguments.
+ template <typename PointT, typename... Args>
+ PointT *getProgramPoint(Args &&...args) {
+ return PointT::get(uniquer, std::forward<Args>(args)...);
+ }
+
+ /// A work item on the solver queue is a program point, child analysis pair.
+ /// Each item is processed by invoking the child analysis at the program
+ /// point.
+ using WorkItem = std::pair<ProgramPoint, DataFlowAnalysis *>;
+ /// Push a work item onto the worklist.
+ void enqueue(WorkItem item) { worklist.push(std::move(item)); }
+
+protected:
+ /// Get the state associated with the given program point. If it does not
+ /// exist, create an uninitialized state.
+ template <typename StateT, typename PointT>
+ StateT *getOrCreateState(PointT point);
+
+ /// Propagate an update to an analysis state if it changed by pushing
+ /// dependent work items to the back of the queue.
+ void propagateIfChanged(AnalysisState *state, ChangeResult changed);
+
+ /// Add a dependency to an analysis state on a child analysis and program
+ /// point. If the state is updated, the child analysis must be invoked on the
+ /// given program point again.
+ void addDependency(AnalysisState *state, DataFlowAnalysis *analysis,
+ ProgramPoint point);
+
+private:
+ /// The solver's work queue. Work items can be inserted to the front of the
+ /// queue to be processed greedily, speeding up computations that otherwise
+ /// quickly degenerate to quadratic due to propagation of state updates.
+ std::queue<WorkItem> worklist;
+
+ /// Type-erased instances of the children analyses.
+ SmallVector<std::unique_ptr<DataFlowAnalysis>> childAnalyses;
+
+ /// The storage uniquer instance that owns the memory of the allocated program
+ /// points.
+ StorageUniquer uniquer;
+
+ /// A type-erased map of program points to associated analysis states for
+ /// first-class program points.
+ DenseMap<std::pair<ProgramPoint, TypeID>, std::unique_ptr<AnalysisState>>
+ analysisStates;
+
+ /// Allow the base child analysis class to access the internals of the solver.
+ friend class DataFlowAnalysis;
+};
+
+//===----------------------------------------------------------------------===//
+// AnalysisState
+//===----------------------------------------------------------------------===//
+
+/// Base class for generic analysis states. Analysis states contain data-flow
+/// information that are attached to program points and which evolve as the
+/// analysis iterates.
+///
+/// This class places no restrictions on the semantics of analysis states beyond
+/// these requirements.
+///
+/// 1. Querying the state of a program point prior to visiting that point
+/// results in uninitialized state. Analyses must be aware of unintialized
+/// states.
+/// 2. Analysis states can reach fixpoints, where subsequent updates will never
+/// trigger a change in the state.
+/// 3. Analysis states that are uninitialized can be forcefully initialized to a
+/// default value.
+class AnalysisState {
+public:
+ virtual ~AnalysisState();
+
+ /// Create the analysis state at the given program point.
+ AnalysisState(ProgramPoint point) : point(point) {}
+
+ /// Returns true if the analysis state is uninitialized.
+ virtual bool isUninitialized() const = 0;
+
+ /// Force an uninitialized analysis state to initialize itself with a default
+ /// value.
+ virtual ChangeResult defaultInitialize() = 0;
+
+ /// Print the contents of the analysis state.
+ virtual void print(raw_ostream &os) const = 0;
+
+protected:
+ /// This function is called by the solver when the analysis state is updated
+ /// to optionally enqueue more work items. For example, if a state tracks
+ /// dependents through the IR (e.g. use-def chains), this function can be
+ /// implemented to push those dependents on the worklist.
+ virtual void onUpdate(DataFlowSolver *solver) const {}
+
+ /// The dependency relations originating from this analysis state. An entry
+ /// `state -> (analysis, point)` is created when `analysis` queries `state`
+ /// when updating `point`.
+ ///
+ /// When this state is updated, all dependent child analysis invocations are
+ /// pushed to the back of the queue. Use a `SetVector` to keep the analysis
+ /// deterministic.
+ ///
+ /// Store the dependents on the analysis state for efficiency.
+ SetVector<DataFlowSolver::WorkItem> dependents;
+
+ /// The program point to which the state belongs.
+ ProgramPoint point;
+
+#if LLVM_ENABLE_ABI_BREAKING_CHECKS
+ /// When compiling with debugging, keep a name for the analysis state.
+ StringRef debugName;
+#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
+
+ /// Allow the framework to access the dependents.
+ friend class DataFlowSolver;
+};
+
+//===----------------------------------------------------------------------===//
+// DataFlowAnalysis
+//===----------------------------------------------------------------------===//
+
+/// Base class for all data-flow analyses. A child analysis is expected to build
+/// an initial dependency graph (and optionally provide an initial state) when
+/// initialized and define transfer functions when visiting program points.
+///
+/// In classical data-flow analysis, the dependency graph is fixed and analyses
+/// define explicit transfer functions between input states and output states.
+/// In this framework, however, the dependency graph can change during the
+/// analysis, and transfer functions are opaque such that the solver doesn't
+/// know what states calling `visit` on an analysis will be updated. This allows
+/// multiple analyses to plug in and provide values for the same state.
+///
+/// Generally, when an analysis queries an uninitialized state, it is expected
+/// to "bail out", i.e., not provide any updates. When the value is initialized,
+/// the solver will re-invoke the analysis. If the solver exhausts its worklist,
+/// however, and there are still uninitialized states, the solver "nudges" the
+/// analyses by default-initializing those states.
+class DataFlowAnalysis {
+public:
+ virtual ~DataFlowAnalysis();
+
+ /// Create an analysis with a reference to the parent solver.
+ explicit DataFlowAnalysis(DataFlowSolver &solver);
+
+ /// Initialize the analysis from the provided top-level operation by building
+ /// an initial dependency graph between all program points of interest. This
+ /// can be implemented by calling `visit` on all program points of interest
+ /// below the top-level operation.
+ ///
+ /// An analysis can optionally provide initial values to certain analysis
+ /// states to influence the evolution of the analysis.
+ virtual LogicalResult initialize(Operation *top) = 0;
+
+ /// Visit the given program point. This function is invoked by the solver on
+ /// this analysis with a given program point when a dependent analysis state
+ /// is updated. The function is similar to a transfer function; it queries
+ /// certain analysis states and sets other states.
+ ///
+ /// The function is expected to create dependencies on queried states and
+ /// propagate updates on changed states. A dependency can be created by
+ /// calling `addDependency` between the input state and a program point,
+ /// indicating that, if the state is updated, the solver should invoke `solve`
+ /// on the program point. The dependent point does not have to be the same as
+ /// the provided point. An update to a state is propagated by calling
+ /// `propagateIfChange` on the state. If the state has changed, then all its
+ /// dependents are placed on the worklist.
+ ///
+ /// The dependency graph does not need to be static. Each invocation of
+ /// `visit` can add new dependencies, but these dependecies will not be
+ /// dynamically added to the worklist because the solver doesn't know what
+ /// will provide a value for then.
+ virtual LogicalResult visit(ProgramPoint point) = 0;
+
+protected:
+ /// Create a dependency between the given analysis state and program point
+ /// on this analysis.
+ void addDependency(AnalysisState *state, ProgramPoint point);
+
+ /// Propagate an update to a state if it changed.
+ void propagateIfChanged(AnalysisState *state, ChangeResult changed);
+
+ /// Register a custom program point class.
+ template <typename PointT>
+ void registerPointKind() {
+ solver.uniquer.registerParametricStorageType<PointT>();
+ }
+
+ /// Get or create a custom program point.
+ template <typename PointT, typename... Args>
+ PointT *getProgramPoint(Args &&...args) {
+ return solver.getProgramPoint<PointT>(std::forward<Args>(args)...);
+ }
+
+ /// Get the analysis state assiocated with the program point. The returned
+ /// state is expected to be "write-only", and any updates need to be
+ /// propagated by `propagateIfChanged`.
+ template <typename StateT, typename PointT>
+ StateT *getOrCreate(PointT point) {
+ return solver.getOrCreateState<StateT>(point);
+ }
+
+ /// Get a read-only analysis state for the given point and create a dependency
+ /// on `dependent`. If the return state is updated elsewhere, this analysis is
+ /// re-invoked on the dependent.
+ template <typename StateT, typename PointT>
+ const StateT *getOrCreateFor(ProgramPoint dependent, PointT point) {
+ StateT *state = getOrCreate<StateT>(point);
+ addDependency(state, dependent);
+ return state;
+ }
+
+#if LLVM_ENABLE_ABI_BREAKING_CHECKS
+ /// When compiling with debugging, keep a name for the analyis.
+ StringRef debugName;
+#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
+
+private:
+ /// The parent data-flow solver.
+ DataFlowSolver &solver;
+
+ /// Allow the data-flow solver to access the internals of this class.
+ friend class DataFlowSolver;
+};
+
+template <typename AnalysisT, typename... Args>
+AnalysisT *DataFlowSolver::load(Args &&...args) {
+ childAnalyses.emplace_back(new AnalysisT(*this, std::forward<Args>(args)...));
+#if LLVM_ENABLE_ABI_BREAKING_CHECKS
+ childAnalyses.back().get()->debugName = llvm::getTypeName<AnalysisT>();
+#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
+ return static_cast<AnalysisT *>(childAnalyses.back().get());
+}
+
+template <typename StateT, typename PointT>
+StateT *DataFlowSolver::getOrCreateState(PointT point) {
+ std::unique_ptr<AnalysisState> &state =
+ analysisStates[{ProgramPoint(point), TypeID::get<StateT>()}];
+ if (!state) {
+ state = std::unique_ptr<StateT>(new StateT(point));
+#if LLVM_ENABLE_ABI_BREAKING_CHECKS
+ state->debugName = llvm::getTypeName<StateT>();
+#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
+ }
+ return static_cast<StateT *>(state.get());
+}
+
+inline raw_ostream &operator<<(raw_ostream &os, const AnalysisState &state) {
+ state.print(os);
+ return os;
+}
+
+inline raw_ostream &operator<<(raw_ostream &os, ProgramPoint point) {
+ point.print(os);
+ return os;
+}
+
+} // end namespace mlir
+
+namespace llvm {
+/// Allow hashing of program points.
+template <>
+struct DenseMapInfo<mlir::ProgramPoint>
+ : public DenseMapInfo<mlir::ProgramPoint::ParentTy> {};
+} // end namespace llvm
+
+#endif // MLIR_ANALYSIS_DATAFLOWFRAMEWORK_H
diff --git a/mlir/lib/Analysis/CMakeLists.txt b/mlir/lib/Analysis/CMakeLists.txt
index 6c45e40efa9ab..fe5f3832322ce 100644
--- a/mlir/lib/Analysis/CMakeLists.txt
+++ b/mlir/lib/Analysis/CMakeLists.txt
@@ -16,6 +16,7 @@ add_mlir_library(MLIRAnalysis
BufferViewFlowAnalysis.cpp
CallGraph.cpp
DataFlowAnalysis.cpp
+ DataFlowFramework.cpp
DataLayoutAnalysis.cpp
IntRangeAnalysis.cpp
Liveness.cpp
diff --git a/mlir/lib/Analysis/DataFlowFramework.cpp b/mlir/lib/Analysis/DataFlowFramework.cpp
new file mode 100644
index 0000000000000..be18432468d4f
--- /dev/null
+++ b/mlir/lib/Analysis/DataFlowFramework.cpp
@@ -0,0 +1,151 @@
+//===- DataFlowFramework.cpp - A generic framework for data-flow analysis -===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Analysis/DataFlowFramework.h"
+#include "llvm/Support/Debug.h"
+
+#define DEBUG_TYPE "dataflow"
+#if LLVM_ENABLE_ABI_BREAKING_CHECKS
+#define DATAFLOW_DEBUG(X) LLVM_DEBUG(X)
+#else
+#define DATAFLOW_DEBUG(X)
+#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
+
+using namespace mlir;
+
+//===----------------------------------------------------------------------===//
+// GenericProgramPoint
+//===----------------------------------------------------------------------===//
+
+GenericProgramPoint::~GenericProgramPoint() = default;
+
+//===----------------------------------------------------------------------===//
+// AnalysisState
+//===----------------------------------------------------------------------===//
+
+AnalysisState::~AnalysisState() = default;
+
+//===----------------------------------------------------------------------===//
+// ProgramPoint
+//===----------------------------------------------------------------------===//
+
+void ProgramPoint::print(raw_ostream &os) const {
+ if (isNull()) {
+ os << "<NULL POINT>";
+ return;
+ }
+ if (auto *programPoint = dyn_cast<GenericProgramPoint *>())
+ return programPoint->print(os);
+ if (auto *op = dyn_cast<Operation *>())
+ return op->print(os);
+ if (auto value = dyn_cast<Value>())
+ return value.print(os);
+ return get<Block *>()->print(os);
+}
+
+Location ProgramPoint::getLoc() const {
+ if (auto *programPoint = dyn_cast<GenericProgramPoint *>())
+ return programPoint->getLoc();
+ if (auto *op = dyn_cast<Operation *>())
+ return op->getLoc();
+ if (auto value = dyn_cast<Value>())
+ return value.getLoc();
+ return get<Block *>()->getParent()->getLoc();
+}
+
+//===----------------------------------------------------------------------===//
+// DataFlowSolver
+//===----------------------------------------------------------------------===//
+
+LogicalResult DataFlowSolver::initializeAndRun(Operation *top) {
+ // Initialize the analyses.
+ for (DataFlowAnalysis &analysis : llvm::make_pointee_range(childAnalyses)) {
+ DATAFLOW_DEBUG(llvm::dbgs()
+ << "Priming analysis: " << analysis.debugName << "\n");
+ if (failed(analysis.initialize(top)))
+ return failure();
+ }
+
+ // Run the analysis until fixpoint.
+ ProgramPoint point;
+ DataFlowAnalysis *analysis;
+
+ do {
+ // Exhaust the worklist.
+ while (!worklist.empty()) {
+ std::tie(point, analysis) = worklist.front();
+ worklist.pop();
+
+ DATAFLOW_DEBUG(llvm::dbgs() << "Invoking '" << analysis->debugName
+ << "' on: " << point << "\n");
+ if (failed(analysis->visit(point)))
+ return failure();
+ }
+
+ // "Nudge" the state of the analysis by forcefully initializing states that
+ // are still uninitialized. All uninitialized states in the graph can be
+ // initialized in any order because the analysis reached fixpoint, meaning
+ // that there are no work items that would have further nudged the analysis.
+ for (AnalysisState &state :
+ llvm::make_pointee_range(llvm::make_second_range(analysisStates))) {
+ if (!state.isUninitialized())
+ continue;
+ DATAFLOW_DEBUG(llvm::dbgs() << "Default initializing " << state.debugName
+ << " of " << state.point << "\n");
+ propagateIfChanged(&state, state.defaultInitialize());
+ }
+
+ // Iterate until all states are in some initialized state and the worklist
+ // is exhausted.
+ } while (!worklist.empty());
+
+ return success();
+}
+
+void DataFlowSolver::propagateIfChanged(AnalysisState *state,
+ ChangeResult changed) {
+ if (changed == ChangeResult::Change) {
+ DATAFLOW_DEBUG(llvm::dbgs() << "Propagating update to " << state->debugName
+ << " of " << state->point << "\n"
+ << "Value: " << *state << "\n");
+ for (const WorkItem &item : state->dependents)
+ enqueue(item);
+ state->onUpdate(this);
+ }
+}
+
+void DataFlowSolver::addDependency(AnalysisState *state,
+ DataFlowAnalysis *analysis,
+ ProgramPoint point) {
+ auto inserted = state->dependents.insert({point, analysis});
+ (void)inserted;
+ DATAFLOW_DEBUG({
+ if (inserted) {
+ llvm::dbgs() << "Creating dependency between " << state->debugName
+ << " of " << state->point << "\nand " << analysis->debugName
+ << " on " << point << "\n";
+ }
+ });
+}
+
+//===----------------------------------------------------------------------===//
+// DataFlowAnalysis
+//===----------------------------------------------------------------------===//
+
+DataFlowAnalysis::~DataFlowAnalysis() = default;
+
+DataFlowAnalysis::DataFlowAnalysis(DataFlowSolver &solver) : solver(solver) {}
+
+void DataFlowAnalysis::addDependency(AnalysisState *state, ProgramPoint point) {
+ solver.addDependency(state, this, point);
+}
+
+void DataFlowAnalysis::propagateIfChanged(AnalysisState *state,
+ ChangeResult changed) {
+ solver.propagateIfChanged(state, changed);
+}
diff --git a/mlir/test/Analysis/test-foo-analysis.mlir b/mlir/test/Analysis/test-foo-analysis.mlir
new file mode 100644
index 0000000000000..7c5d07396a83f
--- /dev/null
+++ b/mlir/test/Analysis/test-foo-analysis.mlir
@@ -0,0 +1,95 @@
+// RUN: mlir-opt -split-input-file -pass-pipeline='func.func(test-foo-analysis)' %s 2>&1 | FileCheck %s
+
+// CHECK-LABEL: function: @test_default_init
+func.func @test_default_init() -> () {
+ // CHECK: a -> 0
+ "test.foo"() {tag = "a"} : () -> ()
+ return
+}
+
+// -----
+
+// CHECK-LABEL: function: @test_one_join
+func.func @test_one_join() -> () {
+ // CHECK: a -> 0
+ "test.foo"() {tag = "a"} : () -> ()
+ // CHECK: b -> 1
+ "test.foo"() {tag = "b", foo = 1 : ui64} : () -> ()
+ return
+}
+
+// -----
+
+// CHECK-LABEL: function: @test_two_join
+func.func @test_two_join() -> () {
+ // CHECK: a -> 0
+ "test.foo"() {tag = "a"} : () -> ()
+ // CHECK: b -> 1
+ "test.foo"() {tag = "b", foo = 1 : ui64} : () -> ()
+ // CHECK: c -> 0
+ "test.foo"() {tag = "c", foo = 1 : ui64} : () -> ()
+ return
+}
+
+// -----
+
+// CHECK-LABEL: function: @test_fork
+func.func @test_fork() -> () {
+ // CHECK: init -> 1
+ "test.branch"() [^bb0, ^bb1] {tag = "init", foo = 1 : ui64} : () -> ()
+
+^bb0:
+ // CHECK: a -> 3
+ "test.branch"() [^bb2] {tag = "a", foo = 2 : ui64} : () -> ()
+
+^bb1:
+ // CHECK: b -> 5
+ "test.branch"() [^bb2] {tag = "b", foo = 4 : ui64} : () -> ()
+
+^bb2:
+ // CHECK: end -> 6
+ "test.foo"() {tag = "end"} : () -> ()
+ return
+
+}
+
+// -----
+
+// CHECK-LABEL: function: @test_simple_loop
+func.func @test_simple_loop() -> () {
+ // CHECK: init -> 1
+ "test.branch"() [^bb0] {tag = "init", foo = 1 : ui64} : () -> ()
+
+^bb0:
+ // CHECK: a -> 1
+ "test.foo"() {tag = "a", foo = 3 : ui64} : () -> ()
+ "test.branch"() [^bb0, ^bb1] : () -> ()
+
+^bb1:
+ // CHECK: end -> 3
+ "test.foo"() {tag = "end"} : () -> ()
+ return
+}
+
+// -----
+
+// CHECK-LABEL: function: @test_double_loop
+func.func @test_double_loop() -> () {
+ // CHECK: init -> 2
+ "test.branch"() [^bb0] {tag = "init", foo = 2 : ui64} : () -> ()
+
+^bb0:
+ // CHECK: a -> 1
+ "test.foo"() {tag = "a", foo = 3 : ui64} : () -> ()
+ "test.branch"() [^bb0, ^bb1] : () -> ()
+
+^bb1:
+ // CHECK: b -> 4
+ "test.foo"() {tag = "b", foo = 5 : ui64} : () -> ()
+ "test.branch"() [^bb0, ^bb2] : () -> ()
+
+^bb2:
+ // CHECK: end -> 4
+ "test.foo"() {tag = "end"} : () -> ()
+ return
+}
diff --git a/mlir/test/lib/Analysis/CMakeLists.txt b/mlir/test/lib/Analysis/CMakeLists.txt
index 02ca3a1481993..d0b9d2be4f6ea 100644
--- a/mlir/test/lib/Analysis/CMakeLists.txt
+++ b/mlir/test/lib/Analysis/CMakeLists.txt
@@ -3,6 +3,7 @@ add_mlir_library(MLIRTestAnalysis
TestAliasAnalysis.cpp
TestCallGraph.cpp
TestDataFlow.cpp
+ TestDataFlowFramework.cpp
TestLiveness.cpp
TestMatchReduction.cpp
TestMemRefBoundCheck.cpp
diff --git a/mlir/test/lib/Analysis/TestDataFlowFramework.cpp b/mlir/test/lib/Analysis/TestDataFlowFramework.cpp
new file mode 100644
index 0000000000000..329be3c5446f9
--- /dev/null
+++ b/mlir/test/lib/Analysis/TestDataFlowFramework.cpp
@@ -0,0 +1,188 @@
+//===- TestDataFlowFramework.cpp - Test data-flow analysis framework ------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Analysis/DataFlowFramework.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Pass/Pass.h"
+
+using namespace mlir;
+
+namespace {
+/// This analysis state represents an integer that is XOR'd with other states.
+class FooState : public AnalysisState {
+public:
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(FooState)
+
+ using AnalysisState::AnalysisState;
+
+ /// Default-initialize the state to zero.
+ ChangeResult defaultInitialize() override { return join(0); }
+
+ /// Returns true if the state is uninitialized.
+ bool isUninitialized() const override { return !state; }
+
+ /// Print the integer value or "none" if uninitialized.
+ void print(raw_ostream &os) const override {
+ if (state)
+ os << *state;
+ else
+ os << "none";
+ }
+
+ /// Join the state with another. If either is unintialized, take the
+ /// initialized value. Otherwise, XOR the integer values.
+ ChangeResult join(const FooState &rhs) {
+ if (rhs.isUninitialized())
+ return ChangeResult::NoChange;
+ return join(*rhs.state);
+ }
+ ChangeResult join(uint64_t value) {
+ if (isUninitialized()) {
+ state = value;
+ return ChangeResult::Change;
+ }
+ uint64_t before = *state;
+ state = before ^ value;
+ return before == *state ? ChangeResult::NoChange : ChangeResult::Change;
+ }
+
+ /// Set the value of the state directly.
+ ChangeResult set(const FooState &rhs) {
+ if (state == rhs.state)
+ return ChangeResult::NoChange;
+ state = rhs.state;
+ return ChangeResult::Change;
+ }
+
+ /// Returns the integer value of the state.
+ uint64_t getValue() const { return *state; }
+
+private:
+ /// An optional integer value.
+ Optional<uint64_t> state;
+};
+
+/// This analysis computes `FooState` across operations and control-flow edges.
+/// If an op specifies a `foo` integer attribute, the contained value is XOR'd
+/// with the value before the operation.
+class FooAnalysis : public DataFlowAnalysis {
+public:
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(FooAnalysis)
+
+ using DataFlowAnalysis::DataFlowAnalysis;
+
+ LogicalResult initialize(Operation *top) override;
+ LogicalResult visit(ProgramPoint point) override;
+
+private:
+ void visitBlock(Block *block);
+ void visitOperation(Operation *op);
+};
+
+struct TestFooAnalysisPass
+ : public PassWrapper<TestFooAnalysisPass, OperationPass<func::FuncOp>> {
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestFooAnalysisPass)
+
+ StringRef getArgument() const override { return "test-foo-analysis"; }
+
+ void runOnOperation() override;
+};
+} // namespace
+
+LogicalResult FooAnalysis::initialize(Operation *top) {
+ if (top->getNumRegions() != 1)
+ return top->emitError("expected a single region top-level op");
+
+ // Initialize the top-level state.
+ getOrCreate<FooState>(&top->getRegion(0).front())->join(0);
+
+ // Visit all nested blocks and operations.
+ for (Block &block : top->getRegion(0)) {
+ visitBlock(&block);
+ for (Operation &op : block) {
+ if (op.getNumRegions())
+ return op.emitError("unexpected op with regions");
+ visitOperation(&op);
+ }
+ }
+ return success();
+}
+
+LogicalResult FooAnalysis::visit(ProgramPoint point) {
+ if (auto *op = point.dyn_cast<Operation *>()) {
+ visitOperation(op);
+ return success();
+ }
+ if (auto *block = point.dyn_cast<Block *>()) {
+ visitBlock(block);
+ return success();
+ }
+ return emitError(point.getLoc(), "unknown point kind");
+}
+
+void FooAnalysis::visitBlock(Block *block) {
+ if (block->isEntryBlock()) {
+ // This is the initial state. Let the framework default-initialize it.
+ return;
+ }
+ FooState *state = getOrCreate<FooState>(block);
+ ChangeResult result = ChangeResult::NoChange;
+ for (Block *pred : block->getPredecessors()) {
+ // Join the state at the terminators of all predecessors.
+ const FooState *predState =
+ getOrCreateFor<FooState>(block, pred->getTerminator());
+ result |= state->join(*predState);
+ }
+ propagateIfChanged(state, result);
+}
+
+void FooAnalysis::visitOperation(Operation *op) {
+ FooState *state = getOrCreate<FooState>(op);
+ ChangeResult result = ChangeResult::NoChange;
+
+ // Copy the state across the operation.
+ const FooState *prevState;
+ if (Operation *prev = op->getPrevNode())
+ prevState = getOrCreateFor<FooState>(op, prev);
+ else
+ prevState = getOrCreateFor<FooState>(op, op->getBlock());
+ result |= state->set(*prevState);
+
+ // Modify the state with the attribute, if specified.
+ if (auto attr = op->getAttrOfType<IntegerAttr>("foo")) {
+ uint64_t value = attr.getUInt();
+ result |= state->join(value);
+ }
+ propagateIfChanged(state, result);
+}
+
+void TestFooAnalysisPass::runOnOperation() {
+ func::FuncOp func = getOperation();
+ DataFlowSolver solver;
+ solver.load<FooAnalysis>();
+ if (failed(solver.initializeAndRun(func)))
+ return signalPassFailure();
+
+ raw_ostream &os = llvm::errs();
+ os << "function: @" << func.getSymName() << "\n";
+
+ func.walk([&](Operation *op) {
+ auto tag = op->getAttrOfType<StringAttr>("tag");
+ if (!tag)
+ return;
+ const FooState *state = solver.lookupState<FooState>(op);
+ assert(state && !state->isUninitialized());
+ os << tag.getValue() << " -> " << state->getValue() << "\n";
+ });
+}
+
+namespace mlir {
+namespace test {
+void registerTestFooAnalysisPass() { PassRegistration<TestFooAnalysisPass>(); }
+} // namespace test
+} // namespace mlir
diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index b50cfa964290f..9e872ab63f5f4 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -77,6 +77,7 @@ void registerTestDiagnosticsPass();
void registerTestDominancePass();
void registerTestDynamicPipelinePass();
void registerTestExpandMathPass();
+void registerTestFooAnalysisPass();
void registerTestComposeSubView();
void registerTestMultiBuffering();
void registerTestIntRangeInference();
@@ -175,6 +176,7 @@ void registerTestPasses() {
mlir::test::registerTestDominancePass();
mlir::test::registerTestDynamicPipelinePass();
mlir::test::registerTestExpandMathPass();
+ mlir::test::registerTestFooAnalysisPass();
mlir::test::registerTestComposeSubView();
mlir::test::registerTestMultiBuffering();
mlir::test::registerTestIntRangeInference();
diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index ffb3d5e0b52cf..8585b7792b081 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -5787,17 +5787,12 @@ cc_library(
"lib/Analysis/*/*.cpp",
"lib/Analysis/*/*.h",
],
- exclude = [
- "lib/Analysis/Vector*.cpp",
- "lib/Analysis/Vector*.h",
- ],
),
hdrs = glob(
[
"include/mlir/Analysis/*.h",
"include/mlir/Analysis/*/*.h",
],
- exclude = ["include/mlir/Analysis/Vector*.h"],
),
includes = ["include"],
deps = [
diff --git a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel
index 2f5da7e9fee65..fa401a56a2552 100644
--- a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel
@@ -26,6 +26,7 @@ cc_library(
"//mlir:AffineAnalysis",
"//mlir:AffineDialect",
"//mlir:Analysis",
+ "//mlir:FuncDialect",
"//mlir:IR",
"//mlir:MemRefDialect",
"//mlir:Pass",
More information about the Mlir-commits
mailing list