[Mlir-commits] [mlir] [mlir][Linalg] Add speculation for LinalgStructuredOps (PR #108032)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Sep 10 07:17:29 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
@llvm/pr-subscribers-mlir-core

@llvm/pr-subscribers-mlir-linalg

Author: Kunwar Grover (Groverkss)

<details>
<summary>Changes</summary>

This patch adds speculation behavior for linalg structured ops, allowing them to be hoisted out of loops using LICM.

---

Patch is 50.76 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/108032.diff


17 Files Affected:

- (modified) mlir/include/mlir/Analysis/DataFlow/DeadCodeAnalysis.h (+8-8) 
- (modified) mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h (+18-19) 
- (modified) mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h (+1-1) 
- (modified) mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h (+4-4) 
- (modified) mlir/include/mlir/Analysis/DataFlowFramework.h (+101-135) 
- (modified) mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td (+1) 
- (modified) mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp (+15-17) 
- (modified) mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp (+14-8) 
- (modified) mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp (+1-1) 
- (modified) mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp (+16-9) 
- (modified) mlir/lib/Analysis/DataFlowFramework.cpp (+16-29) 
- (modified) mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp (+31) 
- (modified) mlir/test/Transforms/loop-invariant-code-motion.mlir (+45) 
- (modified) mlir/test/lib/Analysis/DataFlow/TestDeadCodeAnalysis.cpp (+1-1) 
- (modified) mlir/test/lib/Analysis/DataFlow/TestDenseDataFlowAnalysis.h (+1-1) 
- (modified) mlir/test/lib/Analysis/TestDataFlowFramework.cpp (+8-4) 
- (modified) mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp (+4-1) 


``````````diff
diff --git a/mlir/include/mlir/Analysis/DataFlow/DeadCodeAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/DeadCodeAnalysis.h
index 80c8b86c63678a..10ef8b6ba5843a 100644
--- a/mlir/include/mlir/Analysis/DataFlow/DeadCodeAnalysis.h
+++ b/mlir/include/mlir/Analysis/DataFlow/DeadCodeAnalysis.h
@@ -35,21 +35,21 @@ namespace dataflow {
 //===----------------------------------------------------------------------===//
 
 /// This is a simple analysis state that represents whether the associated
-/// lattice anchor (either a block or a control-flow edge) is live.
+/// program point (either a block or a control-flow edge) is live.
 class Executable : public AnalysisState {
 public:
   using AnalysisState::AnalysisState;
 
-  /// Set the state of the lattice anchor to live.
+  /// Set the state of the program point to live.
   ChangeResult setToLive();
 
-  /// Get whether the lattice anchor is live.
+  /// Get whether the program point is live.
   bool isLive() const { return live; }
 
   /// Print the liveness.
   void print(raw_ostream &os) const override;
 
-  /// When the state of the lattice anchor is changed to live, re-invoke
+  /// When the state of the program point is changed to live, re-invoke
   /// subscribed analyses on the operations in the block and on the block
   /// itself.
   void onUpdate(DataFlowSolver *solver) const override;
@@ -60,8 +60,8 @@ class Executable : public AnalysisState {
   }
 
 private:
-  /// Whether the lattice anchor is live. Optimistically assume that the lattice
-  /// anchor is dead.
+  /// Whether the program point is live. Optimistically assume that the program
+  /// point is dead.
   bool live = false;
 
   /// A set of analyses that should be updated when this state changes.
@@ -140,10 +140,10 @@ class PredecessorState : public AnalysisState {
 // CFGEdge
 //===----------------------------------------------------------------------===//
 
-/// This lattice anchor represents a control-flow edge between a block and one
+/// This program point represents a control-flow edge between a block and one
 /// of its successors.
 class CFGEdge
-    : public GenericLatticeAnchorBase<CFGEdge, std::pair<Block *, Block *>> {
+    : public GenericProgramPointBase<CFGEdge, std::pair<Block *, Block *>> {
 public:
   using Base::Base;
 
diff --git a/mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h
index 7917f1e3ba6485..4ad5f3fcd838c0 100644
--- a/mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h
+++ b/mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h
@@ -91,16 +91,15 @@ class AbstractDenseForwardDataFlowAnalysis : public DataFlowAnalysis {
                                            const AbstractDenseLattice &before,
                                            AbstractDenseLattice *after) = 0;
 
-  /// Get the dense lattice after the execution of the given lattice anchor.
-  virtual AbstractDenseLattice *getLattice(LatticeAnchor anchor) = 0;
+  /// Get the dense lattice after the execution of the given program point.
+  virtual AbstractDenseLattice *getLattice(ProgramPoint point) = 0;
 
   /// Get the dense lattice after the execution of the given program point and
-  /// add it as a dependency to a lattice anchor. That is, every time the
-  /// lattice after anchor is updated, the dependent program point must be
-  /// visited, and the newly triggered visit might update the lattice after
-  /// dependent.
+  /// add it as a dependency to a program point. That is, every time the lattice
+  /// after point is updated, the dependent program point must be visited, and
+  /// the newly triggered visit might update the lattice after dependent.
   const AbstractDenseLattice *getLatticeFor(ProgramPoint dependent,
-                                            LatticeAnchor anchor);
+                                            ProgramPoint point);
 
   /// Set the dense lattice at control flow entry point and propagate an update
   /// if it changed.
@@ -250,9 +249,9 @@ class DenseForwardDataFlowAnalysis
   }
 
 protected:
-  /// Get the dense lattice on this lattice anchor.
-  LatticeT *getLattice(LatticeAnchor anchor) override {
-    return getOrCreate<LatticeT>(anchor);
+  /// Get the dense lattice after this program point.
+  LatticeT *getLattice(ProgramPoint point) override {
+    return getOrCreate<LatticeT>(point);
   }
 
   /// Set the dense lattice at control flow entry point and propagate an update
@@ -332,16 +331,16 @@ class AbstractDenseBackwardDataFlowAnalysis : public DataFlowAnalysis {
                                            const AbstractDenseLattice &after,
                                            AbstractDenseLattice *before) = 0;
 
-  /// Get the dense lattice before the execution of the lattice anchor. That is,
+  /// Get the dense lattice before the execution of the program point. That is,
   /// before the execution of the given operation or after the execution of the
   /// block.
-  virtual AbstractDenseLattice *getLattice(LatticeAnchor anchor) = 0;
+  virtual AbstractDenseLattice *getLattice(ProgramPoint point) = 0;
 
-  /// Get the dense lattice before the execution of the program point in
-  /// `anchor` and declare that the `dependent` program point must be updated
-  /// every time `point` is.
+  /// Get the dense lattice before the execution of the program point `point`
+  /// and declare that the `dependent` program point must be updated every time
+  /// `point` is.
   const AbstractDenseLattice *getLatticeFor(ProgramPoint dependent,
-                                            LatticeAnchor anchor);
+                                            ProgramPoint point);
 
   /// Set the dense lattice before at the control flow exit point and propagate
   /// the update if it changed.
@@ -501,9 +500,9 @@ class DenseBackwardDataFlowAnalysis
   }
 
 protected:
-  /// Get the dense lattice at the given lattice anchor.
-  LatticeT *getLattice(LatticeAnchor anchor) override {
-    return getOrCreate<LatticeT>(anchor);
+  /// Get the dense lattice at the given program point.
+  LatticeT *getLattice(ProgramPoint point) override {
+    return getOrCreate<LatticeT>(point);
   }
 
   /// Set the dense lattice at control flow exit point (after the terminator)
diff --git a/mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h
index f99eae379596b6..d4a5472cfde868 100644
--- a/mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h
+++ b/mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h
@@ -50,7 +50,7 @@ class IntegerRangeAnalysis
   /// At an entry point, we cannot reason about interger value ranges.
   void setToEntryState(IntegerValueRangeLattice *lattice) override {
     propagateIfChanged(lattice, lattice->join(IntegerValueRange::getMaxRange(
-                                    lattice->getAnchor())));
+                                    lattice->getPoint())));
   }
 
   /// Visit an operation. Invoke the transfer function on each operation that
diff --git a/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
index 933790b4f2a6eb..89726ae3a855c8 100644
--- a/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
+++ b/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
@@ -36,8 +36,8 @@ class AbstractSparseLattice : public AnalysisState {
   /// Lattices can only be created for values.
   AbstractSparseLattice(Value value) : AnalysisState(value) {}
 
-  /// Return the value this lattice is located at.
-  Value getAnchor() const { return AnalysisState::getAnchor().get<Value>(); }
+  /// Return the program point this lattice is located at.
+  Value getPoint() const { return AnalysisState::getPoint().get<Value>(); }
 
   /// Join the information contained in 'rhs' into this lattice. Returns
   /// if the value of the lattice changed.
@@ -86,8 +86,8 @@ class Lattice : public AbstractSparseLattice {
 public:
   using AbstractSparseLattice::AbstractSparseLattice;
 
-  /// Return the value this lattice is located at.
-  Value getAnchor() const { return anchor.get<Value>(); }
+  /// Return the program point this lattice is located at.
+  Value getPoint() const { return point.get<Value>(); }
 
   /// Return the value held by this lattice. This requires that the value is
   /// initialized.
diff --git a/mlir/include/mlir/Analysis/DataFlowFramework.h b/mlir/include/mlir/Analysis/DataFlowFramework.h
index b0450ecdbd99b8..2580ec28b51902 100644
--- a/mlir/include/mlir/Analysis/DataFlowFramework.h
+++ b/mlir/include/mlir/Analysis/DataFlowFramework.h
@@ -49,93 +49,79 @@ inline ChangeResult operator&(ChangeResult lhs, ChangeResult rhs) {
 /// Forward declare the analysis state class.
 class AnalysisState;
 
-/// Program point represents a specific location in the execution of a program.
-/// A sequence of program points can be combined into a control flow graph.
-struct ProgramPoint : public PointerUnion<Operation *, Block *> {
-  using ParentTy = PointerUnion<Operation *, Block *>;
-  /// Inherit constructors.
-  using ParentTy::PointerUnion;
-  /// Allow implicit conversion from the parent type.
-  ProgramPoint(ParentTy point = nullptr) : ParentTy(point) {}
-  /// Allow implicit conversions from operation wrappers.
-  /// TODO: For Windows only. Find a better solution.
-  template <typename OpT, typename = std::enable_if_t<
-                              std::is_convertible<OpT, Operation *>::value &&
-                              !std::is_same<OpT, Operation *>::value>>
-  ProgramPoint(OpT op) : ParentTy(op) {}
-
-  /// Print the program point.
-  void print(raw_ostream &os) const;
-};
-
 //===----------------------------------------------------------------------===//
-// GenericLatticeAnchor
+// GenericProgramPoint
 //===----------------------------------------------------------------------===//
 
-/// Abstract class for generic lattice anchor. In classical data-flow analysis,
-/// lattice anchor represent positions in a program to which lattice elements
+/// Abstract class for generic program points. In classical data-flow analysis,
+/// programs points represent positions in a program to which lattice elements
 /// are attached. In sparse data-flow analysis, these can be SSA values, and in
 /// dense data-flow analysis, these are the program points before and after
 /// every operation.
 ///
-/// Lattice anchor are implemented using MLIR's storage uniquer framework and
+/// In the general MLIR data-flow analysis framework, program points are an
+/// extensible concept. Program points are uniquely identifiable objects to
+/// which analysis states can be attached. The semantics of program points are
+/// defined by the analyses that specify their transfer functions.
+///
+/// Program points are implemented using MLIR's storage uniquer framework and
 /// type ID system to provide RTTI.
-class GenericLatticeAnchor : public StorageUniquer::BaseStorage {
+class GenericProgramPoint : public StorageUniquer::BaseStorage {
 public:
-  virtual ~GenericLatticeAnchor();
+  virtual ~GenericProgramPoint();
 
-  /// Get the abstract lattice anchor's type identifier.
+  /// Get the abstract program point's type identifier.
   TypeID getTypeID() const { return typeID; }
 
-  /// Get a derived source location for the lattice anchor.
+  /// Get a derived source location for the program point.
   virtual Location getLoc() const = 0;
 
-  /// Print the lattice anchor.
+  /// Print the program point.
   virtual void print(raw_ostream &os) const = 0;
 
 protected:
-  /// Create an abstract lattice anchor with type identifier.
-  explicit GenericLatticeAnchor(TypeID typeID) : typeID(typeID) {}
+  /// Create an abstract program point with type identifier.
+  explicit GenericProgramPoint(TypeID typeID) : typeID(typeID) {}
 
 private:
-  /// The type identifier of the lattice anchor.
+  /// The type identifier of the program point.
   TypeID typeID;
 };
 
 //===----------------------------------------------------------------------===//
-// GenericLatticeAnchorBase
+// GenericProgramPointBase
 //===----------------------------------------------------------------------===//
 
-/// Base class for generic lattice anchor based on a concrete lattice anchor
+/// Base class for generic program points based on a concrete program point
 /// type and a content key. This class defines the common methods required for
 /// operability with the storage uniquer framework.
 ///
-/// The provided key type uniquely identifies the concrete lattice anchor
+/// The provided key type uniquely identifies the concrete program point
 /// instance and are the data members of the class.
 template <typename ConcreteT, typename Value>
-class GenericLatticeAnchorBase : public GenericLatticeAnchor {
+class GenericProgramPointBase : public GenericProgramPoint {
 public:
   /// The concrete key type used by the storage uniquer. This class is uniqued
   /// by its contents.
   using KeyTy = Value;
   /// Alias for the base class.
-  using Base = GenericLatticeAnchorBase<ConcreteT, Value>;
+  using Base = GenericProgramPointBase<ConcreteT, Value>;
 
-  /// Construct an instance of the lattice anchor using the provided value and
+  /// Construct an instance of the program point using the provided value and
   /// the type ID of the concrete type.
   template <typename ValueT>
-  explicit GenericLatticeAnchorBase(ValueT &&value)
-      : GenericLatticeAnchor(TypeID::get<ConcreteT>()),
+  explicit GenericProgramPointBase(ValueT &&value)
+      : GenericProgramPoint(TypeID::get<ConcreteT>()),
         value(std::forward<ValueT>(value)) {}
 
-  /// Get a uniqued instance of this lattice anchor class with the given
+  /// Get a uniqued instance of this program point class with the given
   /// arguments.
   template <typename... Args>
   static ConcreteT *get(StorageUniquer &uniquer, Args &&...args) {
     return uniquer.get<ConcreteT>(/*initFn=*/{}, std::forward<Args>(args)...);
   }
 
-  /// Allocate space for a lattice anchor and construct it in-place.
+  /// Allocate space for a program point and construct it in-place.
   template <typename ValueT>
   static ConcreteT *construct(StorageUniquer::StorageAllocator &alloc,
                               ValueT &&value) {
@@ -143,48 +129,46 @@ class GenericLatticeAnchorBase : public GenericLatticeAnchor {
         ConcreteT(std::forward<ValueT>(value));
   }
 
-  /// Two lattice anchors are equal if their values are equal.
+  /// Two program points are equal if their values are equal.
   bool operator==(const Value &value) const { return this->value == value; }
 
   /// Provide LLVM-style RTTI using type IDs.
-  static bool classof(const GenericLatticeAnchor *point) {
+  static bool classof(const GenericProgramPoint *point) {
     return point->getTypeID() == TypeID::get<ConcreteT>();
   }
 
-  /// Get the contents of the lattice anchor.
+  /// Get the contents of the program point.
   const Value &getValue() const { return value; }
 
 private:
-  /// The lattice anchor value.
+  /// The program point value.
   Value value;
 };
 
 //===----------------------------------------------------------------------===//
-// LatticeAnchor
+// ProgramPoint
 //===----------------------------------------------------------------------===//
 
-/// Fundamental IR components are supported as first-class lattice anchor.
-struct LatticeAnchor
-    : public PointerUnion<GenericLatticeAnchor *, ProgramPoint, Value> {
-  using ParentTy = PointerUnion<GenericLatticeAnchor *, ProgramPoint, Value>;
+/// Fundamental IR components are supported as first-class program points.
+struct ProgramPoint
+    : public PointerUnion<GenericProgramPoint *, Operation *, Value, Block *> {
+  using ParentTy =
+      PointerUnion<GenericProgramPoint *, Operation *, Value, Block *>;
   /// Inherit constructors.
   using ParentTy::PointerUnion;
   /// Allow implicit conversion from the parent type.
-  LatticeAnchor(ParentTy point = nullptr) : ParentTy(point) {}
+  ProgramPoint(ParentTy point = nullptr) : ParentTy(point) {}
   /// Allow implicit conversions from operation wrappers.
   /// TODO: For Windows only. Find a better solution.
   template <typename OpT, typename = std::enable_if_t<
                               std::is_convertible<OpT, Operation *>::value &&
                               !std::is_same<OpT, Operation *>::value>>
-  LatticeAnchor(OpT op) : ParentTy(ProgramPoint(op)) {}
-
-  LatticeAnchor(Operation *op) : ParentTy(ProgramPoint(op)) {}
-  LatticeAnchor(Block *block) : ParentTy(ProgramPoint(block)) {}
+  ProgramPoint(OpT op) : ParentTy(op) {}
 
-  /// Print the lattice anchor.
+  /// Print the program point.
   void print(raw_ostream &os) const;
 
-  /// Get the source location of the lattice anchor.
+  /// Get the source location of the program point.
   Location getLoc() const;
 };
 
@@ -223,8 +207,8 @@ class DataFlowConfig {
 
 /// The general data-flow analysis solver. This class is responsible for
 /// orchestrating child data-flow analyses, running the fixed-point iteration
-/// algorithm, managing analysis state and lattice anchor memory, and tracking
-/// dependencies between analyses, lattice anchor, and analysis states.
+/// algorithm, managing analysis state and program point memory, and tracking
+/// dependencies between analyses, program points, and analysis states.
 ///
 /// Steps to run a data-flow analysis:
 ///
@@ -248,33 +232,32 @@ class DataFlowSolver {
   /// operation and run the analysis until fixpoint.
   LogicalResult initializeAndRun(Operation *top);
 
-  /// Lookup an analysis state for the given lattice anchor. Returns null if one
+  /// Lookup an analysis state for the given program point. Returns null if one
   /// does not exist.
-  template <typename StateT, typename AnchorT>
-  const StateT *lookupState(AnchorT anchor) const {
-    auto it =
-        analysisStates.find({LatticeAnchor(anchor), TypeID::get<StateT>()});
+  template <typename StateT, typename PointT>
+  const StateT *lookupState(PointT point) const {
+    auto it = analysisStates.find({ProgramPoint(point), TypeID::get<StateT>()});
     if (it == analysisStates.end())
       return nullptr;
     return static_cast<const StateT *>(it->second.get());
   }
 
-  /// Erase any analysis state associated with the given lattice anchor.
-  template <typename AnchorT>
-  void eraseState(AnchorT anchor) {
-    LatticeAnchor la(anchor);
+  /// Erase any analysis state associated with the given program point.
+  template <typename PointT>
+  void eraseState(PointT point) {
+    ProgramPoint pp(point);
 
     for (auto it = analysisStates.begin(); it != analysisStates.end(); ++it) {
-      if (it->first.first == la)
+      if (it->first.first == pp)
         analysisStates.erase(it);
     }
   }
 
-  /// Get a uniqued lattice anchor instance. If one is not present, it is
+  /// Get a uniqued program point instance. If one is not present, it is
   /// created with the provided arguments.
-  template <typename AnchorT, typename... Args>
-  AnchorT *getLatticeAnchor(Args &&...args) {
-    return AnchorT::get(uniquer, std::forward<Args>(args)...);
+  template <typename PointT, typename... Args>
+  PointT *getProgramPoint(Args &&...args) {
+    return PointT::get(uniquer, std::forward<Args>(args)...);
   }
 
   /// A work item on the solver queue is a program point, child analysis pair.
@@ -284,10 +267,10 @@ class DataFlowSolver {
   /// Push a work item onto the worklist.
   void enqueue(WorkItem item) { worklist.push(std::move(item)); }
 
-  /// Get the state associated with the given lattice anchor. If it does not
+  /// Get the state associated with the given program point. If it does not
   /// exist, create an uninitialized state.
-  template <typename StateT, typename AnchorT>
-  StateT *getOrCreateState(AnchorT anchor);
+  template <typename StateT, typename PointT>
+  StateT *getOrCreateState(PointT point);
 
   /// Propagate an update to an analysis state if it changed by pushing
   /// dependent work items to the back of the queue.
@@ -308,13 +291,13 @@ class DataFlowSolver {
   /// Type-erased instances of the children analyses.
   SmallVector<std::unique_ptr<DataFlowAnalysis>> childAnalyses;
 
-  /// The storage uniquer instance that owns the memory of the allocated lattice
-  /// anchors
+  /// The storage uniquer instance that owns the memory of the allocated program
+  /// points.
   StorageUniquer uniquer;
 
-  /// A ...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/108032


More information about the Mlir-commits mailing list