[Mlir-commits] [mlir] 0118a80 - [ADT] Add Compare template param to EquivalenceClasses

Matthias Springer llvmlistbot at llvm.org
Mon Nov 1 01:38:51 PDT 2021


Author: Matthias Springer
Date: 2021-11-01T17:16:03+09:00
New Revision: 0118a8044f8bda1a2e1b3add5e244ef4ce714982

URL: https://github.com/llvm/llvm-project/commit/0118a8044f8bda1a2e1b3add5e244ef4ce714982
DIFF: https://github.com/llvm/llvm-project/commit/0118a8044f8bda1a2e1b3add5e244ef4ce714982.diff

LOG: [ADT] Add Compare template param to EquivalenceClasses

This makes the class usable with types that do not provide their own operator<.

Update MLIR Linalg ComprehensiveBufferize to take advantage of the new template param.

Differential Revision: https://reviews.llvm.org/D112052

Added: 
    

Modified: 
    llvm/include/llvm/ADT/EquivalenceClasses.h
    mlir/include/mlir/Dialect/Linalg/Transforms/ComprehensiveBufferize.h
    mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/ADT/EquivalenceClasses.h b/llvm/include/llvm/ADT/EquivalenceClasses.h
index 273b00f99d5d8..de6bb3bca7e33 100644
--- a/llvm/include/llvm/ADT/EquivalenceClasses.h
+++ b/llvm/include/llvm/ADT/EquivalenceClasses.h
@@ -30,7 +30,8 @@ namespace llvm {
 ///
 /// This implementation is an efficient implementation that only stores one copy
 /// of the element being indexed per entry in the set, and allows any arbitrary
-/// type to be indexed (as long as it can be ordered with operator<).
+/// type to be indexed (as long as it can be ordered with operator< or a
+/// comparator is provided).
 ///
 /// Here is a simple example using integers:
 ///
@@ -54,7 +55,7 @@ namespace llvm {
 ///   4
 ///   5 1 2
 ///
-template <class ElemTy>
+template <class ElemTy, class Compare = std::less<ElemTy>>
 class EquivalenceClasses {
   /// ECValue - The EquivalenceClasses data structure is just a set of these.
   /// Each of these represents a relation for a value.  First it stores the
@@ -101,22 +102,40 @@ class EquivalenceClasses {
       assert(RHS.isLeader() && RHS.getNext() == nullptr && "Not a singleton!");
     }
 
-    bool operator<(const ECValue &UFN) const { return Data < UFN.Data; }
-
     bool isLeader() const { return (intptr_t)Next & 1; }
     const ElemTy &getData() const { return Data; }
 
     const ECValue *getNext() const {
       return (ECValue*)((intptr_t)Next & ~(intptr_t)1);
     }
+  };
+
+  /// A wrapper of the comparator, to be passed to the set.
+  struct ECValueComparator {
+    using is_transparent = void;
+
+    ECValueComparator() : compare(Compare()) {}
+
+    bool operator()(const ECValue &lhs, const ECValue &rhs) const {
+      return compare(lhs.Data, rhs.Data);
+    }
+
+    template <typename T>
+    bool operator()(const T &lhs, const ECValue &rhs) const {
+      return compare(lhs, rhs.Data);
+    }
+
+    template <typename T>
+    bool operator()(const ECValue &lhs, const T &rhs) const {
+      return compare(lhs.Data, rhs);
+    }
 
-    template<typename T>
-    bool operator<(const T &Val) const { return Data < Val; }
+    const Compare compare;
   };
 
   /// TheMapping - This implicitly provides a mapping from ElemTy values to the
   /// ECValues, it just keeps the key as part of the value.
-  std::set<ECValue> TheMapping;
+  std::set<ECValue, ECValueComparator> TheMapping;
 
 public:
   EquivalenceClasses() = default;

diff  --git a/mlir/include/mlir/Dialect/Linalg/Transforms/ComprehensiveBufferize.h b/mlir/include/mlir/Dialect/Linalg/Transforms/ComprehensiveBufferize.h
index 5baec15446df6..f793ae4eec9c2 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/ComprehensiveBufferize.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/ComprehensiveBufferize.h
@@ -122,23 +122,17 @@ class BufferizationAliasInfo {
   void dumpEquivalences() const;
 
 private:
-  /// llvm::EquivalenceClasses wants comparable elements because it uses
-  /// std::set as the underlying impl.
-  /// ValueWrapper wraps Value and uses pointer comparison on the defining op.
-  /// This is a poor man's comparison but it's not like UnionFind needs ordering
-  /// anyway ..
-  struct ValueWrapper {
-    ValueWrapper(Value val) : v(val) {}
-    operator Value() const { return v; }
-    bool operator<(const ValueWrapper &wrap) const {
-      return v.getImpl() < wrap.v.getImpl();
+  /// llvm::EquivalenceClasses wants comparable elements. This comparator uses
+  /// uses pointer comparison on the defining op. This is a poor man's
+  /// comparison but it's not like UnionFind needs ordering anyway.
+  struct ValueComparator {
+    bool operator()(const Value &lhs, const Value &rhs) const {
+      return lhs.getImpl() < rhs.getImpl();
     }
-    bool operator==(const ValueWrapper &wrap) const { return v == wrap.v; }
-    Value v;
   };
 
   using EquivalenceClassRangeType = llvm::iterator_range<
-      llvm::EquivalenceClasses<ValueWrapper>::member_iterator>;
+      llvm::EquivalenceClasses<Value, ValueComparator>::member_iterator>;
   /// Check that aliasInfo for `v` exists and return a reference to it.
   EquivalenceClassRangeType getAliases(Value v) const;
 
@@ -164,10 +158,10 @@ class BufferizationAliasInfo {
   /// Auxiliary structure to store all the values a given value aliases with.
   /// These are the conservative cases that can further decompose into
   /// "equivalent" buffer relationships.
-  llvm::EquivalenceClasses<ValueWrapper> aliasInfo;
+  llvm::EquivalenceClasses<Value, ValueComparator> aliasInfo;
 
   /// Auxiliary structure to store all the equivalent buffer classes.
-  llvm::EquivalenceClasses<ValueWrapper> equivalentInfo;
+  llvm::EquivalenceClasses<Value, ValueComparator> equivalentInfo;
 };
 
 /// Analyze the `ops` to determine which OpResults are inplaceable.

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
index 0373867ee1b6e..691d56f7ab3e3 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
@@ -1213,11 +1213,11 @@ bool BufferizationAliasInfo::isSourceEquivalentToAMatchingInplaceExtractSliceOp(
   for (auto mit = leaderIt, meit = equivalentInfo.member_end(); mit != meit;
        ++mit) {
     auto extractSliceOp =
-        dyn_cast_or_null<ExtractSliceOp>(mit->v.getDefiningOp());
+        dyn_cast_or_null<ExtractSliceOp>(mit->getDefiningOp());
     if (extractSliceOp &&
         areEquivalentExtractSliceOps(extractSliceOp, insertSliceOp) &&
         getInPlace(extractSliceOp.result()) == InPlaceSpec::True) {
-      LDBG("\tfound: " << *mit->v.getDefiningOp() << '\n');
+      LDBG("\tfound: " << *mit->getDefiningOp() << '\n');
       return true;
     }
   }
@@ -1231,7 +1231,7 @@ void BufferizationAliasInfo::applyOnEquivalenceClass(
   auto leaderIt = equivalentInfo.findLeader(v);
   for (auto mit = leaderIt, meit = equivalentInfo.member_end(); mit != meit;
        ++mit) {
-    fun(mit->v);
+    fun(*mit);
   }
 }
 


        


More information about the Mlir-commits mailing list