[Mlir-commits] [mlir] [mlir][Linalg] Add speculation for LinalgStructuredOps (PR #108032)

Kunwar Grover llvmlistbot at llvm.org
Tue Sep 10 07:16:57 PDT 2024


https://github.com/Groverkss created https://github.com/llvm/llvm-project/pull/108032

This patch adds speculation behavior for linalg structured ops, allowing them to be hoisted out of loops using LICM.

>From ac0ef19c9953bdb947461bd7f90a85ce92b6b32f Mon Sep 17 00:00:00 2001
From: MaheshRavishankar <mahesh.ravishankar at gmail.com>
Date: Wed, 28 Aug 2024 16:10:11 -0700
Subject: [PATCH 1/4] Revert "[mlir] [dataflow] Refactoring the definition of
 program points in data flow analysis (#105656)"

This reverts commit b6603e1bf11dee4761e49af6581c8b8f074b705d.
---
 .../mlir/Analysis/DataFlow/DeadCodeAnalysis.h |  16 +-
 .../mlir/Analysis/DataFlow/DenseAnalysis.h    |  37 ++-
 .../Analysis/DataFlow/IntegerRangeAnalysis.h  |   2 +-
 .../mlir/Analysis/DataFlow/SparseAnalysis.h   |   8 +-
 .../include/mlir/Analysis/DataFlowFramework.h | 236 ++++++++----------
 .../Analysis/DataFlow/DeadCodeAnalysis.cpp    |  32 ++-
 mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp  |  22 +-
 .../DataFlow/IntegerRangeAnalysis.cpp         |   2 +-
 mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp |  25 +-
 mlir/lib/Analysis/DataFlowFramework.cpp       |  45 ++--
 .../DataFlow/TestDeadCodeAnalysis.cpp         |   2 +-
 .../DataFlow/TestDenseDataFlowAnalysis.h      |   2 +-
 .../lib/Analysis/TestDataFlowFramework.cpp    |  12 +-
 13 files changed, 204 insertions(+), 237 deletions(-)

diff --git a/mlir/include/mlir/Analysis/DataFlow/DeadCodeAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/DeadCodeAnalysis.h
index 80c8b86c63678a..10ef8b6ba5843a 100644
--- a/mlir/include/mlir/Analysis/DataFlow/DeadCodeAnalysis.h
+++ b/mlir/include/mlir/Analysis/DataFlow/DeadCodeAnalysis.h
@@ -35,21 +35,21 @@ namespace dataflow {
 //===----------------------------------------------------------------------===//
 
 /// This is a simple analysis state that represents whether the associated
-/// lattice anchor (either a block or a control-flow edge) is live.
+/// program point (either a block or a control-flow edge) is live.
 class Executable : public AnalysisState {
 public:
   using AnalysisState::AnalysisState;
 
-  /// Set the state of the lattice anchor to live.
+  /// Set the state of the program point to live.
   ChangeResult setToLive();
 
-  /// Get whether the lattice anchor is live.
+  /// Get whether the program point is live.
   bool isLive() const { return live; }
 
   /// Print the liveness.
   void print(raw_ostream &os) const override;
 
-  /// When the state of the lattice anchor is changed to live, re-invoke
+  /// When the state of the program point is changed to live, re-invoke
   /// subscribed analyses on the operations in the block and on the block
   /// itself.
   void onUpdate(DataFlowSolver *solver) const override;
@@ -60,8 +60,8 @@ class Executable : public AnalysisState {
   }
 
 private:
-  /// Whether the lattice anchor is live. Optimistically assume that the lattice
-  /// anchor is dead.
+  /// Whether the program point is live. Optimistically assume that the program
+  /// point is dead.
   bool live = false;
 
   /// A set of analyses that should be updated when this state changes.
@@ -140,10 +140,10 @@ class PredecessorState : public AnalysisState {
 // CFGEdge
 //===----------------------------------------------------------------------===//
 
-/// This lattice anchor represents a control-flow edge between a block and one
+/// This program point represents a control-flow edge between a block and one
 /// of its successors.
 class CFGEdge
-    : public GenericLatticeAnchorBase<CFGEdge, std::pair<Block *, Block *>> {
+    : public GenericProgramPointBase<CFGEdge, std::pair<Block *, Block *>> {
 public:
   using Base::Base;
 
diff --git a/mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h
index 7917f1e3ba6485..4ad5f3fcd838c0 100644
--- a/mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h
+++ b/mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h
@@ -91,16 +91,15 @@ class AbstractDenseForwardDataFlowAnalysis : public DataFlowAnalysis {
                                            const AbstractDenseLattice &before,
                                            AbstractDenseLattice *after) = 0;
 
-  /// Get the dense lattice after the execution of the given lattice anchor.
-  virtual AbstractDenseLattice *getLattice(LatticeAnchor anchor) = 0;
+  /// Get the dense lattice after the execution of the given program point.
+  virtual AbstractDenseLattice *getLattice(ProgramPoint point) = 0;
 
   /// Get the dense lattice after the execution of the given program point and
-  /// add it as a dependency to a lattice anchor. That is, every time the
-  /// lattice after anchor is updated, the dependent program point must be
-  /// visited, and the newly triggered visit might update the lattice after
-  /// dependent.
+  /// add it as a dependency to a program point. That is, every time the lattice
+  /// after point is updated, the dependent program point must be visited, and
+  /// the newly triggered visit might update the lattice after dependent.
   const AbstractDenseLattice *getLatticeFor(ProgramPoint dependent,
-                                            LatticeAnchor anchor);
+                                            ProgramPoint point);
 
   /// Set the dense lattice at control flow entry point and propagate an update
   /// if it changed.
@@ -250,9 +249,9 @@ class DenseForwardDataFlowAnalysis
   }
 
 protected:
-  /// Get the dense lattice on this lattice anchor.
-  LatticeT *getLattice(LatticeAnchor anchor) override {
-    return getOrCreate<LatticeT>(anchor);
+  /// Get the dense lattice after this program point.
+  LatticeT *getLattice(ProgramPoint point) override {
+    return getOrCreate<LatticeT>(point);
   }
 
   /// Set the dense lattice at control flow entry point and propagate an update
@@ -332,16 +331,16 @@ class AbstractDenseBackwardDataFlowAnalysis : public DataFlowAnalysis {
                                            const AbstractDenseLattice &after,
                                            AbstractDenseLattice *before) = 0;
 
-  /// Get the dense lattice before the execution of the lattice anchor. That is,
+  /// Get the dense lattice before the execution of the program point. That is,
   /// before the execution of the given operation or after the execution of the
   /// block.
-  virtual AbstractDenseLattice *getLattice(LatticeAnchor anchor) = 0;
+  virtual AbstractDenseLattice *getLattice(ProgramPoint point) = 0;
 
-  /// Get the dense lattice before the execution of the program point in
-  /// `anchor` and declare that the `dependent` program point must be updated
-  /// every time `point` is.
+  /// Get the dense lattice before the execution of the program point `point`
+  /// and declare that the `dependent` program point must be updated every time
+  /// `point` is.
   const AbstractDenseLattice *getLatticeFor(ProgramPoint dependent,
-                                            LatticeAnchor anchor);
+                                            ProgramPoint point);
 
   /// Set the dense lattice before at the control flow exit point and propagate
   /// the update if it changed.
@@ -501,9 +500,9 @@ class DenseBackwardDataFlowAnalysis
   }
 
 protected:
-  /// Get the dense lattice at the given lattice anchor.
-  LatticeT *getLattice(LatticeAnchor anchor) override {
-    return getOrCreate<LatticeT>(anchor);
+  /// Get the dense lattice at the given program point.
+  LatticeT *getLattice(ProgramPoint point) override {
+    return getOrCreate<LatticeT>(point);
   }
 
   /// Set the dense lattice at control flow exit point (after the terminator)
diff --git a/mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h
index f99eae379596b6..d4a5472cfde868 100644
--- a/mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h
+++ b/mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h
@@ -50,7 +50,7 @@ class IntegerRangeAnalysis
   /// At an entry point, we cannot reason about interger value ranges.
   void setToEntryState(IntegerValueRangeLattice *lattice) override {
     propagateIfChanged(lattice, lattice->join(IntegerValueRange::getMaxRange(
-                                    lattice->getAnchor())));
+                                    lattice->getPoint())));
   }
 
   /// Visit an operation. Invoke the transfer function on each operation that
diff --git a/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
index 933790b4f2a6eb..89726ae3a855c8 100644
--- a/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
+++ b/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
@@ -36,8 +36,8 @@ class AbstractSparseLattice : public AnalysisState {
   /// Lattices can only be created for values.
   AbstractSparseLattice(Value value) : AnalysisState(value) {}
 
-  /// Return the value this lattice is located at.
-  Value getAnchor() const { return AnalysisState::getAnchor().get<Value>(); }
+  /// Return the program point this lattice is located at.
+  Value getPoint() const { return AnalysisState::getPoint().get<Value>(); }
 
   /// Join the information contained in 'rhs' into this lattice. Returns
   /// if the value of the lattice changed.
@@ -86,8 +86,8 @@ class Lattice : public AbstractSparseLattice {
 public:
   using AbstractSparseLattice::AbstractSparseLattice;
 
-  /// Return the value this lattice is located at.
-  Value getAnchor() const { return anchor.get<Value>(); }
+  /// Return the program point this lattice is located at.
+  Value getPoint() const { return point.get<Value>(); }
 
   /// Return the value held by this lattice. This requires that the value is
   /// initialized.
diff --git a/mlir/include/mlir/Analysis/DataFlowFramework.h b/mlir/include/mlir/Analysis/DataFlowFramework.h
index b0450ecdbd99b8..2580ec28b51902 100644
--- a/mlir/include/mlir/Analysis/DataFlowFramework.h
+++ b/mlir/include/mlir/Analysis/DataFlowFramework.h
@@ -49,93 +49,79 @@ inline ChangeResult operator&(ChangeResult lhs, ChangeResult rhs) {
 /// Forward declare the analysis state class.
 class AnalysisState;
 
-/// Program point represents a specific location in the execution of a program.
-/// A sequence of program points can be combined into a control flow graph.
-struct ProgramPoint : public PointerUnion<Operation *, Block *> {
-  using ParentTy = PointerUnion<Operation *, Block *>;
-  /// Inherit constructors.
-  using ParentTy::PointerUnion;
-  /// Allow implicit conversion from the parent type.
-  ProgramPoint(ParentTy point = nullptr) : ParentTy(point) {}
-  /// Allow implicit conversions from operation wrappers.
-  /// TODO: For Windows only. Find a better solution.
-  template <typename OpT, typename = std::enable_if_t<
-                              std::is_convertible<OpT, Operation *>::value &&
-                              !std::is_same<OpT, Operation *>::value>>
-  ProgramPoint(OpT op) : ParentTy(op) {}
-
-  /// Print the program point.
-  void print(raw_ostream &os) const;
-};
-
 //===----------------------------------------------------------------------===//
-// GenericLatticeAnchor
+// GenericProgramPoint
 //===----------------------------------------------------------------------===//
 
-/// Abstract class for generic lattice anchor. In classical data-flow analysis,
-/// lattice anchor represent positions in a program to which lattice elements
+/// Abstract class for generic program points. In classical data-flow analysis,
+/// programs points represent positions in a program to which lattice elements
 /// are attached. In sparse data-flow analysis, these can be SSA values, and in
 /// dense data-flow analysis, these are the program points before and after
 /// every operation.
 ///
-/// Lattice anchor are implemented using MLIR's storage uniquer framework and
+/// In the general MLIR data-flow analysis framework, program points are an
+/// extensible concept. Program points are uniquely identifiable objects to
+/// which analysis states can be attached. The semantics of program points are
+/// defined by the analyses that specify their transfer functions.
+///
+/// Program points are implemented using MLIR's storage uniquer framework and
 /// type ID system to provide RTTI.
-class GenericLatticeAnchor : public StorageUniquer::BaseStorage {
+class GenericProgramPoint : public StorageUniquer::BaseStorage {
 public:
-  virtual ~GenericLatticeAnchor();
+  virtual ~GenericProgramPoint();
 
-  /// Get the abstract lattice anchor's type identifier.
+  /// Get the abstract program point's type identifier.
   TypeID getTypeID() const { return typeID; }
 
-  /// Get a derived source location for the lattice anchor.
+  /// Get a derived source location for the program point.
   virtual Location getLoc() const = 0;
 
-  /// Print the lattice anchor.
+  /// Print the program point.
   virtual void print(raw_ostream &os) const = 0;
 
 protected:
-  /// Create an abstract lattice anchor with type identifier.
-  explicit GenericLatticeAnchor(TypeID typeID) : typeID(typeID) {}
+  /// Create an abstract program point with type identifier.
+  explicit GenericProgramPoint(TypeID typeID) : typeID(typeID) {}
 
 private:
-  /// The type identifier of the lattice anchor.
+  /// The type identifier of the program point.
   TypeID typeID;
 };
 
 //===----------------------------------------------------------------------===//
-// GenericLatticeAnchorBase
+// GenericProgramPointBase
 //===----------------------------------------------------------------------===//
 
-/// Base class for generic lattice anchor based on a concrete lattice anchor
+/// Base class for generic program points based on a concrete program point
 /// type and a content key. This class defines the common methods required for
 /// operability with the storage uniquer framework.
 ///
-/// The provided key type uniquely identifies the concrete lattice anchor
+/// The provided key type uniquely identifies the concrete program point
 /// instance and are the data members of the class.
 template <typename ConcreteT, typename Value>
-class GenericLatticeAnchorBase : public GenericLatticeAnchor {
+class GenericProgramPointBase : public GenericProgramPoint {
 public:
   /// The concrete key type used by the storage uniquer. This class is uniqued
   /// by its contents.
   using KeyTy = Value;
   /// Alias for the base class.
-  using Base = GenericLatticeAnchorBase<ConcreteT, Value>;
+  using Base = GenericProgramPointBase<ConcreteT, Value>;
 
-  /// Construct an instance of the lattice anchor using the provided value and
+  /// Construct an instance of the program point using the provided value and
   /// the type ID of the concrete type.
   template <typename ValueT>
-  explicit GenericLatticeAnchorBase(ValueT &&value)
-      : GenericLatticeAnchor(TypeID::get<ConcreteT>()),
+  explicit GenericProgramPointBase(ValueT &&value)
+      : GenericProgramPoint(TypeID::get<ConcreteT>()),
         value(std::forward<ValueT>(value)) {}
 
-  /// Get a uniqued instance of this lattice anchor class with the given
+  /// Get a uniqued instance of this program point class with the given
   /// arguments.
   template <typename... Args>
   static ConcreteT *get(StorageUniquer &uniquer, Args &&...args) {
     return uniquer.get<ConcreteT>(/*initFn=*/{}, std::forward<Args>(args)...);
   }
 
-  /// Allocate space for a lattice anchor and construct it in-place.
+  /// Allocate space for a program point and construct it in-place.
   template <typename ValueT>
   static ConcreteT *construct(StorageUniquer::StorageAllocator &alloc,
                               ValueT &&value) {
@@ -143,48 +129,46 @@ class GenericLatticeAnchorBase : public GenericLatticeAnchor {
         ConcreteT(std::forward<ValueT>(value));
   }
 
-  /// Two lattice anchors are equal if their values are equal.
+  /// Two program points are equal if their values are equal.
   bool operator==(const Value &value) const { return this->value == value; }
 
   /// Provide LLVM-style RTTI using type IDs.
-  static bool classof(const GenericLatticeAnchor *point) {
+  static bool classof(const GenericProgramPoint *point) {
     return point->getTypeID() == TypeID::get<ConcreteT>();
   }
 
-  /// Get the contents of the lattice anchor.
+  /// Get the contents of the program point.
   const Value &getValue() const { return value; }
 
 private:
-  /// The lattice anchor value.
+  /// The program point value.
   Value value;
 };
 
 //===----------------------------------------------------------------------===//
-// LatticeAnchor
+// ProgramPoint
 //===----------------------------------------------------------------------===//
 
-/// Fundamental IR components are supported as first-class lattice anchor.
-struct LatticeAnchor
-    : public PointerUnion<GenericLatticeAnchor *, ProgramPoint, Value> {
-  using ParentTy = PointerUnion<GenericLatticeAnchor *, ProgramPoint, Value>;
+/// Fundamental IR components are supported as first-class program points.
+struct ProgramPoint
+    : public PointerUnion<GenericProgramPoint *, Operation *, Value, Block *> {
+  using ParentTy =
+      PointerUnion<GenericProgramPoint *, Operation *, Value, Block *>;
   /// Inherit constructors.
   using ParentTy::PointerUnion;
   /// Allow implicit conversion from the parent type.
-  LatticeAnchor(ParentTy point = nullptr) : ParentTy(point) {}
+  ProgramPoint(ParentTy point = nullptr) : ParentTy(point) {}
   /// Allow implicit conversions from operation wrappers.
   /// TODO: For Windows only. Find a better solution.
   template <typename OpT, typename = std::enable_if_t<
                               std::is_convertible<OpT, Operation *>::value &&
                               !std::is_same<OpT, Operation *>::value>>
-  LatticeAnchor(OpT op) : ParentTy(ProgramPoint(op)) {}
-
-  LatticeAnchor(Operation *op) : ParentTy(ProgramPoint(op)) {}
-  LatticeAnchor(Block *block) : ParentTy(ProgramPoint(block)) {}
+  ProgramPoint(OpT op) : ParentTy(op) {}
 
-  /// Print the lattice anchor.
+  /// Print the program point.
   void print(raw_ostream &os) const;
 
-  /// Get the source location of the lattice anchor.
+  /// Get the source location of the program point.
   Location getLoc() const;
 };
 
@@ -223,8 +207,8 @@ class DataFlowConfig {
 
 /// The general data-flow analysis solver. This class is responsible for
 /// orchestrating child data-flow analyses, running the fixed-point iteration
-/// algorithm, managing analysis state and lattice anchor memory, and tracking
-/// dependencies between analyses, lattice anchor, and analysis states.
+/// algorithm, managing analysis state and program point memory, and tracking
+/// dependencies between analyses, program points, and analysis states.
 ///
 /// Steps to run a data-flow analysis:
 ///
@@ -248,33 +232,32 @@ class DataFlowSolver {
   /// operation and run the analysis until fixpoint.
   LogicalResult initializeAndRun(Operation *top);
 
-  /// Lookup an analysis state for the given lattice anchor. Returns null if one
+  /// Lookup an analysis state for the given program point. Returns null if one
   /// does not exist.
-  template <typename StateT, typename AnchorT>
-  const StateT *lookupState(AnchorT anchor) const {
-    auto it =
-        analysisStates.find({LatticeAnchor(anchor), TypeID::get<StateT>()});
+  template <typename StateT, typename PointT>
+  const StateT *lookupState(PointT point) const {
+    auto it = analysisStates.find({ProgramPoint(point), TypeID::get<StateT>()});
     if (it == analysisStates.end())
       return nullptr;
     return static_cast<const StateT *>(it->second.get());
   }
 
-  /// Erase any analysis state associated with the given lattice anchor.
-  template <typename AnchorT>
-  void eraseState(AnchorT anchor) {
-    LatticeAnchor la(anchor);
+  /// Erase any analysis state associated with the given program point.
+  template <typename PointT>
+  void eraseState(PointT point) {
+    ProgramPoint pp(point);
 
     for (auto it = analysisStates.begin(); it != analysisStates.end(); ++it) {
-      if (it->first.first == la)
+      if (it->first.first == pp)
         analysisStates.erase(it);
     }
   }
 
-  /// Get a uniqued lattice anchor instance. If one is not present, it is
+  /// Get a uniqued program point instance. If one is not present, it is
   /// created with the provided arguments.
-  template <typename AnchorT, typename... Args>
-  AnchorT *getLatticeAnchor(Args &&...args) {
-    return AnchorT::get(uniquer, std::forward<Args>(args)...);
+  template <typename PointT, typename... Args>
+  PointT *getProgramPoint(Args &&...args) {
+    return PointT::get(uniquer, std::forward<Args>(args)...);
   }
 
   /// A work item on the solver queue is a program point, child analysis pair.
@@ -284,10 +267,10 @@ class DataFlowSolver {
   /// Push a work item onto the worklist.
   void enqueue(WorkItem item) { worklist.push(std::move(item)); }
 
-  /// Get the state associated with the given lattice anchor. If it does not
+  /// Get the state associated with the given program point. If it does not
   /// exist, create an uninitialized state.
-  template <typename StateT, typename AnchorT>
-  StateT *getOrCreateState(AnchorT anchor);
+  template <typename StateT, typename PointT>
+  StateT *getOrCreateState(PointT point);
 
   /// Propagate an update to an analysis state if it changed by pushing
   /// dependent work items to the back of the queue.
@@ -308,13 +291,13 @@ class DataFlowSolver {
   /// Type-erased instances of the children analyses.
   SmallVector<std::unique_ptr<DataFlowAnalysis>> childAnalyses;
 
-  /// The storage uniquer instance that owns the memory of the allocated lattice
-  /// anchors
+  /// The storage uniquer instance that owns the memory of the allocated program
+  /// points.
   StorageUniquer uniquer;
 
-  /// A type-erased map of lattice anchors to associated analysis states for
-  /// first-class lattice anchors.
-  DenseMap<std::pair<LatticeAnchor, TypeID>, std::unique_ptr<AnalysisState>>
+  /// A type-erased map of program points to associated analysis states for
+  /// first-class program points.
+  DenseMap<std::pair<ProgramPoint, TypeID>, std::unique_ptr<AnalysisState>>
       analysisStates;
 
   /// Allow the base child analysis class to access the internals of the solver.
@@ -326,13 +309,13 @@ class DataFlowSolver {
 //===----------------------------------------------------------------------===//
 
 /// Base class for generic analysis states. Analysis states contain data-flow
-/// information that are attached to lattice anchors and which evolve as the
+/// information that are attached to program points and which evolve as the
 /// analysis iterates.
 ///
 /// This class places no restrictions on the semantics of analysis states beyond
 /// these requirements.
 ///
-/// 1. Querying the state of a lattice anchor prior to visiting that anchor
+/// 1. Querying the state of a program point prior to visiting that point
 ///    results in uninitialized state. Analyses must be aware of unintialized
 ///    states.
 /// 2. Analysis states can reach fixpoints, where subsequent updates will never
@@ -343,20 +326,20 @@ class AnalysisState {
 public:
   virtual ~AnalysisState();
 
-  /// Create the analysis state at the given lattice anchor.
-  AnalysisState(LatticeAnchor anchor) : anchor(anchor) {}
+  /// Create the analysis state at the given program point.
+  AnalysisState(ProgramPoint point) : point(point) {}
 
-  /// Returns the lattice anchor this state is located at.
-  LatticeAnchor getAnchor() const { return anchor; }
+  /// Returns the program point this state is located at.
+  ProgramPoint getPoint() const { return point; }
 
   /// Print the contents of the analysis state.
   virtual void print(raw_ostream &os) const = 0;
   LLVM_DUMP_METHOD void dump() const;
 
-  /// Add a dependency to this analysis state on a lattice anchor and an
+  /// Add a dependency to this analysis state on a program point and an
   /// analysis. If this state is updated, the analysis will be invoked on the
-  /// given lattice anchor again (in onUpdate()).
-  void addDependency(ProgramPoint point, DataFlowAnalysis *analysis);
+  /// given program point again (in onUpdate()).
+  void addDependency(ProgramPoint dependent, DataFlowAnalysis *analysis);
 
 protected:
   /// This function is called by the solver when the analysis state is updated
@@ -368,8 +351,8 @@ class AnalysisState {
       solver->enqueue(item);
   }
 
-  /// The lattice anchor to which the state belongs.
-  LatticeAnchor anchor;
+  /// The program point to which the state belongs.
+  ProgramPoint point;
 
 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
   /// When compiling with debugging, keep a name for the analysis state.
@@ -378,8 +361,8 @@ class AnalysisState {
 
 private:
   /// The dependency relations originating from this analysis state. An entry
-  /// `state -> (analysis, anchor)` is created when `analysis` queries `state`
-  /// when updating `anchor`.
+  /// `state -> (analysis, point)` is created when `analysis` queries `state`
+  /// when updating `point`.
   ///
   /// When this state is updated, all dependent child analysis invocations are
   /// pushed to the back of the queue. Use a `SetVector` to keep the analysis
@@ -420,7 +403,7 @@ class DataFlowAnalysis {
   explicit DataFlowAnalysis(DataFlowSolver &solver);
 
   /// Initialize the analysis from the provided top-level operation by building
-  /// an initial dependency graph between all lattice anchors of interest. This
+  /// an initial dependency graph between all program points of interest. This
   /// can be implemented by calling `visit` on all program points of interest
   /// below the top-level operation.
   ///
@@ -449,39 +432,39 @@ class DataFlowAnalysis {
   virtual LogicalResult visit(ProgramPoint point) = 0;
 
 protected:
-  /// Create a dependency between the given analysis state and lattice anchor
+  /// Create a dependency between the given analysis state and program point
   /// on this analysis.
   void addDependency(AnalysisState *state, ProgramPoint point);
 
   /// Propagate an update to a state if it changed.
   void propagateIfChanged(AnalysisState *state, ChangeResult changed);
 
-  /// Register a custom lattice anchor class.
-  template <typename AnchorT>
-  void registerAnchorKind() {
-    solver.uniquer.registerParametricStorageType<AnchorT>();
+  /// Register a custom program point class.
+  template <typename PointT>
+  void registerPointKind() {
+    solver.uniquer.registerParametricStorageType<PointT>();
   }
 
-  /// Get or create a custom lattice anchor.
-  template <typename AnchorT, typename... Args>
-  AnchorT *getLatticeAnchor(Args &&...args) {
-    return solver.getLatticeAnchor<AnchorT>(std::forward<Args>(args)...);
+  /// Get or create a custom program point.
+  template <typename PointT, typename... Args>
+  PointT *getProgramPoint(Args &&...args) {
+    return solver.getProgramPoint<PointT>(std::forward<Args>(args)...);
   }
 
-  /// Get the analysis state associated with the lattice anchor. The returned
+  /// Get the analysis state associated with the program point. The returned
   /// state is expected to be "write-only", and any updates need to be
   /// propagated by `propagateIfChanged`.
-  template <typename StateT, typename AnchorT>
-  StateT *getOrCreate(AnchorT anchor) {
-    return solver.getOrCreateState<StateT>(anchor);
+  template <typename StateT, typename PointT>
+  StateT *getOrCreate(PointT point) {
+    return solver.getOrCreateState<StateT>(point);
   }
 
   /// Get a read-only analysis state for the given point and create a dependency
   /// on `dependent`. If the return state is updated elsewhere, this analysis is
   /// re-invoked on the dependent.
-  template <typename StateT, typename AnchorT>
-  const StateT *getOrCreateFor(ProgramPoint dependent, AnchorT anchor) {
-    StateT *state = getOrCreate<StateT>(anchor);
+  template <typename StateT, typename PointT>
+  const StateT *getOrCreateFor(ProgramPoint dependent, PointT point) {
+    StateT *state = getOrCreate<StateT>(point);
     addDependency(state, dependent);
     return state;
   }
@@ -511,12 +494,12 @@ AnalysisT *DataFlowSolver::load(Args &&...args) {
   return static_cast<AnalysisT *>(childAnalyses.back().get());
 }
 
-template <typename StateT, typename AnchorT>
-StateT *DataFlowSolver::getOrCreateState(AnchorT anchor) {
+template <typename StateT, typename PointT>
+StateT *DataFlowSolver::getOrCreateState(PointT point) {
   std::unique_ptr<AnalysisState> &state =
-      analysisStates[{LatticeAnchor(anchor), TypeID::get<StateT>()}];
+      analysisStates[{ProgramPoint(point), TypeID::get<StateT>()}];
   if (!state) {
-    state = std::unique_ptr<StateT>(new StateT(anchor));
+    state = std::unique_ptr<StateT>(new StateT(point));
 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
     state->debugName = llvm::getTypeName<StateT>();
 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
@@ -529,32 +512,20 @@ inline raw_ostream &operator<<(raw_ostream &os, const AnalysisState &state) {
   return os;
 }
 
-inline raw_ostream &operator<<(raw_ostream &os, LatticeAnchor anchor) {
-  anchor.print(os);
+inline raw_ostream &operator<<(raw_ostream &os, ProgramPoint point) {
+  point.print(os);
   return os;
 }
 
 } // end namespace mlir
 
 namespace llvm {
-/// Allow hashing of lattice anchors and program points.
-template <>
-struct DenseMapInfo<mlir::LatticeAnchor>
-    : public DenseMapInfo<mlir::LatticeAnchor::ParentTy> {};
-
+/// Allow hashing of program points.
 template <>
 struct DenseMapInfo<mlir::ProgramPoint>
     : public DenseMapInfo<mlir::ProgramPoint::ParentTy> {};
 
 // Allow llvm::cast style functions.
-template <typename To>
-struct CastInfo<To, mlir::LatticeAnchor>
-    : public CastInfo<To, mlir::LatticeAnchor::PointerUnion> {};
-
-template <typename To>
-struct CastInfo<To, const mlir::LatticeAnchor>
-    : public CastInfo<To, const mlir::LatticeAnchor::PointerUnion> {};
-
 template <typename To>
 struct CastInfo<To, mlir::ProgramPoint>
     : public CastInfo<To, mlir::ProgramPoint::PointerUnion> {};
@@ -563,11 +534,6 @@ template <typename To>
 struct CastInfo<To, const mlir::ProgramPoint>
     : public CastInfo<To, const mlir::ProgramPoint::PointerUnion> {};
 
-/// Allow stealing the low bits of a ProgramPoint.
-template <>
-struct PointerLikeTypeTraits<mlir::ProgramPoint>
-    : public PointerLikeTypeTraits<mlir::ProgramPoint::ParentTy> {};
-
 } // end namespace llvm
 
 #endif // MLIR_ANALYSIS_DATAFLOWFRAMEWORK_H
diff --git a/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp b/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp
index 532480b6fad57d..fab2bd83888da8 100644
--- a/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp
@@ -46,20 +46,17 @@ void Executable::print(raw_ostream &os) const {
 void Executable::onUpdate(DataFlowSolver *solver) const {
   AnalysisState::onUpdate(solver);
 
-  if (ProgramPoint pp = llvm::dyn_cast_if_present<ProgramPoint>(anchor)) {
-    if (Block *block = llvm::dyn_cast_if_present<Block *>(pp)) {
-      // Re-invoke the analyses on the block itself.
-      for (DataFlowAnalysis *analysis : subscribers)
-        solver->enqueue({block, analysis});
-      // Re-invoke the analyses on all operations in the block.
-      for (DataFlowAnalysis *analysis : subscribers)
-        for (Operation &op : *block)
-          solver->enqueue({&op, analysis});
-    }
-  } else if (auto *latticeAnchor =
-                 llvm::dyn_cast_if_present<GenericLatticeAnchor *>(anchor)) {
+  if (auto *block = llvm::dyn_cast_if_present<Block *>(point)) {
+    // Re-invoke the analyses on the block itself.
+    for (DataFlowAnalysis *analysis : subscribers)
+      solver->enqueue({block, analysis});
+    // Re-invoke the analyses on all operations in the block.
+    for (DataFlowAnalysis *analysis : subscribers)
+      for (Operation &op : *block)
+        solver->enqueue({&op, analysis});
+  } else if (auto *programPoint = llvm::dyn_cast_if_present<GenericProgramPoint *>(point)) {
     // Re-invoke the analysis on the successor block.
-    if (auto *edge = dyn_cast<CFGEdge>(latticeAnchor)) {
+    if (auto *edge = dyn_cast<CFGEdge>(programPoint)) {
       for (DataFlowAnalysis *analysis : subscribers)
         solver->enqueue({edge->getTo(), analysis});
     }
@@ -117,7 +114,7 @@ void CFGEdge::print(raw_ostream &os) const {
 
 DeadCodeAnalysis::DeadCodeAnalysis(DataFlowSolver &solver)
     : DataFlowAnalysis(solver) {
-  registerAnchorKind<CFGEdge>();
+  registerPointKind<CFGEdge>();
 }
 
 LogicalResult DeadCodeAnalysis::initialize(Operation *top) {
@@ -221,8 +218,7 @@ LogicalResult DeadCodeAnalysis::initializeRecursively(Operation *op) {
 void DeadCodeAnalysis::markEdgeLive(Block *from, Block *to) {
   auto *state = getOrCreate<Executable>(to);
   propagateIfChanged(state, state->setToLive());
-  auto *edgeState =
-      getOrCreate<Executable>(getLatticeAnchor<CFGEdge>(from, to));
+  auto *edgeState = getOrCreate<Executable>(getProgramPoint<CFGEdge>(from, to));
   propagateIfChanged(edgeState, edgeState->setToLive());
 }
 
@@ -238,7 +234,9 @@ void DeadCodeAnalysis::markEntryBlocksLive(Operation *op) {
 LogicalResult DeadCodeAnalysis::visit(ProgramPoint point) {
   if (point.is<Block *>())
     return success();
-  auto *op = point.get<Operation *>();
+  auto *op = llvm::dyn_cast_if_present<Operation *>(point);
+  if (!op)
+    return emitError(point.getLoc(), "unknown program point kind");
 
   // If the parent block is not executable, there is nothing to do.
   if (!getOrCreate<Executable>(op->getBlock())->isLive())
diff --git a/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp b/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp
index 37f4ceaaa56cee..33c877f78f4bf6 100644
--- a/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp
@@ -47,7 +47,10 @@ LogicalResult AbstractDenseForwardDataFlowAnalysis::initialize(Operation *top) {
 LogicalResult AbstractDenseForwardDataFlowAnalysis::visit(ProgramPoint point) {
   if (auto *op = llvm::dyn_cast_if_present<Operation *>(point))
     return processOperation(op);
-  visitBlock(point.get<Block *>());
+  else if (auto *block = llvm::dyn_cast_if_present<Block *>(point))
+    visitBlock(block);
+  else
+    return failure();
   return success();
 }
 
@@ -177,7 +180,7 @@ void AbstractDenseForwardDataFlowAnalysis::visitBlock(Block *block) {
     // Skip control edges that aren't executable.
     Block *predecessor = *it;
     if (!getOrCreateFor<Executable>(
-             block, getLatticeAnchor<CFGEdge>(predecessor, block))
+             block, getProgramPoint<CFGEdge>(predecessor, block))
              ->isLive())
       continue;
 
@@ -245,8 +248,8 @@ void AbstractDenseForwardDataFlowAnalysis::visitRegionBranchOperation(
 
 const AbstractDenseLattice *
 AbstractDenseForwardDataFlowAnalysis::getLatticeFor(ProgramPoint dependent,
-                                                    LatticeAnchor anchor) {
-  AbstractDenseLattice *state = getLattice(anchor);
+                                                    ProgramPoint point) {
+  AbstractDenseLattice *state = getLattice(point);
   addDependency(state, dependent);
   return state;
 }
@@ -276,7 +279,10 @@ AbstractDenseBackwardDataFlowAnalysis::initialize(Operation *top) {
 LogicalResult AbstractDenseBackwardDataFlowAnalysis::visit(ProgramPoint point) {
   if (auto *op = llvm::dyn_cast_if_present<Operation *>(point))
     return processOperation(op);
-  visitBlock(point.get<Block *>());
+  else if (auto *block = llvm::dyn_cast_if_present<Block *>(point))
+    visitBlock(block);
+  else
+    return failure();
   return success();
 }
 
@@ -418,7 +424,7 @@ void AbstractDenseBackwardDataFlowAnalysis::visitBlock(Block *block) {
   // Meet the state with the state before block's successors.
   for (Block *successor : block->getSuccessors()) {
     if (!getOrCreateFor<Executable>(block,
-                                    getLatticeAnchor<CFGEdge>(block, successor))
+                                    getProgramPoint<CFGEdge>(block, successor))
              ->isLive())
       continue;
 
@@ -468,8 +474,8 @@ void AbstractDenseBackwardDataFlowAnalysis::visitRegionBranchOperation(
 
 const AbstractDenseLattice *
 AbstractDenseBackwardDataFlowAnalysis::getLatticeFor(ProgramPoint dependent,
-                                                     LatticeAnchor anchor) {
-  AbstractDenseLattice *state = getLattice(anchor);
+                                                     ProgramPoint point) {
+  AbstractDenseLattice *state = getLattice(point);
   addDependency(state, dependent);
   return state;
 }
diff --git a/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp b/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp
index 9a95f172d5df48..35d38ea02d7162 100644
--- a/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp
@@ -42,7 +42,7 @@ void IntegerValueRangeLattice::onUpdate(DataFlowSolver *solver) const {
   // If the integer range can be narrowed to a constant, update the constant
   // value of the SSA value.
   std::optional<APInt> constant = getValue().getValue().getConstantValue();
-  auto value = anchor.get<Value>();
+  auto value = point.get<Value>();
   auto *cv = solver->getOrCreateState<Lattice<ConstantValue>>(value);
   if (!constant)
     return solver->propagateIfChanged(
diff --git a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
index 4a73f21a18aae7..d47d5fec8a9a6a 100644
--- a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
@@ -34,7 +34,7 @@ void AbstractSparseLattice::onUpdate(DataFlowSolver *solver) const {
   AnalysisState::onUpdate(solver);
 
   // Push all users of the value to the queue.
-  for (Operation *user : anchor.get<Value>().getUsers())
+  for (Operation *user : point.get<Value>().getUsers())
     for (DataFlowAnalysis *analysis : useDefSubscribers)
       solver->enqueue({user, analysis});
 }
@@ -46,7 +46,7 @@ void AbstractSparseLattice::onUpdate(DataFlowSolver *solver) const {
 AbstractSparseForwardDataFlowAnalysis::AbstractSparseForwardDataFlowAnalysis(
     DataFlowSolver &solver)
     : DataFlowAnalysis(solver) {
-  registerAnchorKind<CFGEdge>();
+  registerPointKind<CFGEdge>();
 }
 
 LogicalResult
@@ -86,7 +86,10 @@ AbstractSparseForwardDataFlowAnalysis::initializeRecursively(Operation *op) {
 LogicalResult AbstractSparseForwardDataFlowAnalysis::visit(ProgramPoint point) {
   if (Operation *op = llvm::dyn_cast_if_present<Operation *>(point))
     return visitOperation(op);
-  visitBlock(point.get<Block *>());
+  else if (Block *block = llvm::dyn_cast_if_present<Block *>(point))
+    visitBlock(block);
+  else
+    return failure();
   return success();
 }
 
@@ -214,7 +217,7 @@ void AbstractSparseForwardDataFlowAnalysis::visitBlock(Block *block) {
     // If the edge from the predecessor block to the current block is not live,
     // bail out.
     auto *edgeExecutable =
-        getOrCreate<Executable>(getLatticeAnchor<CFGEdge>(predecessor, block));
+        getOrCreate<Executable>(getProgramPoint<CFGEdge>(predecessor, block));
     edgeExecutable->blockContentSubscribe(this);
     if (!edgeExecutable->isLive())
       continue;
@@ -321,7 +324,7 @@ void AbstractSparseForwardDataFlowAnalysis::join(
 AbstractSparseBackwardDataFlowAnalysis::AbstractSparseBackwardDataFlowAnalysis(
     DataFlowSolver &solver, SymbolTableCollection &symbolTable)
     : DataFlowAnalysis(solver), symbolTable(symbolTable) {
-  registerAnchorKind<CFGEdge>();
+  registerPointKind<CFGEdge>();
 }
 
 LogicalResult
@@ -352,10 +355,14 @@ LogicalResult
 AbstractSparseBackwardDataFlowAnalysis::visit(ProgramPoint point) {
   if (Operation *op = llvm::dyn_cast_if_present<Operation *>(point))
     return visitOperation(op);
-  // For backward dataflow, we don't have to do any work for the blocks
-  // themselves. CFG edges between blocks are processed by the BranchOp
-  // logic in `visitOperation`, and entry blocks for functions are tied
-  // to the CallOp arguments by visitOperation.
+  else if (llvm::dyn_cast_if_present<Block *>(point))
+    // For backward dataflow, we don't have to do any work for the blocks
+    // themselves. CFG edges between blocks are processed by the BranchOp
+    // logic in `visitOperation`, and entry blocks for functions are tied
+    // to the CallOp arguments by visitOperation.
+    return success();
+  else
+    return failure();
   return success();
 }
 
diff --git a/mlir/lib/Analysis/DataFlowFramework.cpp b/mlir/lib/Analysis/DataFlowFramework.cpp
index a65ddc13143bae..d0e827aa1c2b64 100644
--- a/mlir/lib/Analysis/DataFlowFramework.cpp
+++ b/mlir/lib/Analysis/DataFlowFramework.cpp
@@ -26,10 +26,10 @@
 using namespace mlir;
 
 //===----------------------------------------------------------------------===//
-// GenericLatticeAnchor
+// GenericProgramPoint
 //===----------------------------------------------------------------------===//
 
-GenericLatticeAnchor::~GenericLatticeAnchor() = default;
+GenericProgramPoint::~GenericProgramPoint() = default;
 
 //===----------------------------------------------------------------------===//
 // AnalysisState
@@ -44,7 +44,7 @@ void AnalysisState::addDependency(ProgramPoint dependent,
   DATAFLOW_DEBUG({
     if (inserted) {
       llvm::dbgs() << "Creating dependency between " << debugName << " of "
-                   << anchor << "\nand " << debugName << " on " << dependent
+                   << point << "\nand " << debugName << " on " << dependent
                    << "\n";
     }
   });
@@ -53,7 +53,7 @@ void AnalysisState::addDependency(ProgramPoint dependent,
 void AnalysisState::dump() const { print(llvm::errs()); }
 
 //===----------------------------------------------------------------------===//
-// LatticeAnchor
+// ProgramPoint
 //===----------------------------------------------------------------------===//
 
 void ProgramPoint::print(raw_ostream &os) const {
@@ -61,36 +61,23 @@ void ProgramPoint::print(raw_ostream &os) const {
     os << "<NULL POINT>";
     return;
   }
-  if (Operation *op = llvm::dyn_cast<Operation *>(*this)) {
+  if (auto *programPoint = llvm::dyn_cast<GenericProgramPoint *>(*this))
+    return programPoint->print(os);
+  if (auto *op = llvm::dyn_cast<Operation *>(*this))
     return op->print(os, OpPrintingFlags().skipRegions());
-  }
-  return get<Block *>()->print(os);
-}
-
-void LatticeAnchor::print(raw_ostream &os) const {
-  if (isNull()) {
-    os << "<NULL POINT>";
-    return;
-  }
-  if (auto *LatticeAnchor = llvm::dyn_cast<GenericLatticeAnchor *>(*this))
-    return LatticeAnchor->print(os);
-  if (auto value = llvm::dyn_cast<Value>(*this)) {
+  if (auto value = llvm::dyn_cast<Value>(*this))
     return value.print(os, OpPrintingFlags().skipRegions());
-  }
-
-  return get<ProgramPoint>().print(os);
+  return get<Block *>()->print(os);
 }
 
-Location LatticeAnchor::getLoc() const {
-  if (auto *LatticeAnchor = llvm::dyn_cast<GenericLatticeAnchor *>(*this))
-    return LatticeAnchor->getLoc();
+Location ProgramPoint::getLoc() const {
+  if (auto *programPoint = llvm::dyn_cast<GenericProgramPoint *>(*this))
+    return programPoint->getLoc();
+  if (auto *op = llvm::dyn_cast<Operation *>(*this))
+    return op->getLoc();
   if (auto value = llvm::dyn_cast<Value>(*this))
     return value.getLoc();
-
-  ProgramPoint pp = get<ProgramPoint>();
-  if (auto *op = llvm::dyn_cast<Operation *>(pp))
-    return op->getLoc();
-  return pp.get<Block *>()->getParent()->getLoc();
+  return get<Block *>()->getParent()->getLoc();
 }
 
 //===----------------------------------------------------------------------===//
@@ -130,7 +117,7 @@ void DataFlowSolver::propagateIfChanged(AnalysisState *state,
                                         ChangeResult changed) {
   if (changed == ChangeResult::Change) {
     DATAFLOW_DEBUG(llvm::dbgs() << "Propagating update to " << state->debugName
-                                << " of " << state->anchor << "\n"
+                                << " of " << state->point << "\n"
                                 << "Value: " << *state << "\n");
     state->onUpdate(this);
   }
diff --git a/mlir/test/lib/Analysis/DataFlow/TestDeadCodeAnalysis.cpp b/mlir/test/lib/Analysis/DataFlow/TestDeadCodeAnalysis.cpp
index d02efaaa3fe320..90973af9c2cf5d 100644
--- a/mlir/test/lib/Analysis/DataFlow/TestDeadCodeAnalysis.cpp
+++ b/mlir/test/lib/Analysis/DataFlow/TestDeadCodeAnalysis.cpp
@@ -40,7 +40,7 @@ static void printAnalysisResults(DataFlowSolver &solver, Operation *op,
           pred->printAsOperand(os);
           os << " = ";
           auto *live = solver.lookupState<Executable>(
-              solver.getLatticeAnchor<CFGEdge>(pred, &block));
+              solver.getProgramPoint<CFGEdge>(pred, &block));
           if (live)
             os << *live;
           else
diff --git a/mlir/test/lib/Analysis/DataFlow/TestDenseDataFlowAnalysis.h b/mlir/test/lib/Analysis/DataFlow/TestDenseDataFlowAnalysis.h
index 86eb8651cb90c1..57fe0ca458de21 100644
--- a/mlir/test/lib/Analysis/DataFlow/TestDenseDataFlowAnalysis.h
+++ b/mlir/test/lib/Analysis/DataFlow/TestDenseDataFlowAnalysis.h
@@ -206,7 +206,7 @@ class UnderlyingValueAnalysis
   /// At an entry point, the underlying value of a value is itself.
   void setToEntryState(UnderlyingValueLattice *lattice) override {
     propagateIfChanged(lattice,
-                       lattice->join(UnderlyingValue{lattice->getAnchor()}));
+                       lattice->join(UnderlyingValue{lattice->getPoint()}));
   }
 
   /// Look for the most underlying value of a value.
diff --git a/mlir/test/lib/Analysis/TestDataFlowFramework.cpp b/mlir/test/lib/Analysis/TestDataFlowFramework.cpp
index 9573ec1d143257..b6b33182440cf4 100644
--- a/mlir/test/lib/Analysis/TestDataFlowFramework.cpp
+++ b/mlir/test/lib/Analysis/TestDataFlowFramework.cpp
@@ -115,11 +115,15 @@ LogicalResult FooAnalysis::initialize(Operation *top) {
 }
 
 LogicalResult FooAnalysis::visit(ProgramPoint point) {
-  if (auto *op = llvm::dyn_cast_if_present<Operation *>(point))
+  if (auto *op = llvm::dyn_cast_if_present<Operation *>(point)) {
     visitOperation(op);
-  else
-    visitBlock(point.get<Block *>());
-  return success();
+    return success();
+  }
+  if (auto *block = llvm::dyn_cast_if_present<Block *>(point)) {
+    visitBlock(block);
+    return success();
+  }
+  return emitError(point.getLoc(), "unknown point kind");
 }
 
 void FooAnalysis::visitBlock(Block *block) {

>From 499c6b044e90c9a6676eace35e32fd233bc3c4e8 Mon Sep 17 00:00:00 2001
From: Kunwar Grover <groverkss at gmail.com>
Date: Tue, 10 Sep 2024 13:51:07 +0100
Subject: [PATCH 2/4] Add speculation for linalg structured ops

---
 .../Dialect/Linalg/IR/LinalgStructuredOps.td  |  1 +
 mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp      | 34 +++++++++++++++++++
 .../mlir-linalg-ods-yaml-gen.cpp              |  5 ++-
 3 files changed, 39 insertions(+), 1 deletion(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index ac61117c3d6e36..31f29139247267 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -29,6 +29,7 @@ class LinalgStructuredBase_Op<string mnemonic, list<Trait> props>
   : Op<Linalg_Dialect, mnemonic, !listconcat([
        SingleBlockImplicitTerminator<"YieldOp">,
        DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+       DeclareOpInterfaceMethods<ConditionallySpeculatable>,
        DestinationStyleOpInterface,
        LinalgStructuredInterface,
        ReifyRankedShapedTypeOpInterface], props)> {
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 76df3ecf2d2bd4..790e2ac50d9e39 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -34,6 +34,7 @@
 #include "mlir/IR/OperationSupport.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Interfaces/InferTypeOpInterface.h"
+#include "mlir/Interfaces/SideEffectInterfaces.h"
 
 #include "llvm/ADT/DenseMap.h"
 #include "llvm/ADT/SmallSet.h"
@@ -1202,6 +1203,23 @@ void GenericOp::getEffects(
   getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
 }
 
+static Speculation::Speculatability
+getGenericSpeculatabilityImpl(LinalgOp linalgOp) {
+  // Operands with value semantics are speculatable, while operands with memory
+  // semantics are not.
+  for (Value operand : linalgOp->getOperands()) {
+    if (isa<MemRefType>(operand.getType())) {
+      return Speculation::NotSpeculatable;
+    }
+  }
+  // The body of the op can still have speculation in it's region.
+  return Speculation::RecursivelySpeculatable;
+}
+
+Speculation::Speculatability GenericOp::getSpeculatability() {
+  return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
+}
+
 LogicalResult GenericOp::verify() { return success(); }
 
 namespace {
@@ -1553,6 +1571,10 @@ void MapOp::getEffects(
   getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
 }
 
+Speculation::Speculatability MapOp::getSpeculatability() {
+  return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
+}
+
 //===----------------------------------------------------------------------===//
 // ReduceOp
 //===----------------------------------------------------------------------===//
@@ -1621,6 +1643,10 @@ void ReduceOp::getEffects(
   getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
 }
 
+Speculation::Speculatability ReduceOp::getSpeculatability() {
+  return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
+}
+
 static ParseResult parseDenseI64ArrayAttr(OpAsmParser &parser,
                                           NamedAttrList &attributes,
                                           StringRef attributeName) {
@@ -1906,6 +1932,10 @@ void TransposeOp::getEffects(
   getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
 }
 
+Speculation::Speculatability TransposeOp::getSpeculatability() {
+  return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
+}
+
 LogicalResult TransposeOp::fold(FoldAdaptor adaptor,
                                 SmallVectorImpl<OpFoldResult> &result) {
   // Only the tensor type is supported.
@@ -2134,6 +2164,10 @@ void BroadcastOp::getEffects(
   getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
 }
 
+Speculation::Speculatability BroadcastOp::getSpeculatability() {
+  return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
+}
+
 void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                               MLIRContext *context) {
   results.add<EraseIdentityLinalgOp<BroadcastOp>>(context);
diff --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
index 7311cdd39d0755..e5ce845b94b39b 100644
--- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
+++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
@@ -656,7 +656,7 @@ ArrayAttr {0}::getIndexingMaps() {{
 }
 )FMT";
 
-// Implementations of fold and getEffects.
+// Implementations of fold, getEffects and getSpeculatability.
 // Parameters:
 // {0}: Class name
 const char structuredOpFoldersFormat[] = R"FMT(
@@ -669,6 +669,9 @@ void {0}::getEffects(SmallVectorImpl<
       if (hasPureTensorSemantics()) return;
       getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
 }
+Speculation::Speculatability {0}::getSpeculatability() {{
+  return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
+}
 )FMT";
 
 // Implementation of parse/print.

>From f32be329cc8e68ef6edd78689462be3deec2e892 Mon Sep 17 00:00:00 2001
From: Kunwar Grover <groverkss at gmail.com>
Date: Tue, 10 Sep 2024 15:14:05 +0100
Subject: [PATCH 3/4] Use hasPureTensorSemantics

---
 mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 7 ++-----
 1 file changed, 2 insertions(+), 5 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 790e2ac50d9e39..e2af2e1caad811 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -1207,11 +1207,8 @@ static Speculation::Speculatability
 getGenericSpeculatabilityImpl(LinalgOp linalgOp) {
   // Operands with value semantics are speculatable, while operands with memory
   // semantics are not.
-  for (Value operand : linalgOp->getOperands()) {
-    if (isa<MemRefType>(operand.getType())) {
-      return Speculation::NotSpeculatable;
-    }
-  }
+  if (!linalgOp.hasPureTensorSemantics())
+    return Speculation::NotSpeculatable;
   // The body of the op can still have speculation in it's region.
   return Speculation::RecursivelySpeculatable;
 }

>From 593b2b880b6fac3ab33a66f8a0e31da7d20184ce Mon Sep 17 00:00:00 2001
From: Kunwar Grover <groverkss at gmail.com>
Date: Tue, 10 Sep 2024 15:15:11 +0100
Subject: [PATCH 4/4] Add test

---
 .../loop-invariant-code-motion.mlir           | 45 +++++++++++++++++++
 1 file changed, 45 insertions(+)

diff --git a/mlir/test/Transforms/loop-invariant-code-motion.mlir b/mlir/test/Transforms/loop-invariant-code-motion.mlir
index 47a49465e8a7cd..dbcac818a728b5 100644
--- a/mlir/test/Transforms/loop-invariant-code-motion.mlir
+++ b/mlir/test/Transforms/loop-invariant-code-motion.mlir
@@ -1118,3 +1118,48 @@ func.func @hoist_from_scf_while(%arg0: i32, %arg1: i32) -> i32 {
   }
   return %0 : i32
 }
+
+// -----
+
+#trait = {
+  indexing_maps = [
+    affine_map<(m, n, k) -> (m, k)>,
+    affine_map<(m, n, k) -> (k, n)>,
+    affine_map<(m, n, k) -> (m, n)>
+  ],
+  iterator_types = ["parallel", "parallel", "reduction"] 
+}
+
+// CHECK-LABEL: func @hoist_linalg_ops
+// CHECK: linalg.generic
+// CHECK: scf.for
+// CHECK-NOT: linalg.generic
+// CHECK: tensor.insert_slice
+// CHECK: scf.yield
+func.func @hoist_linalg_ops(%a : tensor<128x128xf32>, 
+                            %b : tensor<128x128xf32>, 
+                            %c: tensor<128x128xf32>,
+                            %lb : index,
+                            %ub : index,
+                            %step : index,
+                            %output : tensor<?x128xf32>) -> tensor<?x128xf32> {
+  %final = 
+  scf.for %i = %lb to %ub step %step iter_args(%acc = %output) 
+                                            -> tensor<?x128xf32> {
+    %compute = linalg.generic #trait
+               ins(%a, %b : tensor<128x128xf32>, tensor<128x128xf32>) 
+               outs(%c : tensor<128x128xf32>) {
+    ^bb0(%in : f32, %in2 : f32, %in3 : f32):
+      %mul = arith.mulf %in, %in2 : f32
+      %add = arith.addf %mul, %in3 : f32
+      linalg.yield %in3 : f32
+    } -> tensor<128x128xf32>
+
+    %newacc = tensor.insert_slice %compute into 
+                                  %output[%i, 0][128, 128][1, 1] 
+                                  : tensor<128x128xf32> into tensor<?x128xf32>
+    scf.yield %newacc : tensor<?x128xf32>
+  }
+
+  func.return %final : tensor<?x128xf32>
+}



More information about the Mlir-commits mailing list