[Mlir-commits] [mlir] d1cff36 - [MLIR][analysis] Lattice: Fix automatic delegation of meet to lattice value classes (#82620)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri May 17 04:41:33 PDT 2024


Author: Andi Drebes
Date: 2024-05-17T13:41:30+02:00
New Revision: d1cff36e5e37cd552ec049335feb1dd8f94517ea

URL: https://github.com/llvm/llvm-project/commit/d1cff36e5e37cd552ec049335feb1dd8f94517ea
DIFF: https://github.com/llvm/llvm-project/commit/d1cff36e5e37cd552ec049335feb1dd8f94517ea.diff

LOG: [MLIR][analysis] Lattice: Fix automatic delegation of meet to lattice value classes (#82620)

The class `Lattice` should automatically delegate invocations of the
meet operator to the meet operation of the associated lattice value
class if that class provides a static function called `meet`. This
process fails for two reasons:

1. `Lattice::has_meet` checks for a member function `meet` without
arguments of the lattice value class, although it should check for a
static member function.

2. The function template `Lattice::meet<VT>()` implementing the default
meet operation directly in the lattice is always present and takes
precedence over the delegating function template `Lattice::meet<VT,
std::integral_constant<bool, true>>()`.

This change fixes the automatic delegation of the meet operation of a
lattice to the lattice value class in the presence of a static `meet`
function by conditionally enabling either the delegating function
template or the non-delegating function template and by changing
`Lattice::has_meet` so that it checks for a static `meet` member
function in the lattice value type.

The test from `TestSparseBackwardDataFlowAnalysis.cpp` is changed, such
that the `meet` function is not provided directly in the `WrittenTo`
lattice, but by the `Lattice` base class in order to trigger delegation
to a lattice value class.

Added: 
    

Modified: 
    mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
    mlir/test/lib/Analysis/DataFlow/TestSparseBackwardDataFlowAnalysis.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
index b65ac8bb1dec2..7aadd5409cc69 100644
--- a/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
+++ b/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
@@ -132,14 +132,15 @@ class Lattice : public AbstractSparseLattice {
   /// analysis, lattices will only have a `join`, no `meet`, but we want to use
   /// the same `Lattice` class for both directions.
   template <typename T, typename... Args>
-  using has_meet = decltype(std::declval<T>().meet());
+  using has_meet = decltype(&T::meet);
   template <typename T>
   using lattice_has_meet = llvm::is_detected<has_meet, T>;
 
   /// Meet (intersect) the information contained in the 'rhs' value with this
   /// lattice. Returns if the state of the current lattice changed.  If the
   /// lattice elements don't have a `meet` method, this is a no-op (see below.)
-  template <typename VT, std::enable_if_t<lattice_has_meet<VT>::value>>
+  template <typename VT,
+            std::enable_if_t<lattice_has_meet<VT>::value> * = nullptr>
   ChangeResult meet(const VT &rhs) {
     ValueT newValue = ValueT::meet(value, rhs);
     assert(ValueT::meet(newValue, value) == newValue &&
@@ -155,7 +156,8 @@ class Lattice : public AbstractSparseLattice {
     return ChangeResult::Change;
   }
 
-  template <typename VT>
+  template <typename VT,
+            std::enable_if_t<!lattice_has_meet<VT>::value> * = nullptr>
   ChangeResult meet(const VT &rhs) {
     return ChangeResult::NoChange;
   }

diff  --git a/mlir/test/lib/Analysis/DataFlow/TestSparseBackwardDataFlowAnalysis.cpp b/mlir/test/lib/Analysis/DataFlow/TestSparseBackwardDataFlowAnalysis.cpp
index e1c60f06a6b5e..3029738046644 100644
--- a/mlir/test/lib/Analysis/DataFlow/TestSparseBackwardDataFlowAnalysis.cpp
+++ b/mlir/test/lib/Analysis/DataFlow/TestSparseBackwardDataFlowAnalysis.cpp
@@ -18,18 +18,27 @@ using namespace mlir::dataflow;
 
 namespace {
 
-/// This lattice represents, for a given value, the set of memory resources that
-/// this value, or anything derived from this value, is potentially written to.
-struct WrittenTo : public AbstractSparseLattice {
-  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(WrittenTo)
-  using AbstractSparseLattice::AbstractSparseLattice;
+/// Lattice value storing the a set of memory resources that something
+/// is written to.
+struct WrittenToLatticeValue {
+  bool operator==(const WrittenToLatticeValue &other) {
+    return this->writes == other.writes;
+  }
 
-  void print(raw_ostream &os) const override {
-    os << "[";
-    llvm::interleave(
-        writes, os, [&](const StringAttr &a) { os << a.str(); }, " ");
-    os << "]";
+  static WrittenToLatticeValue meet(const WrittenToLatticeValue &lhs,
+                                    const WrittenToLatticeValue &rhs) {
+    WrittenToLatticeValue res = lhs;
+    (void)res.addWrites(rhs.writes);
+
+    return res;
   }
+
+  static WrittenToLatticeValue join(const WrittenToLatticeValue &lhs,
+                                    const WrittenToLatticeValue &rhs) {
+    // Should not be triggered by this test, but required by `Lattice<T>`
+    llvm_unreachable("Join should not be triggered by this test");
+  }
+
   ChangeResult addWrites(const SetVector<StringAttr> &writes) {
     int sizeBefore = this->writes.size();
     this->writes.insert(writes.begin(), writes.end());
@@ -37,14 +46,26 @@ struct WrittenTo : public AbstractSparseLattice {
     return sizeBefore == sizeAfter ? ChangeResult::NoChange
                                    : ChangeResult::Change;
   }
-  ChangeResult meet(const AbstractSparseLattice &other) override {
-    const auto *rhs = reinterpret_cast<const WrittenTo *>(&other);
-    return addWrites(rhs->writes);
+
+  void print(raw_ostream &os) const {
+    os << "[";
+    llvm::interleave(
+        writes, os, [&](const StringAttr &a) { os << a.str(); }, " ");
+    os << "]";
   }
 
+  void clear() { writes.clear(); }
+
   SetVector<StringAttr> writes;
 };
 
+/// This lattice represents, for a given value, the set of memory resources that
+/// this value, or anything derived from this value, is potentially written to.
+struct WrittenTo : public Lattice<WrittenToLatticeValue> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(WrittenTo)
+  using Lattice::Lattice;
+};
+
 /// An analysis that, by going backwards along the dataflow graph, annotates
 /// each value with all the memory resources it (or anything derived from it)
 /// is eventually written to.
@@ -65,7 +86,9 @@ class WrittenToAnalysis : public SparseBackwardDataFlowAnalysis<WrittenTo> {
   void visitExternalCall(CallOpInterface call, ArrayRef<WrittenTo *> operands,
                          ArrayRef<const WrittenTo *> results) override;
 
-  void setToExitState(WrittenTo *lattice) override { lattice->writes.clear(); }
+  void setToExitState(WrittenTo *lattice) override {
+    lattice->getValue().clear();
+  }
 
 private:
   bool assumeFuncWrites;
@@ -77,7 +100,8 @@ void WrittenToAnalysis::visitOperation(Operation *op,
   if (auto store = dyn_cast<memref::StoreOp>(op)) {
     SetVector<StringAttr> newWrites;
     newWrites.insert(op->getAttrOfType<StringAttr>("tag_name"));
-    propagateIfChanged(operands[0], operands[0]->addWrites(newWrites));
+    propagateIfChanged(operands[0],
+                       operands[0]->getValue().addWrites(newWrites));
     return;
   } // By default, every result of an op depends on every operand.
   for (const WrittenTo *r : results) {
@@ -95,7 +119,7 @@ void WrittenToAnalysis::visitBranchOperand(OpOperand &operand) {
   newWrites.insert(
       StringAttr::get(operand.getOwner()->getContext(),
                       "brancharg" + Twine(operand.getOperandNumber())));
-  propagateIfChanged(lattice, lattice->addWrites(newWrites));
+  propagateIfChanged(lattice, lattice->getValue().addWrites(newWrites));
 }
 
 void WrittenToAnalysis::visitCallOperand(OpOperand &operand) {
@@ -105,7 +129,7 @@ void WrittenToAnalysis::visitCallOperand(OpOperand &operand) {
   newWrites.insert(
       StringAttr::get(operand.getOwner()->getContext(),
                       "callarg" + Twine(operand.getOperandNumber())));
-  propagateIfChanged(lattice, lattice->addWrites(newWrites));
+  propagateIfChanged(lattice, lattice->getValue().addWrites(newWrites));
 }
 
 void WrittenToAnalysis::visitExternalCall(CallOpInterface call,
@@ -124,7 +148,7 @@ void WrittenToAnalysis::visitExternalCall(CallOpInterface call,
                              call.getOperation()->getName().getStringRef());
     }
     newWrites.insert(name);
-    propagateIfChanged(lattice, lattice->addWrites(newWrites));
+    propagateIfChanged(lattice, lattice->getValue().addWrites(newWrites));
   }
 }
 


        


More information about the Mlir-commits mailing list