[Mlir-commits] [mlir] [mlir][Linalg] Add speculation for LinalgStructuredOps (PR #108032)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Sep 10 07:17:29 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
@llvm/pr-subscribers-mlir-core
@llvm/pr-subscribers-mlir-linalg
Author: Kunwar Grover (Groverkss)
<details>
<summary>Changes</summary>
This patch adds speculation behavior for linalg structured ops, allowing them to be hoisted out of loops using LICM.
---
Patch is 50.76 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/108032.diff
17 Files Affected:
- (modified) mlir/include/mlir/Analysis/DataFlow/DeadCodeAnalysis.h (+8-8)
- (modified) mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h (+18-19)
- (modified) mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h (+1-1)
- (modified) mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h (+4-4)
- (modified) mlir/include/mlir/Analysis/DataFlowFramework.h (+101-135)
- (modified) mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td (+1)
- (modified) mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp (+15-17)
- (modified) mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp (+14-8)
- (modified) mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp (+1-1)
- (modified) mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp (+16-9)
- (modified) mlir/lib/Analysis/DataFlowFramework.cpp (+16-29)
- (modified) mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp (+31)
- (modified) mlir/test/Transforms/loop-invariant-code-motion.mlir (+45)
- (modified) mlir/test/lib/Analysis/DataFlow/TestDeadCodeAnalysis.cpp (+1-1)
- (modified) mlir/test/lib/Analysis/DataFlow/TestDenseDataFlowAnalysis.h (+1-1)
- (modified) mlir/test/lib/Analysis/TestDataFlowFramework.cpp (+8-4)
- (modified) mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp (+4-1)
``````````diff
diff --git a/mlir/include/mlir/Analysis/DataFlow/DeadCodeAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/DeadCodeAnalysis.h
index 80c8b86c63678a..10ef8b6ba5843a 100644
--- a/mlir/include/mlir/Analysis/DataFlow/DeadCodeAnalysis.h
+++ b/mlir/include/mlir/Analysis/DataFlow/DeadCodeAnalysis.h
@@ -35,21 +35,21 @@ namespace dataflow {
//===----------------------------------------------------------------------===//
/// This is a simple analysis state that represents whether the associated
-/// lattice anchor (either a block or a control-flow edge) is live.
+/// program point (either a block or a control-flow edge) is live.
class Executable : public AnalysisState {
public:
using AnalysisState::AnalysisState;
- /// Set the state of the lattice anchor to live.
+ /// Set the state of the program point to live.
ChangeResult setToLive();
- /// Get whether the lattice anchor is live.
+ /// Get whether the program point is live.
bool isLive() const { return live; }
/// Print the liveness.
void print(raw_ostream &os) const override;
- /// When the state of the lattice anchor is changed to live, re-invoke
+ /// When the state of the program point is changed to live, re-invoke
/// subscribed analyses on the operations in the block and on the block
/// itself.
void onUpdate(DataFlowSolver *solver) const override;
@@ -60,8 +60,8 @@ class Executable : public AnalysisState {
}
private:
- /// Whether the lattice anchor is live. Optimistically assume that the lattice
- /// anchor is dead.
+ /// Whether the program point is live. Optimistically assume that the program
+ /// point is dead.
bool live = false;
/// A set of analyses that should be updated when this state changes.
@@ -140,10 +140,10 @@ class PredecessorState : public AnalysisState {
// CFGEdge
//===----------------------------------------------------------------------===//
-/// This lattice anchor represents a control-flow edge between a block and one
+/// This program point represents a control-flow edge between a block and one
/// of its successors.
class CFGEdge
- : public GenericLatticeAnchorBase<CFGEdge, std::pair<Block *, Block *>> {
+ : public GenericProgramPointBase<CFGEdge, std::pair<Block *, Block *>> {
public:
using Base::Base;
diff --git a/mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h
index 7917f1e3ba6485..4ad5f3fcd838c0 100644
--- a/mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h
+++ b/mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h
@@ -91,16 +91,15 @@ class AbstractDenseForwardDataFlowAnalysis : public DataFlowAnalysis {
const AbstractDenseLattice &before,
AbstractDenseLattice *after) = 0;
- /// Get the dense lattice after the execution of the given lattice anchor.
- virtual AbstractDenseLattice *getLattice(LatticeAnchor anchor) = 0;
+ /// Get the dense lattice after the execution of the given program point.
+ virtual AbstractDenseLattice *getLattice(ProgramPoint point) = 0;
/// Get the dense lattice after the execution of the given program point and
- /// add it as a dependency to a lattice anchor. That is, every time the
- /// lattice after anchor is updated, the dependent program point must be
- /// visited, and the newly triggered visit might update the lattice after
- /// dependent.
+ /// add it as a dependency to a program point. That is, every time the lattice
+ /// after point is updated, the dependent program point must be visited, and
+ /// the newly triggered visit might update the lattice after dependent.
const AbstractDenseLattice *getLatticeFor(ProgramPoint dependent,
- LatticeAnchor anchor);
+ ProgramPoint point);
/// Set the dense lattice at control flow entry point and propagate an update
/// if it changed.
@@ -250,9 +249,9 @@ class DenseForwardDataFlowAnalysis
}
protected:
- /// Get the dense lattice on this lattice anchor.
- LatticeT *getLattice(LatticeAnchor anchor) override {
- return getOrCreate<LatticeT>(anchor);
+ /// Get the dense lattice after this program point.
+ LatticeT *getLattice(ProgramPoint point) override {
+ return getOrCreate<LatticeT>(point);
}
/// Set the dense lattice at control flow entry point and propagate an update
@@ -332,16 +331,16 @@ class AbstractDenseBackwardDataFlowAnalysis : public DataFlowAnalysis {
const AbstractDenseLattice &after,
AbstractDenseLattice *before) = 0;
- /// Get the dense lattice before the execution of the lattice anchor. That is,
+ /// Get the dense lattice before the execution of the program point. That is,
/// before the execution of the given operation or after the execution of the
/// block.
- virtual AbstractDenseLattice *getLattice(LatticeAnchor anchor) = 0;
+ virtual AbstractDenseLattice *getLattice(ProgramPoint point) = 0;
- /// Get the dense lattice before the execution of the program point in
- /// `anchor` and declare that the `dependent` program point must be updated
- /// every time `point` is.
+ /// Get the dense lattice before the execution of the program point `point`
+ /// and declare that the `dependent` program point must be updated every time
+ /// `point` is.
const AbstractDenseLattice *getLatticeFor(ProgramPoint dependent,
- LatticeAnchor anchor);
+ ProgramPoint point);
/// Set the dense lattice before at the control flow exit point and propagate
/// the update if it changed.
@@ -501,9 +500,9 @@ class DenseBackwardDataFlowAnalysis
}
protected:
- /// Get the dense lattice at the given lattice anchor.
- LatticeT *getLattice(LatticeAnchor anchor) override {
- return getOrCreate<LatticeT>(anchor);
+ /// Get the dense lattice at the given program point.
+ LatticeT *getLattice(ProgramPoint point) override {
+ return getOrCreate<LatticeT>(point);
}
/// Set the dense lattice at control flow exit point (after the terminator)
diff --git a/mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h
index f99eae379596b6..d4a5472cfde868 100644
--- a/mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h
+++ b/mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h
@@ -50,7 +50,7 @@ class IntegerRangeAnalysis
/// At an entry point, we cannot reason about interger value ranges.
void setToEntryState(IntegerValueRangeLattice *lattice) override {
propagateIfChanged(lattice, lattice->join(IntegerValueRange::getMaxRange(
- lattice->getAnchor())));
+ lattice->getPoint())));
}
/// Visit an operation. Invoke the transfer function on each operation that
diff --git a/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
index 933790b4f2a6eb..89726ae3a855c8 100644
--- a/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
+++ b/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
@@ -36,8 +36,8 @@ class AbstractSparseLattice : public AnalysisState {
/// Lattices can only be created for values.
AbstractSparseLattice(Value value) : AnalysisState(value) {}
- /// Return the value this lattice is located at.
- Value getAnchor() const { return AnalysisState::getAnchor().get<Value>(); }
+ /// Return the program point this lattice is located at.
+ Value getPoint() const { return AnalysisState::getPoint().get<Value>(); }
/// Join the information contained in 'rhs' into this lattice. Returns
/// if the value of the lattice changed.
@@ -86,8 +86,8 @@ class Lattice : public AbstractSparseLattice {
public:
using AbstractSparseLattice::AbstractSparseLattice;
- /// Return the value this lattice is located at.
- Value getAnchor() const { return anchor.get<Value>(); }
+ /// Return the program point this lattice is located at.
+ Value getPoint() const { return point.get<Value>(); }
/// Return the value held by this lattice. This requires that the value is
/// initialized.
diff --git a/mlir/include/mlir/Analysis/DataFlowFramework.h b/mlir/include/mlir/Analysis/DataFlowFramework.h
index b0450ecdbd99b8..2580ec28b51902 100644
--- a/mlir/include/mlir/Analysis/DataFlowFramework.h
+++ b/mlir/include/mlir/Analysis/DataFlowFramework.h
@@ -49,93 +49,79 @@ inline ChangeResult operator&(ChangeResult lhs, ChangeResult rhs) {
/// Forward declare the analysis state class.
class AnalysisState;
-/// Program point represents a specific location in the execution of a program.
-/// A sequence of program points can be combined into a control flow graph.
-struct ProgramPoint : public PointerUnion<Operation *, Block *> {
- using ParentTy = PointerUnion<Operation *, Block *>;
- /// Inherit constructors.
- using ParentTy::PointerUnion;
- /// Allow implicit conversion from the parent type.
- ProgramPoint(ParentTy point = nullptr) : ParentTy(point) {}
- /// Allow implicit conversions from operation wrappers.
- /// TODO: For Windows only. Find a better solution.
- template <typename OpT, typename = std::enable_if_t<
- std::is_convertible<OpT, Operation *>::value &&
- !std::is_same<OpT, Operation *>::value>>
- ProgramPoint(OpT op) : ParentTy(op) {}
-
- /// Print the program point.
- void print(raw_ostream &os) const;
-};
-
//===----------------------------------------------------------------------===//
-// GenericLatticeAnchor
+// GenericProgramPoint
//===----------------------------------------------------------------------===//
-/// Abstract class for generic lattice anchor. In classical data-flow analysis,
-/// lattice anchor represent positions in a program to which lattice elements
+/// Abstract class for generic program points. In classical data-flow analysis,
+/// programs points represent positions in a program to which lattice elements
/// are attached. In sparse data-flow analysis, these can be SSA values, and in
/// dense data-flow analysis, these are the program points before and after
/// every operation.
///
-/// Lattice anchor are implemented using MLIR's storage uniquer framework and
+/// In the general MLIR data-flow analysis framework, program points are an
+/// extensible concept. Program points are uniquely identifiable objects to
+/// which analysis states can be attached. The semantics of program points are
+/// defined by the analyses that specify their transfer functions.
+///
+/// Program points are implemented using MLIR's storage uniquer framework and
/// type ID system to provide RTTI.
-class GenericLatticeAnchor : public StorageUniquer::BaseStorage {
+class GenericProgramPoint : public StorageUniquer::BaseStorage {
public:
- virtual ~GenericLatticeAnchor();
+ virtual ~GenericProgramPoint();
- /// Get the abstract lattice anchor's type identifier.
+ /// Get the abstract program point's type identifier.
TypeID getTypeID() const { return typeID; }
- /// Get a derived source location for the lattice anchor.
+ /// Get a derived source location for the program point.
virtual Location getLoc() const = 0;
- /// Print the lattice anchor.
+ /// Print the program point.
virtual void print(raw_ostream &os) const = 0;
protected:
- /// Create an abstract lattice anchor with type identifier.
- explicit GenericLatticeAnchor(TypeID typeID) : typeID(typeID) {}
+ /// Create an abstract program point with type identifier.
+ explicit GenericProgramPoint(TypeID typeID) : typeID(typeID) {}
private:
- /// The type identifier of the lattice anchor.
+ /// The type identifier of the program point.
TypeID typeID;
};
//===----------------------------------------------------------------------===//
-// GenericLatticeAnchorBase
+// GenericProgramPointBase
//===----------------------------------------------------------------------===//
-/// Base class for generic lattice anchor based on a concrete lattice anchor
+/// Base class for generic program points based on a concrete program point
/// type and a content key. This class defines the common methods required for
/// operability with the storage uniquer framework.
///
-/// The provided key type uniquely identifies the concrete lattice anchor
+/// The provided key type uniquely identifies the concrete program point
/// instance and are the data members of the class.
template <typename ConcreteT, typename Value>
-class GenericLatticeAnchorBase : public GenericLatticeAnchor {
+class GenericProgramPointBase : public GenericProgramPoint {
public:
/// The concrete key type used by the storage uniquer. This class is uniqued
/// by its contents.
using KeyTy = Value;
/// Alias for the base class.
- using Base = GenericLatticeAnchorBase<ConcreteT, Value>;
+ using Base = GenericProgramPointBase<ConcreteT, Value>;
- /// Construct an instance of the lattice anchor using the provided value and
+ /// Construct an instance of the program point using the provided value and
/// the type ID of the concrete type.
template <typename ValueT>
- explicit GenericLatticeAnchorBase(ValueT &&value)
- : GenericLatticeAnchor(TypeID::get<ConcreteT>()),
+ explicit GenericProgramPointBase(ValueT &&value)
+ : GenericProgramPoint(TypeID::get<ConcreteT>()),
value(std::forward<ValueT>(value)) {}
- /// Get a uniqued instance of this lattice anchor class with the given
+ /// Get a uniqued instance of this program point class with the given
/// arguments.
template <typename... Args>
static ConcreteT *get(StorageUniquer &uniquer, Args &&...args) {
return uniquer.get<ConcreteT>(/*initFn=*/{}, std::forward<Args>(args)...);
}
- /// Allocate space for a lattice anchor and construct it in-place.
+ /// Allocate space for a program point and construct it in-place.
template <typename ValueT>
static ConcreteT *construct(StorageUniquer::StorageAllocator &alloc,
ValueT &&value) {
@@ -143,48 +129,46 @@ class GenericLatticeAnchorBase : public GenericLatticeAnchor {
ConcreteT(std::forward<ValueT>(value));
}
- /// Two lattice anchors are equal if their values are equal.
+ /// Two program points are equal if their values are equal.
bool operator==(const Value &value) const { return this->value == value; }
/// Provide LLVM-style RTTI using type IDs.
- static bool classof(const GenericLatticeAnchor *point) {
+ static bool classof(const GenericProgramPoint *point) {
return point->getTypeID() == TypeID::get<ConcreteT>();
}
- /// Get the contents of the lattice anchor.
+ /// Get the contents of the program point.
const Value &getValue() const { return value; }
private:
- /// The lattice anchor value.
+ /// The program point value.
Value value;
};
//===----------------------------------------------------------------------===//
-// LatticeAnchor
+// ProgramPoint
//===----------------------------------------------------------------------===//
-/// Fundamental IR components are supported as first-class lattice anchor.
-struct LatticeAnchor
- : public PointerUnion<GenericLatticeAnchor *, ProgramPoint, Value> {
- using ParentTy = PointerUnion<GenericLatticeAnchor *, ProgramPoint, Value>;
+/// Fundamental IR components are supported as first-class program points.
+struct ProgramPoint
+ : public PointerUnion<GenericProgramPoint *, Operation *, Value, Block *> {
+ using ParentTy =
+ PointerUnion<GenericProgramPoint *, Operation *, Value, Block *>;
/// Inherit constructors.
using ParentTy::PointerUnion;
/// Allow implicit conversion from the parent type.
- LatticeAnchor(ParentTy point = nullptr) : ParentTy(point) {}
+ ProgramPoint(ParentTy point = nullptr) : ParentTy(point) {}
/// Allow implicit conversions from operation wrappers.
/// TODO: For Windows only. Find a better solution.
template <typename OpT, typename = std::enable_if_t<
std::is_convertible<OpT, Operation *>::value &&
!std::is_same<OpT, Operation *>::value>>
- LatticeAnchor(OpT op) : ParentTy(ProgramPoint(op)) {}
-
- LatticeAnchor(Operation *op) : ParentTy(ProgramPoint(op)) {}
- LatticeAnchor(Block *block) : ParentTy(ProgramPoint(block)) {}
+ ProgramPoint(OpT op) : ParentTy(op) {}
- /// Print the lattice anchor.
+ /// Print the program point.
void print(raw_ostream &os) const;
- /// Get the source location of the lattice anchor.
+ /// Get the source location of the program point.
Location getLoc() const;
};
@@ -223,8 +207,8 @@ class DataFlowConfig {
/// The general data-flow analysis solver. This class is responsible for
/// orchestrating child data-flow analyses, running the fixed-point iteration
-/// algorithm, managing analysis state and lattice anchor memory, and tracking
-/// dependencies between analyses, lattice anchor, and analysis states.
+/// algorithm, managing analysis state and program point memory, and tracking
+/// dependencies between analyses, program points, and analysis states.
///
/// Steps to run a data-flow analysis:
///
@@ -248,33 +232,32 @@ class DataFlowSolver {
/// operation and run the analysis until fixpoint.
LogicalResult initializeAndRun(Operation *top);
- /// Lookup an analysis state for the given lattice anchor. Returns null if one
+ /// Lookup an analysis state for the given program point. Returns null if one
/// does not exist.
- template <typename StateT, typename AnchorT>
- const StateT *lookupState(AnchorT anchor) const {
- auto it =
- analysisStates.find({LatticeAnchor(anchor), TypeID::get<StateT>()});
+ template <typename StateT, typename PointT>
+ const StateT *lookupState(PointT point) const {
+ auto it = analysisStates.find({ProgramPoint(point), TypeID::get<StateT>()});
if (it == analysisStates.end())
return nullptr;
return static_cast<const StateT *>(it->second.get());
}
- /// Erase any analysis state associated with the given lattice anchor.
- template <typename AnchorT>
- void eraseState(AnchorT anchor) {
- LatticeAnchor la(anchor);
+ /// Erase any analysis state associated with the given program point.
+ template <typename PointT>
+ void eraseState(PointT point) {
+ ProgramPoint pp(point);
for (auto it = analysisStates.begin(); it != analysisStates.end(); ++it) {
- if (it->first.first == la)
+ if (it->first.first == pp)
analysisStates.erase(it);
}
}
- /// Get a uniqued lattice anchor instance. If one is not present, it is
+ /// Get a uniqued program point instance. If one is not present, it is
/// created with the provided arguments.
- template <typename AnchorT, typename... Args>
- AnchorT *getLatticeAnchor(Args &&...args) {
- return AnchorT::get(uniquer, std::forward<Args>(args)...);
+ template <typename PointT, typename... Args>
+ PointT *getProgramPoint(Args &&...args) {
+ return PointT::get(uniquer, std::forward<Args>(args)...);
}
/// A work item on the solver queue is a program point, child analysis pair.
@@ -284,10 +267,10 @@ class DataFlowSolver {
/// Push a work item onto the worklist.
void enqueue(WorkItem item) { worklist.push(std::move(item)); }
- /// Get the state associated with the given lattice anchor. If it does not
+ /// Get the state associated with the given program point. If it does not
/// exist, create an uninitialized state.
- template <typename StateT, typename AnchorT>
- StateT *getOrCreateState(AnchorT anchor);
+ template <typename StateT, typename PointT>
+ StateT *getOrCreateState(PointT point);
/// Propagate an update to an analysis state if it changed by pushing
/// dependent work items to the back of the queue.
@@ -308,13 +291,13 @@ class DataFlowSolver {
/// Type-erased instances of the children analyses.
SmallVector<std::unique_ptr<DataFlowAnalysis>> childAnalyses;
- /// The storage uniquer instance that owns the memory of the allocated lattice
- /// anchors
+ /// The storage uniquer instance that owns the memory of the allocated program
+ /// points.
StorageUniquer uniquer;
- /// A ...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/108032
More information about the Mlir-commits
mailing list