[Mlir-commits] [mlir] [MLIR][analysis] Lattice: Fix automatic delegation of meet to lattice value classes (PR #82620)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Feb 22 05:49:01 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Andi Drebes (andidr)

<details>
<summary>Changes</summary>

The class `Lattice` should automatically delegate invocations of the meet operator to the meet operation of the associated lattice value class if that class provides a static function called `meet`. This process fails for two reasons:

  1. `Lattice::has_meet` checks for a member function `meet` without arguments of the lattice value class, although it should check for a static member function.

  2. The function template `Lattice::meet<VT>()` implementing the default meet operation directly in the lattice is always present and takes precedence over the delegating function template `Lattice::meet<VT, std::integral_constant<bool, true>>()`.

This change fixes the automatic delegation of the meet operation of a lattice to the lattice value class in the presence of a static `meet` function by conditionally enabling either the delegating function template or the non-delegating function template and by changing `Lattice::has_meet` so that it checks for a static `meet` member function in the lattice value type.

The test from `TestSparseBackwardDataFlowAnalysis.cpp` is changed, such that the `meet` function is not provided directly in the `WrittenTo` lattice, but by the `Lattice` base class in order to trigger delegation to a lattice value class.

---
Full diff: https://github.com/llvm/llvm-project/pull/82620.diff


2 Files Affected:

- (modified) mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h (+5-3) 
- (modified) mlir/test/lib/Analysis/DataFlow/TestSparseBackwardDataFlowAnalysis.cpp (+42-18) 


``````````diff
diff --git a/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
index b65ac8bb1dec27..7aadd5409cc695 100644
--- a/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
+++ b/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
@@ -132,14 +132,15 @@ class Lattice : public AbstractSparseLattice {
   /// analysis, lattices will only have a `join`, no `meet`, but we want to use
   /// the same `Lattice` class for both directions.
   template <typename T, typename... Args>
-  using has_meet = decltype(std::declval<T>().meet());
+  using has_meet = decltype(&T::meet);
   template <typename T>
   using lattice_has_meet = llvm::is_detected<has_meet, T>;
 
   /// Meet (intersect) the information contained in the 'rhs' value with this
   /// lattice. Returns if the state of the current lattice changed.  If the
   /// lattice elements don't have a `meet` method, this is a no-op (see below.)
-  template <typename VT, std::enable_if_t<lattice_has_meet<VT>::value>>
+  template <typename VT,
+            std::enable_if_t<lattice_has_meet<VT>::value> * = nullptr>
   ChangeResult meet(const VT &rhs) {
     ValueT newValue = ValueT::meet(value, rhs);
     assert(ValueT::meet(newValue, value) == newValue &&
@@ -155,7 +156,8 @@ class Lattice : public AbstractSparseLattice {
     return ChangeResult::Change;
   }
 
-  template <typename VT>
+  template <typename VT,
+            std::enable_if_t<!lattice_has_meet<VT>::value> * = nullptr>
   ChangeResult meet(const VT &rhs) {
     return ChangeResult::NoChange;
   }
diff --git a/mlir/test/lib/Analysis/DataFlow/TestSparseBackwardDataFlowAnalysis.cpp b/mlir/test/lib/Analysis/DataFlow/TestSparseBackwardDataFlowAnalysis.cpp
index e1c60f06a6b5eb..6b35d4e2c0d8af 100644
--- a/mlir/test/lib/Analysis/DataFlow/TestSparseBackwardDataFlowAnalysis.cpp
+++ b/mlir/test/lib/Analysis/DataFlow/TestSparseBackwardDataFlowAnalysis.cpp
@@ -18,18 +18,27 @@ using namespace mlir::dataflow;
 
 namespace {
 
-/// This lattice represents, for a given value, the set of memory resources that
-/// this value, or anything derived from this value, is potentially written to.
-struct WrittenTo : public AbstractSparseLattice {
-  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(WrittenTo)
-  using AbstractSparseLattice::AbstractSparseLattice;
+/// Lattice value storing the a set of memory resources that something
+/// is written to.
+struct WrittenToLatticeValue {
+  bool operator==(const WrittenToLatticeValue &other) {
+    return this->writes == other.writes;
+  }
 
-  void print(raw_ostream &os) const override {
-    os << "[";
-    llvm::interleave(
-        writes, os, [&](const StringAttr &a) { os << a.str(); }, " ");
-    os << "]";
+  static WrittenToLatticeValue meet(const WrittenToLatticeValue &lhs,
+                                    const WrittenToLatticeValue &rhs) {
+    WrittenToLatticeValue res = lhs;
+    (void)res.addWrites(rhs.writes);
+
+    return res;
   }
+
+  static WrittenToLatticeValue join(const WrittenToLatticeValue &lhs,
+                                    const WrittenToLatticeValue &rhs) {
+    // Should not be triggered by this test, but required by `Lattice<T>`
+    assert(false);
+  }
+
   ChangeResult addWrites(const SetVector<StringAttr> &writes) {
     int sizeBefore = this->writes.size();
     this->writes.insert(writes.begin(), writes.end());
@@ -37,14 +46,26 @@ struct WrittenTo : public AbstractSparseLattice {
     return sizeBefore == sizeAfter ? ChangeResult::NoChange
                                    : ChangeResult::Change;
   }
-  ChangeResult meet(const AbstractSparseLattice &other) override {
-    const auto *rhs = reinterpret_cast<const WrittenTo *>(&other);
-    return addWrites(rhs->writes);
+
+  void print(raw_ostream &os) const {
+    os << "[";
+    llvm::interleave(
+        writes, os, [&](const StringAttr &a) { os << a.str(); }, " ");
+    os << "]";
   }
 
+  void clear() { writes.clear(); }
+
   SetVector<StringAttr> writes;
 };
 
+/// This lattice represents, for a given value, the set of memory resources that
+/// this value, or anything derived from this value, is potentially written to.
+struct WrittenTo : public Lattice<WrittenToLatticeValue> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(WrittenTo)
+  using Lattice::Lattice;
+};
+
 /// An analysis that, by going backwards along the dataflow graph, annotates
 /// each value with all the memory resources it (or anything derived from it)
 /// is eventually written to.
@@ -65,7 +86,9 @@ class WrittenToAnalysis : public SparseBackwardDataFlowAnalysis<WrittenTo> {
   void visitExternalCall(CallOpInterface call, ArrayRef<WrittenTo *> operands,
                          ArrayRef<const WrittenTo *> results) override;
 
-  void setToExitState(WrittenTo *lattice) override { lattice->writes.clear(); }
+  void setToExitState(WrittenTo *lattice) override {
+    lattice->getValue().clear();
+  }
 
 private:
   bool assumeFuncWrites;
@@ -77,7 +100,8 @@ void WrittenToAnalysis::visitOperation(Operation *op,
   if (auto store = dyn_cast<memref::StoreOp>(op)) {
     SetVector<StringAttr> newWrites;
     newWrites.insert(op->getAttrOfType<StringAttr>("tag_name"));
-    propagateIfChanged(operands[0], operands[0]->addWrites(newWrites));
+    propagateIfChanged(operands[0],
+                       operands[0]->getValue().addWrites(newWrites));
     return;
   } // By default, every result of an op depends on every operand.
   for (const WrittenTo *r : results) {
@@ -95,7 +119,7 @@ void WrittenToAnalysis::visitBranchOperand(OpOperand &operand) {
   newWrites.insert(
       StringAttr::get(operand.getOwner()->getContext(),
                       "brancharg" + Twine(operand.getOperandNumber())));
-  propagateIfChanged(lattice, lattice->addWrites(newWrites));
+  propagateIfChanged(lattice, lattice->getValue().addWrites(newWrites));
 }
 
 void WrittenToAnalysis::visitCallOperand(OpOperand &operand) {
@@ -105,7 +129,7 @@ void WrittenToAnalysis::visitCallOperand(OpOperand &operand) {
   newWrites.insert(
       StringAttr::get(operand.getOwner()->getContext(),
                       "callarg" + Twine(operand.getOperandNumber())));
-  propagateIfChanged(lattice, lattice->addWrites(newWrites));
+  propagateIfChanged(lattice, lattice->getValue().addWrites(newWrites));
 }
 
 void WrittenToAnalysis::visitExternalCall(CallOpInterface call,
@@ -124,7 +148,7 @@ void WrittenToAnalysis::visitExternalCall(CallOpInterface call,
                              call.getOperation()->getName().getStringRef());
     }
     newWrites.insert(name);
-    propagateIfChanged(lattice, lattice->addWrites(newWrites));
+    propagateIfChanged(lattice, lattice->getValue().addWrites(newWrites));
   }
 }
 

``````````

</details>


https://github.com/llvm/llvm-project/pull/82620


More information about the Mlir-commits mailing list