[Mlir-commits] [mlir] [mlir][Linalg] Add speculation for LinalgStructuredOps (PR #108032)
Kunwar Grover
llvmlistbot at llvm.org
Tue Sep 10 07:16:57 PDT 2024
https://github.com/Groverkss created https://github.com/llvm/llvm-project/pull/108032
This patch adds speculation behavior for linalg structured ops, allowing them to be hoisted out of loops using LICM.
>From ac0ef19c9953bdb947461bd7f90a85ce92b6b32f Mon Sep 17 00:00:00 2001
From: MaheshRavishankar <mahesh.ravishankar at gmail.com>
Date: Wed, 28 Aug 2024 16:10:11 -0700
Subject: [PATCH 1/4] Revert "[mlir] [dataflow] Refactoring the definition of
program points in data flow analysis (#105656)"
This reverts commit b6603e1bf11dee4761e49af6581c8b8f074b705d.
---
.../mlir/Analysis/DataFlow/DeadCodeAnalysis.h | 16 +-
.../mlir/Analysis/DataFlow/DenseAnalysis.h | 37 ++-
.../Analysis/DataFlow/IntegerRangeAnalysis.h | 2 +-
.../mlir/Analysis/DataFlow/SparseAnalysis.h | 8 +-
.../include/mlir/Analysis/DataFlowFramework.h | 236 ++++++++----------
.../Analysis/DataFlow/DeadCodeAnalysis.cpp | 32 ++-
mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp | 22 +-
.../DataFlow/IntegerRangeAnalysis.cpp | 2 +-
mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp | 25 +-
mlir/lib/Analysis/DataFlowFramework.cpp | 45 ++--
.../DataFlow/TestDeadCodeAnalysis.cpp | 2 +-
.../DataFlow/TestDenseDataFlowAnalysis.h | 2 +-
.../lib/Analysis/TestDataFlowFramework.cpp | 12 +-
13 files changed, 204 insertions(+), 237 deletions(-)
diff --git a/mlir/include/mlir/Analysis/DataFlow/DeadCodeAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/DeadCodeAnalysis.h
index 80c8b86c63678a..10ef8b6ba5843a 100644
--- a/mlir/include/mlir/Analysis/DataFlow/DeadCodeAnalysis.h
+++ b/mlir/include/mlir/Analysis/DataFlow/DeadCodeAnalysis.h
@@ -35,21 +35,21 @@ namespace dataflow {
//===----------------------------------------------------------------------===//
/// This is a simple analysis state that represents whether the associated
-/// lattice anchor (either a block or a control-flow edge) is live.
+/// program point (either a block or a control-flow edge) is live.
class Executable : public AnalysisState {
public:
using AnalysisState::AnalysisState;
- /// Set the state of the lattice anchor to live.
+ /// Set the state of the program point to live.
ChangeResult setToLive();
- /// Get whether the lattice anchor is live.
+ /// Get whether the program point is live.
bool isLive() const { return live; }
/// Print the liveness.
void print(raw_ostream &os) const override;
- /// When the state of the lattice anchor is changed to live, re-invoke
+ /// When the state of the program point is changed to live, re-invoke
/// subscribed analyses on the operations in the block and on the block
/// itself.
void onUpdate(DataFlowSolver *solver) const override;
@@ -60,8 +60,8 @@ class Executable : public AnalysisState {
}
private:
- /// Whether the lattice anchor is live. Optimistically assume that the lattice
- /// anchor is dead.
+ /// Whether the program point is live. Optimistically assume that the program
+ /// point is dead.
bool live = false;
/// A set of analyses that should be updated when this state changes.
@@ -140,10 +140,10 @@ class PredecessorState : public AnalysisState {
// CFGEdge
//===----------------------------------------------------------------------===//
-/// This lattice anchor represents a control-flow edge between a block and one
+/// This program point represents a control-flow edge between a block and one
/// of its successors.
class CFGEdge
- : public GenericLatticeAnchorBase<CFGEdge, std::pair<Block *, Block *>> {
+ : public GenericProgramPointBase<CFGEdge, std::pair<Block *, Block *>> {
public:
using Base::Base;
diff --git a/mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h
index 7917f1e3ba6485..4ad5f3fcd838c0 100644
--- a/mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h
+++ b/mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h
@@ -91,16 +91,15 @@ class AbstractDenseForwardDataFlowAnalysis : public DataFlowAnalysis {
const AbstractDenseLattice &before,
AbstractDenseLattice *after) = 0;
- /// Get the dense lattice after the execution of the given lattice anchor.
- virtual AbstractDenseLattice *getLattice(LatticeAnchor anchor) = 0;
+ /// Get the dense lattice after the execution of the given program point.
+ virtual AbstractDenseLattice *getLattice(ProgramPoint point) = 0;
/// Get the dense lattice after the execution of the given program point and
- /// add it as a dependency to a lattice anchor. That is, every time the
- /// lattice after anchor is updated, the dependent program point must be
- /// visited, and the newly triggered visit might update the lattice after
- /// dependent.
+ /// add it as a dependency to a program point. That is, every time the lattice
+ /// after point is updated, the dependent program point must be visited, and
+ /// the newly triggered visit might update the lattice after dependent.
const AbstractDenseLattice *getLatticeFor(ProgramPoint dependent,
- LatticeAnchor anchor);
+ ProgramPoint point);
/// Set the dense lattice at control flow entry point and propagate an update
/// if it changed.
@@ -250,9 +249,9 @@ class DenseForwardDataFlowAnalysis
}
protected:
- /// Get the dense lattice on this lattice anchor.
- LatticeT *getLattice(LatticeAnchor anchor) override {
- return getOrCreate<LatticeT>(anchor);
+ /// Get the dense lattice after this program point.
+ LatticeT *getLattice(ProgramPoint point) override {
+ return getOrCreate<LatticeT>(point);
}
/// Set the dense lattice at control flow entry point and propagate an update
@@ -332,16 +331,16 @@ class AbstractDenseBackwardDataFlowAnalysis : public DataFlowAnalysis {
const AbstractDenseLattice &after,
AbstractDenseLattice *before) = 0;
- /// Get the dense lattice before the execution of the lattice anchor. That is,
+ /// Get the dense lattice before the execution of the program point. That is,
/// before the execution of the given operation or after the execution of the
/// block.
- virtual AbstractDenseLattice *getLattice(LatticeAnchor anchor) = 0;
+ virtual AbstractDenseLattice *getLattice(ProgramPoint point) = 0;
- /// Get the dense lattice before the execution of the program point in
- /// `anchor` and declare that the `dependent` program point must be updated
- /// every time `point` is.
+ /// Get the dense lattice before the execution of the program point `point`
+ /// and declare that the `dependent` program point must be updated every time
+ /// `point` is.
const AbstractDenseLattice *getLatticeFor(ProgramPoint dependent,
- LatticeAnchor anchor);
+ ProgramPoint point);
/// Set the dense lattice before at the control flow exit point and propagate
/// the update if it changed.
@@ -501,9 +500,9 @@ class DenseBackwardDataFlowAnalysis
}
protected:
- /// Get the dense lattice at the given lattice anchor.
- LatticeT *getLattice(LatticeAnchor anchor) override {
- return getOrCreate<LatticeT>(anchor);
+ /// Get the dense lattice at the given program point.
+ LatticeT *getLattice(ProgramPoint point) override {
+ return getOrCreate<LatticeT>(point);
}
/// Set the dense lattice at control flow exit point (after the terminator)
diff --git a/mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h
index f99eae379596b6..d4a5472cfde868 100644
--- a/mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h
+++ b/mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h
@@ -50,7 +50,7 @@ class IntegerRangeAnalysis
/// At an entry point, we cannot reason about interger value ranges.
void setToEntryState(IntegerValueRangeLattice *lattice) override {
propagateIfChanged(lattice, lattice->join(IntegerValueRange::getMaxRange(
- lattice->getAnchor())));
+ lattice->getPoint())));
}
/// Visit an operation. Invoke the transfer function on each operation that
diff --git a/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
index 933790b4f2a6eb..89726ae3a855c8 100644
--- a/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
+++ b/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
@@ -36,8 +36,8 @@ class AbstractSparseLattice : public AnalysisState {
/// Lattices can only be created for values.
AbstractSparseLattice(Value value) : AnalysisState(value) {}
- /// Return the value this lattice is located at.
- Value getAnchor() const { return AnalysisState::getAnchor().get<Value>(); }
+ /// Return the program point this lattice is located at.
+ Value getPoint() const { return AnalysisState::getPoint().get<Value>(); }
/// Join the information contained in 'rhs' into this lattice. Returns
/// if the value of the lattice changed.
@@ -86,8 +86,8 @@ class Lattice : public AbstractSparseLattice {
public:
using AbstractSparseLattice::AbstractSparseLattice;
- /// Return the value this lattice is located at.
- Value getAnchor() const { return anchor.get<Value>(); }
+ /// Return the program point this lattice is located at.
+ Value getPoint() const { return point.get<Value>(); }
/// Return the value held by this lattice. This requires that the value is
/// initialized.
diff --git a/mlir/include/mlir/Analysis/DataFlowFramework.h b/mlir/include/mlir/Analysis/DataFlowFramework.h
index b0450ecdbd99b8..2580ec28b51902 100644
--- a/mlir/include/mlir/Analysis/DataFlowFramework.h
+++ b/mlir/include/mlir/Analysis/DataFlowFramework.h
@@ -49,93 +49,79 @@ inline ChangeResult operator&(ChangeResult lhs, ChangeResult rhs) {
/// Forward declare the analysis state class.
class AnalysisState;
-/// Program point represents a specific location in the execution of a program.
-/// A sequence of program points can be combined into a control flow graph.
-struct ProgramPoint : public PointerUnion<Operation *, Block *> {
- using ParentTy = PointerUnion<Operation *, Block *>;
- /// Inherit constructors.
- using ParentTy::PointerUnion;
- /// Allow implicit conversion from the parent type.
- ProgramPoint(ParentTy point = nullptr) : ParentTy(point) {}
- /// Allow implicit conversions from operation wrappers.
- /// TODO: For Windows only. Find a better solution.
- template <typename OpT, typename = std::enable_if_t<
- std::is_convertible<OpT, Operation *>::value &&
- !std::is_same<OpT, Operation *>::value>>
- ProgramPoint(OpT op) : ParentTy(op) {}
-
- /// Print the program point.
- void print(raw_ostream &os) const;
-};
-
//===----------------------------------------------------------------------===//
-// GenericLatticeAnchor
+// GenericProgramPoint
//===----------------------------------------------------------------------===//
-/// Abstract class for generic lattice anchor. In classical data-flow analysis,
-/// lattice anchor represent positions in a program to which lattice elements
+/// Abstract class for generic program points. In classical data-flow analysis,
+/// programs points represent positions in a program to which lattice elements
/// are attached. In sparse data-flow analysis, these can be SSA values, and in
/// dense data-flow analysis, these are the program points before and after
/// every operation.
///
-/// Lattice anchor are implemented using MLIR's storage uniquer framework and
+/// In the general MLIR data-flow analysis framework, program points are an
+/// extensible concept. Program points are uniquely identifiable objects to
+/// which analysis states can be attached. The semantics of program points are
+/// defined by the analyses that specify their transfer functions.
+///
+/// Program points are implemented using MLIR's storage uniquer framework and
/// type ID system to provide RTTI.
-class GenericLatticeAnchor : public StorageUniquer::BaseStorage {
+class GenericProgramPoint : public StorageUniquer::BaseStorage {
public:
- virtual ~GenericLatticeAnchor();
+ virtual ~GenericProgramPoint();
- /// Get the abstract lattice anchor's type identifier.
+ /// Get the abstract program point's type identifier.
TypeID getTypeID() const { return typeID; }
- /// Get a derived source location for the lattice anchor.
+ /// Get a derived source location for the program point.
virtual Location getLoc() const = 0;
- /// Print the lattice anchor.
+ /// Print the program point.
virtual void print(raw_ostream &os) const = 0;
protected:
- /// Create an abstract lattice anchor with type identifier.
- explicit GenericLatticeAnchor(TypeID typeID) : typeID(typeID) {}
+ /// Create an abstract program point with type identifier.
+ explicit GenericProgramPoint(TypeID typeID) : typeID(typeID) {}
private:
- /// The type identifier of the lattice anchor.
+ /// The type identifier of the program point.
TypeID typeID;
};
//===----------------------------------------------------------------------===//
-// GenericLatticeAnchorBase
+// GenericProgramPointBase
//===----------------------------------------------------------------------===//
-/// Base class for generic lattice anchor based on a concrete lattice anchor
+/// Base class for generic program points based on a concrete program point
/// type and a content key. This class defines the common methods required for
/// operability with the storage uniquer framework.
///
-/// The provided key type uniquely identifies the concrete lattice anchor
+/// The provided key type uniquely identifies the concrete program point
/// instance and are the data members of the class.
template <typename ConcreteT, typename Value>
-class GenericLatticeAnchorBase : public GenericLatticeAnchor {
+class GenericProgramPointBase : public GenericProgramPoint {
public:
/// The concrete key type used by the storage uniquer. This class is uniqued
/// by its contents.
using KeyTy = Value;
/// Alias for the base class.
- using Base = GenericLatticeAnchorBase<ConcreteT, Value>;
+ using Base = GenericProgramPointBase<ConcreteT, Value>;
- /// Construct an instance of the lattice anchor using the provided value and
+ /// Construct an instance of the program point using the provided value and
/// the type ID of the concrete type.
template <typename ValueT>
- explicit GenericLatticeAnchorBase(ValueT &&value)
- : GenericLatticeAnchor(TypeID::get<ConcreteT>()),
+ explicit GenericProgramPointBase(ValueT &&value)
+ : GenericProgramPoint(TypeID::get<ConcreteT>()),
value(std::forward<ValueT>(value)) {}
- /// Get a uniqued instance of this lattice anchor class with the given
+ /// Get a uniqued instance of this program point class with the given
/// arguments.
template <typename... Args>
static ConcreteT *get(StorageUniquer &uniquer, Args &&...args) {
return uniquer.get<ConcreteT>(/*initFn=*/{}, std::forward<Args>(args)...);
}
- /// Allocate space for a lattice anchor and construct it in-place.
+ /// Allocate space for a program point and construct it in-place.
template <typename ValueT>
static ConcreteT *construct(StorageUniquer::StorageAllocator &alloc,
ValueT &&value) {
@@ -143,48 +129,46 @@ class GenericLatticeAnchorBase : public GenericLatticeAnchor {
ConcreteT(std::forward<ValueT>(value));
}
- /// Two lattice anchors are equal if their values are equal.
+ /// Two program points are equal if their values are equal.
bool operator==(const Value &value) const { return this->value == value; }
/// Provide LLVM-style RTTI using type IDs.
- static bool classof(const GenericLatticeAnchor *point) {
+ static bool classof(const GenericProgramPoint *point) {
return point->getTypeID() == TypeID::get<ConcreteT>();
}
- /// Get the contents of the lattice anchor.
+ /// Get the contents of the program point.
const Value &getValue() const { return value; }
private:
- /// The lattice anchor value.
+ /// The program point value.
Value value;
};
//===----------------------------------------------------------------------===//
-// LatticeAnchor
+// ProgramPoint
//===----------------------------------------------------------------------===//
-/// Fundamental IR components are supported as first-class lattice anchor.
-struct LatticeAnchor
- : public PointerUnion<GenericLatticeAnchor *, ProgramPoint, Value> {
- using ParentTy = PointerUnion<GenericLatticeAnchor *, ProgramPoint, Value>;
+/// Fundamental IR components are supported as first-class program points.
+struct ProgramPoint
+ : public PointerUnion<GenericProgramPoint *, Operation *, Value, Block *> {
+ using ParentTy =
+ PointerUnion<GenericProgramPoint *, Operation *, Value, Block *>;
/// Inherit constructors.
using ParentTy::PointerUnion;
/// Allow implicit conversion from the parent type.
- LatticeAnchor(ParentTy point = nullptr) : ParentTy(point) {}
+ ProgramPoint(ParentTy point = nullptr) : ParentTy(point) {}
/// Allow implicit conversions from operation wrappers.
/// TODO: For Windows only. Find a better solution.
template <typename OpT, typename = std::enable_if_t<
std::is_convertible<OpT, Operation *>::value &&
!std::is_same<OpT, Operation *>::value>>
- LatticeAnchor(OpT op) : ParentTy(ProgramPoint(op)) {}
-
- LatticeAnchor(Operation *op) : ParentTy(ProgramPoint(op)) {}
- LatticeAnchor(Block *block) : ParentTy(ProgramPoint(block)) {}
+ ProgramPoint(OpT op) : ParentTy(op) {}
- /// Print the lattice anchor.
+ /// Print the program point.
void print(raw_ostream &os) const;
- /// Get the source location of the lattice anchor.
+ /// Get the source location of the program point.
Location getLoc() const;
};
@@ -223,8 +207,8 @@ class DataFlowConfig {
/// The general data-flow analysis solver. This class is responsible for
/// orchestrating child data-flow analyses, running the fixed-point iteration
-/// algorithm, managing analysis state and lattice anchor memory, and tracking
-/// dependencies between analyses, lattice anchor, and analysis states.
+/// algorithm, managing analysis state and program point memory, and tracking
+/// dependencies between analyses, program points, and analysis states.
///
/// Steps to run a data-flow analysis:
///
@@ -248,33 +232,32 @@ class DataFlowSolver {
/// operation and run the analysis until fixpoint.
LogicalResult initializeAndRun(Operation *top);
- /// Lookup an analysis state for the given lattice anchor. Returns null if one
+ /// Lookup an analysis state for the given program point. Returns null if one
/// does not exist.
- template <typename StateT, typename AnchorT>
- const StateT *lookupState(AnchorT anchor) const {
- auto it =
- analysisStates.find({LatticeAnchor(anchor), TypeID::get<StateT>()});
+ template <typename StateT, typename PointT>
+ const StateT *lookupState(PointT point) const {
+ auto it = analysisStates.find({ProgramPoint(point), TypeID::get<StateT>()});
if (it == analysisStates.end())
return nullptr;
return static_cast<const StateT *>(it->second.get());
}
- /// Erase any analysis state associated with the given lattice anchor.
- template <typename AnchorT>
- void eraseState(AnchorT anchor) {
- LatticeAnchor la(anchor);
+ /// Erase any analysis state associated with the given program point.
+ template <typename PointT>
+ void eraseState(PointT point) {
+ ProgramPoint pp(point);
for (auto it = analysisStates.begin(); it != analysisStates.end(); ++it) {
- if (it->first.first == la)
+ if (it->first.first == pp)
analysisStates.erase(it);
}
}
- /// Get a uniqued lattice anchor instance. If one is not present, it is
+ /// Get a uniqued program point instance. If one is not present, it is
/// created with the provided arguments.
- template <typename AnchorT, typename... Args>
- AnchorT *getLatticeAnchor(Args &&...args) {
- return AnchorT::get(uniquer, std::forward<Args>(args)...);
+ template <typename PointT, typename... Args>
+ PointT *getProgramPoint(Args &&...args) {
+ return PointT::get(uniquer, std::forward<Args>(args)...);
}
/// A work item on the solver queue is a program point, child analysis pair.
@@ -284,10 +267,10 @@ class DataFlowSolver {
/// Push a work item onto the worklist.
void enqueue(WorkItem item) { worklist.push(std::move(item)); }
- /// Get the state associated with the given lattice anchor. If it does not
+ /// Get the state associated with the given program point. If it does not
/// exist, create an uninitialized state.
- template <typename StateT, typename AnchorT>
- StateT *getOrCreateState(AnchorT anchor);
+ template <typename StateT, typename PointT>
+ StateT *getOrCreateState(PointT point);
/// Propagate an update to an analysis state if it changed by pushing
/// dependent work items to the back of the queue.
@@ -308,13 +291,13 @@ class DataFlowSolver {
/// Type-erased instances of the children analyses.
SmallVector<std::unique_ptr<DataFlowAnalysis>> childAnalyses;
- /// The storage uniquer instance that owns the memory of the allocated lattice
- /// anchors
+ /// The storage uniquer instance that owns the memory of the allocated program
+ /// points.
StorageUniquer uniquer;
- /// A type-erased map of lattice anchors to associated analysis states for
- /// first-class lattice anchors.
- DenseMap<std::pair<LatticeAnchor, TypeID>, std::unique_ptr<AnalysisState>>
+ /// A type-erased map of program points to associated analysis states for
+ /// first-class program points.
+ DenseMap<std::pair<ProgramPoint, TypeID>, std::unique_ptr<AnalysisState>>
analysisStates;
/// Allow the base child analysis class to access the internals of the solver.
@@ -326,13 +309,13 @@ class DataFlowSolver {
//===----------------------------------------------------------------------===//
/// Base class for generic analysis states. Analysis states contain data-flow
-/// information that are attached to lattice anchors and which evolve as the
+/// information that are attached to program points and which evolve as the
/// analysis iterates.
///
/// This class places no restrictions on the semantics of analysis states beyond
/// these requirements.
///
-/// 1. Querying the state of a lattice anchor prior to visiting that anchor
+/// 1. Querying the state of a program point prior to visiting that point
/// results in uninitialized state. Analyses must be aware of unintialized
/// states.
/// 2. Analysis states can reach fixpoints, where subsequent updates will never
@@ -343,20 +326,20 @@ class AnalysisState {
public:
virtual ~AnalysisState();
- /// Create the analysis state at the given lattice anchor.
- AnalysisState(LatticeAnchor anchor) : anchor(anchor) {}
+ /// Create the analysis state at the given program point.
+ AnalysisState(ProgramPoint point) : point(point) {}
- /// Returns the lattice anchor this state is located at.
- LatticeAnchor getAnchor() const { return anchor; }
+ /// Returns the program point this state is located at.
+ ProgramPoint getPoint() const { return point; }
/// Print the contents of the analysis state.
virtual void print(raw_ostream &os) const = 0;
LLVM_DUMP_METHOD void dump() const;
- /// Add a dependency to this analysis state on a lattice anchor and an
+ /// Add a dependency to this analysis state on a program point and an
/// analysis. If this state is updated, the analysis will be invoked on the
- /// given lattice anchor again (in onUpdate()).
- void addDependency(ProgramPoint point, DataFlowAnalysis *analysis);
+ /// given program point again (in onUpdate()).
+ void addDependency(ProgramPoint dependent, DataFlowAnalysis *analysis);
protected:
/// This function is called by the solver when the analysis state is updated
@@ -368,8 +351,8 @@ class AnalysisState {
solver->enqueue(item);
}
- /// The lattice anchor to which the state belongs.
- LatticeAnchor anchor;
+ /// The program point to which the state belongs.
+ ProgramPoint point;
#if LLVM_ENABLE_ABI_BREAKING_CHECKS
/// When compiling with debugging, keep a name for the analysis state.
@@ -378,8 +361,8 @@ class AnalysisState {
private:
/// The dependency relations originating from this analysis state. An entry
- /// `state -> (analysis, anchor)` is created when `analysis` queries `state`
- /// when updating `anchor`.
+ /// `state -> (analysis, point)` is created when `analysis` queries `state`
+ /// when updating `point`.
///
/// When this state is updated, all dependent child analysis invocations are
/// pushed to the back of the queue. Use a `SetVector` to keep the analysis
@@ -420,7 +403,7 @@ class DataFlowAnalysis {
explicit DataFlowAnalysis(DataFlowSolver &solver);
/// Initialize the analysis from the provided top-level operation by building
- /// an initial dependency graph between all lattice anchors of interest. This
+ /// an initial dependency graph between all program points of interest. This
/// can be implemented by calling `visit` on all program points of interest
/// below the top-level operation.
///
@@ -449,39 +432,39 @@ class DataFlowAnalysis {
virtual LogicalResult visit(ProgramPoint point) = 0;
protected:
- /// Create a dependency between the given analysis state and lattice anchor
+ /// Create a dependency between the given analysis state and program point
/// on this analysis.
void addDependency(AnalysisState *state, ProgramPoint point);
/// Propagate an update to a state if it changed.
void propagateIfChanged(AnalysisState *state, ChangeResult changed);
- /// Register a custom lattice anchor class.
- template <typename AnchorT>
- void registerAnchorKind() {
- solver.uniquer.registerParametricStorageType<AnchorT>();
+ /// Register a custom program point class.
+ template <typename PointT>
+ void registerPointKind() {
+ solver.uniquer.registerParametricStorageType<PointT>();
}
- /// Get or create a custom lattice anchor.
- template <typename AnchorT, typename... Args>
- AnchorT *getLatticeAnchor(Args &&...args) {
- return solver.getLatticeAnchor<AnchorT>(std::forward<Args>(args)...);
+ /// Get or create a custom program point.
+ template <typename PointT, typename... Args>
+ PointT *getProgramPoint(Args &&...args) {
+ return solver.getProgramPoint<PointT>(std::forward<Args>(args)...);
}
- /// Get the analysis state associated with the lattice anchor. The returned
+ /// Get the analysis state associated with the program point. The returned
/// state is expected to be "write-only", and any updates need to be
/// propagated by `propagateIfChanged`.
- template <typename StateT, typename AnchorT>
- StateT *getOrCreate(AnchorT anchor) {
- return solver.getOrCreateState<StateT>(anchor);
+ template <typename StateT, typename PointT>
+ StateT *getOrCreate(PointT point) {
+ return solver.getOrCreateState<StateT>(point);
}
/// Get a read-only analysis state for the given point and create a dependency
/// on `dependent`. If the return state is updated elsewhere, this analysis is
/// re-invoked on the dependent.
- template <typename StateT, typename AnchorT>
- const StateT *getOrCreateFor(ProgramPoint dependent, AnchorT anchor) {
- StateT *state = getOrCreate<StateT>(anchor);
+ template <typename StateT, typename PointT>
+ const StateT *getOrCreateFor(ProgramPoint dependent, PointT point) {
+ StateT *state = getOrCreate<StateT>(point);
addDependency(state, dependent);
return state;
}
@@ -511,12 +494,12 @@ AnalysisT *DataFlowSolver::load(Args &&...args) {
return static_cast<AnalysisT *>(childAnalyses.back().get());
}
-template <typename StateT, typename AnchorT>
-StateT *DataFlowSolver::getOrCreateState(AnchorT anchor) {
+template <typename StateT, typename PointT>
+StateT *DataFlowSolver::getOrCreateState(PointT point) {
std::unique_ptr<AnalysisState> &state =
- analysisStates[{LatticeAnchor(anchor), TypeID::get<StateT>()}];
+ analysisStates[{ProgramPoint(point), TypeID::get<StateT>()}];
if (!state) {
- state = std::unique_ptr<StateT>(new StateT(anchor));
+ state = std::unique_ptr<StateT>(new StateT(point));
#if LLVM_ENABLE_ABI_BREAKING_CHECKS
state->debugName = llvm::getTypeName<StateT>();
#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
@@ -529,32 +512,20 @@ inline raw_ostream &operator<<(raw_ostream &os, const AnalysisState &state) {
return os;
}
-inline raw_ostream &operator<<(raw_ostream &os, LatticeAnchor anchor) {
- anchor.print(os);
+inline raw_ostream &operator<<(raw_ostream &os, ProgramPoint point) {
+ point.print(os);
return os;
}
} // end namespace mlir
namespace llvm {
-/// Allow hashing of lattice anchors and program points.
-template <>
-struct DenseMapInfo<mlir::LatticeAnchor>
- : public DenseMapInfo<mlir::LatticeAnchor::ParentTy> {};
-
+/// Allow hashing of program points.
template <>
struct DenseMapInfo<mlir::ProgramPoint>
: public DenseMapInfo<mlir::ProgramPoint::ParentTy> {};
// Allow llvm::cast style functions.
-template <typename To>
-struct CastInfo<To, mlir::LatticeAnchor>
- : public CastInfo<To, mlir::LatticeAnchor::PointerUnion> {};
-
-template <typename To>
-struct CastInfo<To, const mlir::LatticeAnchor>
- : public CastInfo<To, const mlir::LatticeAnchor::PointerUnion> {};
-
template <typename To>
struct CastInfo<To, mlir::ProgramPoint>
: public CastInfo<To, mlir::ProgramPoint::PointerUnion> {};
@@ -563,11 +534,6 @@ template <typename To>
struct CastInfo<To, const mlir::ProgramPoint>
: public CastInfo<To, const mlir::ProgramPoint::PointerUnion> {};
-/// Allow stealing the low bits of a ProgramPoint.
-template <>
-struct PointerLikeTypeTraits<mlir::ProgramPoint>
- : public PointerLikeTypeTraits<mlir::ProgramPoint::ParentTy> {};
-
} // end namespace llvm
#endif // MLIR_ANALYSIS_DATAFLOWFRAMEWORK_H
diff --git a/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp b/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp
index 532480b6fad57d..fab2bd83888da8 100644
--- a/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp
@@ -46,20 +46,17 @@ void Executable::print(raw_ostream &os) const {
void Executable::onUpdate(DataFlowSolver *solver) const {
AnalysisState::onUpdate(solver);
- if (ProgramPoint pp = llvm::dyn_cast_if_present<ProgramPoint>(anchor)) {
- if (Block *block = llvm::dyn_cast_if_present<Block *>(pp)) {
- // Re-invoke the analyses on the block itself.
- for (DataFlowAnalysis *analysis : subscribers)
- solver->enqueue({block, analysis});
- // Re-invoke the analyses on all operations in the block.
- for (DataFlowAnalysis *analysis : subscribers)
- for (Operation &op : *block)
- solver->enqueue({&op, analysis});
- }
- } else if (auto *latticeAnchor =
- llvm::dyn_cast_if_present<GenericLatticeAnchor *>(anchor)) {
+ if (auto *block = llvm::dyn_cast_if_present<Block *>(point)) {
+ // Re-invoke the analyses on the block itself.
+ for (DataFlowAnalysis *analysis : subscribers)
+ solver->enqueue({block, analysis});
+ // Re-invoke the analyses on all operations in the block.
+ for (DataFlowAnalysis *analysis : subscribers)
+ for (Operation &op : *block)
+ solver->enqueue({&op, analysis});
+ } else if (auto *programPoint = llvm::dyn_cast_if_present<GenericProgramPoint *>(point)) {
// Re-invoke the analysis on the successor block.
- if (auto *edge = dyn_cast<CFGEdge>(latticeAnchor)) {
+ if (auto *edge = dyn_cast<CFGEdge>(programPoint)) {
for (DataFlowAnalysis *analysis : subscribers)
solver->enqueue({edge->getTo(), analysis});
}
@@ -117,7 +114,7 @@ void CFGEdge::print(raw_ostream &os) const {
DeadCodeAnalysis::DeadCodeAnalysis(DataFlowSolver &solver)
: DataFlowAnalysis(solver) {
- registerAnchorKind<CFGEdge>();
+ registerPointKind<CFGEdge>();
}
LogicalResult DeadCodeAnalysis::initialize(Operation *top) {
@@ -221,8 +218,7 @@ LogicalResult DeadCodeAnalysis::initializeRecursively(Operation *op) {
void DeadCodeAnalysis::markEdgeLive(Block *from, Block *to) {
auto *state = getOrCreate<Executable>(to);
propagateIfChanged(state, state->setToLive());
- auto *edgeState =
- getOrCreate<Executable>(getLatticeAnchor<CFGEdge>(from, to));
+ auto *edgeState = getOrCreate<Executable>(getProgramPoint<CFGEdge>(from, to));
propagateIfChanged(edgeState, edgeState->setToLive());
}
@@ -238,7 +234,9 @@ void DeadCodeAnalysis::markEntryBlocksLive(Operation *op) {
LogicalResult DeadCodeAnalysis::visit(ProgramPoint point) {
if (point.is<Block *>())
return success();
- auto *op = point.get<Operation *>();
+ auto *op = llvm::dyn_cast_if_present<Operation *>(point);
+ if (!op)
+ return emitError(point.getLoc(), "unknown program point kind");
// If the parent block is not executable, there is nothing to do.
if (!getOrCreate<Executable>(op->getBlock())->isLive())
diff --git a/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp b/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp
index 37f4ceaaa56cee..33c877f78f4bf6 100644
--- a/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp
@@ -47,7 +47,10 @@ LogicalResult AbstractDenseForwardDataFlowAnalysis::initialize(Operation *top) {
LogicalResult AbstractDenseForwardDataFlowAnalysis::visit(ProgramPoint point) {
if (auto *op = llvm::dyn_cast_if_present<Operation *>(point))
return processOperation(op);
- visitBlock(point.get<Block *>());
+ else if (auto *block = llvm::dyn_cast_if_present<Block *>(point))
+ visitBlock(block);
+ else
+ return failure();
return success();
}
@@ -177,7 +180,7 @@ void AbstractDenseForwardDataFlowAnalysis::visitBlock(Block *block) {
// Skip control edges that aren't executable.
Block *predecessor = *it;
if (!getOrCreateFor<Executable>(
- block, getLatticeAnchor<CFGEdge>(predecessor, block))
+ block, getProgramPoint<CFGEdge>(predecessor, block))
->isLive())
continue;
@@ -245,8 +248,8 @@ void AbstractDenseForwardDataFlowAnalysis::visitRegionBranchOperation(
const AbstractDenseLattice *
AbstractDenseForwardDataFlowAnalysis::getLatticeFor(ProgramPoint dependent,
- LatticeAnchor anchor) {
- AbstractDenseLattice *state = getLattice(anchor);
+ ProgramPoint point) {
+ AbstractDenseLattice *state = getLattice(point);
addDependency(state, dependent);
return state;
}
@@ -276,7 +279,10 @@ AbstractDenseBackwardDataFlowAnalysis::initialize(Operation *top) {
LogicalResult AbstractDenseBackwardDataFlowAnalysis::visit(ProgramPoint point) {
if (auto *op = llvm::dyn_cast_if_present<Operation *>(point))
return processOperation(op);
- visitBlock(point.get<Block *>());
+ else if (auto *block = llvm::dyn_cast_if_present<Block *>(point))
+ visitBlock(block);
+ else
+ return failure();
return success();
}
@@ -418,7 +424,7 @@ void AbstractDenseBackwardDataFlowAnalysis::visitBlock(Block *block) {
// Meet the state with the state before block's successors.
for (Block *successor : block->getSuccessors()) {
if (!getOrCreateFor<Executable>(block,
- getLatticeAnchor<CFGEdge>(block, successor))
+ getProgramPoint<CFGEdge>(block, successor))
->isLive())
continue;
@@ -468,8 +474,8 @@ void AbstractDenseBackwardDataFlowAnalysis::visitRegionBranchOperation(
const AbstractDenseLattice *
AbstractDenseBackwardDataFlowAnalysis::getLatticeFor(ProgramPoint dependent,
- LatticeAnchor anchor) {
- AbstractDenseLattice *state = getLattice(anchor);
+ ProgramPoint point) {
+ AbstractDenseLattice *state = getLattice(point);
addDependency(state, dependent);
return state;
}
diff --git a/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp b/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp
index 9a95f172d5df48..35d38ea02d7162 100644
--- a/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp
@@ -42,7 +42,7 @@ void IntegerValueRangeLattice::onUpdate(DataFlowSolver *solver) const {
// If the integer range can be narrowed to a constant, update the constant
// value of the SSA value.
std::optional<APInt> constant = getValue().getValue().getConstantValue();
- auto value = anchor.get<Value>();
+ auto value = point.get<Value>();
auto *cv = solver->getOrCreateState<Lattice<ConstantValue>>(value);
if (!constant)
return solver->propagateIfChanged(
diff --git a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
index 4a73f21a18aae7..d47d5fec8a9a6a 100644
--- a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
@@ -34,7 +34,7 @@ void AbstractSparseLattice::onUpdate(DataFlowSolver *solver) const {
AnalysisState::onUpdate(solver);
// Push all users of the value to the queue.
- for (Operation *user : anchor.get<Value>().getUsers())
+ for (Operation *user : point.get<Value>().getUsers())
for (DataFlowAnalysis *analysis : useDefSubscribers)
solver->enqueue({user, analysis});
}
@@ -46,7 +46,7 @@ void AbstractSparseLattice::onUpdate(DataFlowSolver *solver) const {
AbstractSparseForwardDataFlowAnalysis::AbstractSparseForwardDataFlowAnalysis(
DataFlowSolver &solver)
: DataFlowAnalysis(solver) {
- registerAnchorKind<CFGEdge>();
+ registerPointKind<CFGEdge>();
}
LogicalResult
@@ -86,7 +86,10 @@ AbstractSparseForwardDataFlowAnalysis::initializeRecursively(Operation *op) {
LogicalResult AbstractSparseForwardDataFlowAnalysis::visit(ProgramPoint point) {
if (Operation *op = llvm::dyn_cast_if_present<Operation *>(point))
return visitOperation(op);
- visitBlock(point.get<Block *>());
+ else if (Block *block = llvm::dyn_cast_if_present<Block *>(point))
+ visitBlock(block);
+ else
+ return failure();
return success();
}
@@ -214,7 +217,7 @@ void AbstractSparseForwardDataFlowAnalysis::visitBlock(Block *block) {
// If the edge from the predecessor block to the current block is not live,
// bail out.
auto *edgeExecutable =
- getOrCreate<Executable>(getLatticeAnchor<CFGEdge>(predecessor, block));
+ getOrCreate<Executable>(getProgramPoint<CFGEdge>(predecessor, block));
edgeExecutable->blockContentSubscribe(this);
if (!edgeExecutable->isLive())
continue;
@@ -321,7 +324,7 @@ void AbstractSparseForwardDataFlowAnalysis::join(
AbstractSparseBackwardDataFlowAnalysis::AbstractSparseBackwardDataFlowAnalysis(
DataFlowSolver &solver, SymbolTableCollection &symbolTable)
: DataFlowAnalysis(solver), symbolTable(symbolTable) {
- registerAnchorKind<CFGEdge>();
+ registerPointKind<CFGEdge>();
}
LogicalResult
@@ -352,10 +355,14 @@ LogicalResult
AbstractSparseBackwardDataFlowAnalysis::visit(ProgramPoint point) {
if (Operation *op = llvm::dyn_cast_if_present<Operation *>(point))
return visitOperation(op);
- // For backward dataflow, we don't have to do any work for the blocks
- // themselves. CFG edges between blocks are processed by the BranchOp
- // logic in `visitOperation`, and entry blocks for functions are tied
- // to the CallOp arguments by visitOperation.
+ else if (llvm::dyn_cast_if_present<Block *>(point))
+ // For backward dataflow, we don't have to do any work for the blocks
+ // themselves. CFG edges between blocks are processed by the BranchOp
+ // logic in `visitOperation`, and entry blocks for functions are tied
+ // to the CallOp arguments by visitOperation.
+ return success();
+ else
+ return failure();
return success();
}
diff --git a/mlir/lib/Analysis/DataFlowFramework.cpp b/mlir/lib/Analysis/DataFlowFramework.cpp
index a65ddc13143bae..d0e827aa1c2b64 100644
--- a/mlir/lib/Analysis/DataFlowFramework.cpp
+++ b/mlir/lib/Analysis/DataFlowFramework.cpp
@@ -26,10 +26,10 @@
using namespace mlir;
//===----------------------------------------------------------------------===//
-// GenericLatticeAnchor
+// GenericProgramPoint
//===----------------------------------------------------------------------===//
-GenericLatticeAnchor::~GenericLatticeAnchor() = default;
+GenericProgramPoint::~GenericProgramPoint() = default;
//===----------------------------------------------------------------------===//
// AnalysisState
@@ -44,7 +44,7 @@ void AnalysisState::addDependency(ProgramPoint dependent,
DATAFLOW_DEBUG({
if (inserted) {
llvm::dbgs() << "Creating dependency between " << debugName << " of "
- << anchor << "\nand " << debugName << " on " << dependent
+ << point << "\nand " << debugName << " on " << dependent
<< "\n";
}
});
@@ -53,7 +53,7 @@ void AnalysisState::addDependency(ProgramPoint dependent,
void AnalysisState::dump() const { print(llvm::errs()); }
//===----------------------------------------------------------------------===//
-// LatticeAnchor
+// ProgramPoint
//===----------------------------------------------------------------------===//
void ProgramPoint::print(raw_ostream &os) const {
@@ -61,36 +61,23 @@ void ProgramPoint::print(raw_ostream &os) const {
os << "<NULL POINT>";
return;
}
- if (Operation *op = llvm::dyn_cast<Operation *>(*this)) {
+ if (auto *programPoint = llvm::dyn_cast<GenericProgramPoint *>(*this))
+ return programPoint->print(os);
+ if (auto *op = llvm::dyn_cast<Operation *>(*this))
return op->print(os, OpPrintingFlags().skipRegions());
- }
- return get<Block *>()->print(os);
-}
-
-void LatticeAnchor::print(raw_ostream &os) const {
- if (isNull()) {
- os << "<NULL POINT>";
- return;
- }
- if (auto *LatticeAnchor = llvm::dyn_cast<GenericLatticeAnchor *>(*this))
- return LatticeAnchor->print(os);
- if (auto value = llvm::dyn_cast<Value>(*this)) {
+ if (auto value = llvm::dyn_cast<Value>(*this))
return value.print(os, OpPrintingFlags().skipRegions());
- }
-
- return get<ProgramPoint>().print(os);
+ return get<Block *>()->print(os);
}
-Location LatticeAnchor::getLoc() const {
- if (auto *LatticeAnchor = llvm::dyn_cast<GenericLatticeAnchor *>(*this))
- return LatticeAnchor->getLoc();
+Location ProgramPoint::getLoc() const {
+ if (auto *programPoint = llvm::dyn_cast<GenericProgramPoint *>(*this))
+ return programPoint->getLoc();
+ if (auto *op = llvm::dyn_cast<Operation *>(*this))
+ return op->getLoc();
if (auto value = llvm::dyn_cast<Value>(*this))
return value.getLoc();
-
- ProgramPoint pp = get<ProgramPoint>();
- if (auto *op = llvm::dyn_cast<Operation *>(pp))
- return op->getLoc();
- return pp.get<Block *>()->getParent()->getLoc();
+ return get<Block *>()->getParent()->getLoc();
}
//===----------------------------------------------------------------------===//
@@ -130,7 +117,7 @@ void DataFlowSolver::propagateIfChanged(AnalysisState *state,
ChangeResult changed) {
if (changed == ChangeResult::Change) {
DATAFLOW_DEBUG(llvm::dbgs() << "Propagating update to " << state->debugName
- << " of " << state->anchor << "\n"
+ << " of " << state->point << "\n"
<< "Value: " << *state << "\n");
state->onUpdate(this);
}
diff --git a/mlir/test/lib/Analysis/DataFlow/TestDeadCodeAnalysis.cpp b/mlir/test/lib/Analysis/DataFlow/TestDeadCodeAnalysis.cpp
index d02efaaa3fe320..90973af9c2cf5d 100644
--- a/mlir/test/lib/Analysis/DataFlow/TestDeadCodeAnalysis.cpp
+++ b/mlir/test/lib/Analysis/DataFlow/TestDeadCodeAnalysis.cpp
@@ -40,7 +40,7 @@ static void printAnalysisResults(DataFlowSolver &solver, Operation *op,
pred->printAsOperand(os);
os << " = ";
auto *live = solver.lookupState<Executable>(
- solver.getLatticeAnchor<CFGEdge>(pred, &block));
+ solver.getProgramPoint<CFGEdge>(pred, &block));
if (live)
os << *live;
else
diff --git a/mlir/test/lib/Analysis/DataFlow/TestDenseDataFlowAnalysis.h b/mlir/test/lib/Analysis/DataFlow/TestDenseDataFlowAnalysis.h
index 86eb8651cb90c1..57fe0ca458de21 100644
--- a/mlir/test/lib/Analysis/DataFlow/TestDenseDataFlowAnalysis.h
+++ b/mlir/test/lib/Analysis/DataFlow/TestDenseDataFlowAnalysis.h
@@ -206,7 +206,7 @@ class UnderlyingValueAnalysis
/// At an entry point, the underlying value of a value is itself.
void setToEntryState(UnderlyingValueLattice *lattice) override {
propagateIfChanged(lattice,
- lattice->join(UnderlyingValue{lattice->getAnchor()}));
+ lattice->join(UnderlyingValue{lattice->getPoint()}));
}
/// Look for the most underlying value of a value.
diff --git a/mlir/test/lib/Analysis/TestDataFlowFramework.cpp b/mlir/test/lib/Analysis/TestDataFlowFramework.cpp
index 9573ec1d143257..b6b33182440cf4 100644
--- a/mlir/test/lib/Analysis/TestDataFlowFramework.cpp
+++ b/mlir/test/lib/Analysis/TestDataFlowFramework.cpp
@@ -115,11 +115,15 @@ LogicalResult FooAnalysis::initialize(Operation *top) {
}
LogicalResult FooAnalysis::visit(ProgramPoint point) {
- if (auto *op = llvm::dyn_cast_if_present<Operation *>(point))
+ if (auto *op = llvm::dyn_cast_if_present<Operation *>(point)) {
visitOperation(op);
- else
- visitBlock(point.get<Block *>());
- return success();
+ return success();
+ }
+ if (auto *block = llvm::dyn_cast_if_present<Block *>(point)) {
+ visitBlock(block);
+ return success();
+ }
+ return emitError(point.getLoc(), "unknown point kind");
}
void FooAnalysis::visitBlock(Block *block) {
>From 499c6b044e90c9a6676eace35e32fd233bc3c4e8 Mon Sep 17 00:00:00 2001
From: Kunwar Grover <groverkss at gmail.com>
Date: Tue, 10 Sep 2024 13:51:07 +0100
Subject: [PATCH 2/4] Add speculation for linalg structured ops
---
.../Dialect/Linalg/IR/LinalgStructuredOps.td | 1 +
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 34 +++++++++++++++++++
.../mlir-linalg-ods-yaml-gen.cpp | 5 ++-
3 files changed, 39 insertions(+), 1 deletion(-)
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index ac61117c3d6e36..31f29139247267 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -29,6 +29,7 @@ class LinalgStructuredBase_Op<string mnemonic, list<Trait> props>
: Op<Linalg_Dialect, mnemonic, !listconcat([
SingleBlockImplicitTerminator<"YieldOp">,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+ DeclareOpInterfaceMethods<ConditionallySpeculatable>,
DestinationStyleOpInterface,
LinalgStructuredInterface,
ReifyRankedShapedTypeOpInterface], props)> {
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 76df3ecf2d2bd4..790e2ac50d9e39 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -34,6 +34,7 @@
#include "mlir/IR/OperationSupport.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
+#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/SmallSet.h"
@@ -1202,6 +1203,23 @@ void GenericOp::getEffects(
getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
}
+static Speculation::Speculatability
+getGenericSpeculatabilityImpl(LinalgOp linalgOp) {
+ // Operands with value semantics are speculatable, while operands with memory
+ // semantics are not.
+ for (Value operand : linalgOp->getOperands()) {
+ if (isa<MemRefType>(operand.getType())) {
+ return Speculation::NotSpeculatable;
+ }
+ }
+ // The body of the op can still have speculation in it's region.
+ return Speculation::RecursivelySpeculatable;
+}
+
+Speculation::Speculatability GenericOp::getSpeculatability() {
+ return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
+}
+
LogicalResult GenericOp::verify() { return success(); }
namespace {
@@ -1553,6 +1571,10 @@ void MapOp::getEffects(
getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
}
+Speculation::Speculatability MapOp::getSpeculatability() {
+ return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
+}
+
//===----------------------------------------------------------------------===//
// ReduceOp
//===----------------------------------------------------------------------===//
@@ -1621,6 +1643,10 @@ void ReduceOp::getEffects(
getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
}
+Speculation::Speculatability ReduceOp::getSpeculatability() {
+ return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
+}
+
static ParseResult parseDenseI64ArrayAttr(OpAsmParser &parser,
NamedAttrList &attributes,
StringRef attributeName) {
@@ -1906,6 +1932,10 @@ void TransposeOp::getEffects(
getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
}
+Speculation::Speculatability TransposeOp::getSpeculatability() {
+ return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
+}
+
LogicalResult TransposeOp::fold(FoldAdaptor adaptor,
SmallVectorImpl<OpFoldResult> &result) {
// Only the tensor type is supported.
@@ -2134,6 +2164,10 @@ void BroadcastOp::getEffects(
getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
}
+Speculation::Speculatability BroadcastOp::getSpeculatability() {
+ return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
+}
+
void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<EraseIdentityLinalgOp<BroadcastOp>>(context);
diff --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
index 7311cdd39d0755..e5ce845b94b39b 100644
--- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
+++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
@@ -656,7 +656,7 @@ ArrayAttr {0}::getIndexingMaps() {{
}
)FMT";
-// Implementations of fold and getEffects.
+// Implementations of fold, getEffects and getSpeculatability.
// Parameters:
// {0}: Class name
const char structuredOpFoldersFormat[] = R"FMT(
@@ -669,6 +669,9 @@ void {0}::getEffects(SmallVectorImpl<
if (hasPureTensorSemantics()) return;
getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
}
+Speculation::Speculatability {0}::getSpeculatability() {{
+ return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
+}
)FMT";
// Implementation of parse/print.
>From f32be329cc8e68ef6edd78689462be3deec2e892 Mon Sep 17 00:00:00 2001
From: Kunwar Grover <groverkss at gmail.com>
Date: Tue, 10 Sep 2024 15:14:05 +0100
Subject: [PATCH 3/4] Use hasPureTensorSemantics
---
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 7 ++-----
1 file changed, 2 insertions(+), 5 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 790e2ac50d9e39..e2af2e1caad811 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -1207,11 +1207,8 @@ static Speculation::Speculatability
getGenericSpeculatabilityImpl(LinalgOp linalgOp) {
// Operands with value semantics are speculatable, while operands with memory
// semantics are not.
- for (Value operand : linalgOp->getOperands()) {
- if (isa<MemRefType>(operand.getType())) {
- return Speculation::NotSpeculatable;
- }
- }
+ if (!linalgOp.hasPureTensorSemantics())
+ return Speculation::NotSpeculatable;
// The body of the op can still have speculation in it's region.
return Speculation::RecursivelySpeculatable;
}
>From 593b2b880b6fac3ab33a66f8a0e31da7d20184ce Mon Sep 17 00:00:00 2001
From: Kunwar Grover <groverkss at gmail.com>
Date: Tue, 10 Sep 2024 15:15:11 +0100
Subject: [PATCH 4/4] Add test
---
.../loop-invariant-code-motion.mlir | 45 +++++++++++++++++++
1 file changed, 45 insertions(+)
diff --git a/mlir/test/Transforms/loop-invariant-code-motion.mlir b/mlir/test/Transforms/loop-invariant-code-motion.mlir
index 47a49465e8a7cd..dbcac818a728b5 100644
--- a/mlir/test/Transforms/loop-invariant-code-motion.mlir
+++ b/mlir/test/Transforms/loop-invariant-code-motion.mlir
@@ -1118,3 +1118,48 @@ func.func @hoist_from_scf_while(%arg0: i32, %arg1: i32) -> i32 {
}
return %0 : i32
}
+
+// -----
+
+#trait = {
+ indexing_maps = [
+ affine_map<(m, n, k) -> (m, k)>,
+ affine_map<(m, n, k) -> (k, n)>,
+ affine_map<(m, n, k) -> (m, n)>
+ ],
+ iterator_types = ["parallel", "parallel", "reduction"]
+}
+
+// CHECK-LABEL: func @hoist_linalg_ops
+// CHECK: linalg.generic
+// CHECK: scf.for
+// CHECK-NOT: linalg.generic
+// CHECK: tensor.insert_slice
+// CHECK: scf.yield
+func.func @hoist_linalg_ops(%a : tensor<128x128xf32>,
+ %b : tensor<128x128xf32>,
+ %c: tensor<128x128xf32>,
+ %lb : index,
+ %ub : index,
+ %step : index,
+ %output : tensor<?x128xf32>) -> tensor<?x128xf32> {
+ %final =
+ scf.for %i = %lb to %ub step %step iter_args(%acc = %output)
+ -> tensor<?x128xf32> {
+ %compute = linalg.generic #trait
+ ins(%a, %b : tensor<128x128xf32>, tensor<128x128xf32>)
+ outs(%c : tensor<128x128xf32>) {
+ ^bb0(%in : f32, %in2 : f32, %in3 : f32):
+ %mul = arith.mulf %in, %in2 : f32
+ %add = arith.addf %mul, %in3 : f32
+ linalg.yield %in3 : f32
+ } -> tensor<128x128xf32>
+
+ %newacc = tensor.insert_slice %compute into
+ %output[%i, 0][128, 128][1, 1]
+ : tensor<128x128xf32> into tensor<?x128xf32>
+ scf.yield %newacc : tensor<?x128xf32>
+ }
+
+ func.return %final : tensor<?x128xf32>
+}
More information about the Mlir-commits
mailing list