[Mlir-commits] [mlir] 0118a80 - [ADT] Add Compare template param to EquivalenceClasses
Matthias Springer
llvmlistbot at llvm.org
Mon Nov 1 01:38:51 PDT 2021
Author: Matthias Springer
Date: 2021-11-01T17:16:03+09:00
New Revision: 0118a8044f8bda1a2e1b3add5e244ef4ce714982
URL: https://github.com/llvm/llvm-project/commit/0118a8044f8bda1a2e1b3add5e244ef4ce714982
DIFF: https://github.com/llvm/llvm-project/commit/0118a8044f8bda1a2e1b3add5e244ef4ce714982.diff
LOG: [ADT] Add Compare template param to EquivalenceClasses
This makes the class usable with types that do not provide their own operator<.
Update MLIR Linalg ComprehensiveBufferize to take advantage of the new template param.
Differential Revision: https://reviews.llvm.org/D112052
Added:
Modified:
llvm/include/llvm/ADT/EquivalenceClasses.h
mlir/include/mlir/Dialect/Linalg/Transforms/ComprehensiveBufferize.h
mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
Removed:
################################################################################
diff --git a/llvm/include/llvm/ADT/EquivalenceClasses.h b/llvm/include/llvm/ADT/EquivalenceClasses.h
index 273b00f99d5d8..de6bb3bca7e33 100644
--- a/llvm/include/llvm/ADT/EquivalenceClasses.h
+++ b/llvm/include/llvm/ADT/EquivalenceClasses.h
@@ -30,7 +30,8 @@ namespace llvm {
///
/// This implementation is an efficient implementation that only stores one copy
/// of the element being indexed per entry in the set, and allows any arbitrary
-/// type to be indexed (as long as it can be ordered with operator<).
+/// type to be indexed (as long as it can be ordered with operator< or a
+/// comparator is provided).
///
/// Here is a simple example using integers:
///
@@ -54,7 +55,7 @@ namespace llvm {
/// 4
/// 5 1 2
///
-template <class ElemTy>
+template <class ElemTy, class Compare = std::less<ElemTy>>
class EquivalenceClasses {
/// ECValue - The EquivalenceClasses data structure is just a set of these.
/// Each of these represents a relation for a value. First it stores the
@@ -101,22 +102,40 @@ class EquivalenceClasses {
assert(RHS.isLeader() && RHS.getNext() == nullptr && "Not a singleton!");
}
- bool operator<(const ECValue &UFN) const { return Data < UFN.Data; }
-
bool isLeader() const { return (intptr_t)Next & 1; }
const ElemTy &getData() const { return Data; }
const ECValue *getNext() const {
return (ECValue*)((intptr_t)Next & ~(intptr_t)1);
}
+ };
+
+ /// A wrapper of the comparator, to be passed to the set.
+ struct ECValueComparator {
+ using is_transparent = void;
+
+ ECValueComparator() : compare(Compare()) {}
+
+ bool operator()(const ECValue &lhs, const ECValue &rhs) const {
+ return compare(lhs.Data, rhs.Data);
+ }
+
+ template <typename T>
+ bool operator()(const T &lhs, const ECValue &rhs) const {
+ return compare(lhs, rhs.Data);
+ }
+
+ template <typename T>
+ bool operator()(const ECValue &lhs, const T &rhs) const {
+ return compare(lhs.Data, rhs);
+ }
- template<typename T>
- bool operator<(const T &Val) const { return Data < Val; }
+ const Compare compare;
};
/// TheMapping - This implicitly provides a mapping from ElemTy values to the
/// ECValues, it just keeps the key as part of the value.
- std::set<ECValue> TheMapping;
+ std::set<ECValue, ECValueComparator> TheMapping;
public:
EquivalenceClasses() = default;
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/ComprehensiveBufferize.h b/mlir/include/mlir/Dialect/Linalg/Transforms/ComprehensiveBufferize.h
index 5baec15446df6..f793ae4eec9c2 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/ComprehensiveBufferize.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/ComprehensiveBufferize.h
@@ -122,23 +122,17 @@ class BufferizationAliasInfo {
void dumpEquivalences() const;
private:
- /// llvm::EquivalenceClasses wants comparable elements because it uses
- /// std::set as the underlying impl.
- /// ValueWrapper wraps Value and uses pointer comparison on the defining op.
- /// This is a poor man's comparison but it's not like UnionFind needs ordering
- /// anyway ..
- struct ValueWrapper {
- ValueWrapper(Value val) : v(val) {}
- operator Value() const { return v; }
- bool operator<(const ValueWrapper &wrap) const {
- return v.getImpl() < wrap.v.getImpl();
+ /// llvm::EquivalenceClasses wants comparable elements. This comparator uses
+ /// uses pointer comparison on the defining op. This is a poor man's
+ /// comparison but it's not like UnionFind needs ordering anyway.
+ struct ValueComparator {
+ bool operator()(const Value &lhs, const Value &rhs) const {
+ return lhs.getImpl() < rhs.getImpl();
}
- bool operator==(const ValueWrapper &wrap) const { return v == wrap.v; }
- Value v;
};
using EquivalenceClassRangeType = llvm::iterator_range<
- llvm::EquivalenceClasses<ValueWrapper>::member_iterator>;
+ llvm::EquivalenceClasses<Value, ValueComparator>::member_iterator>;
/// Check that aliasInfo for `v` exists and return a reference to it.
EquivalenceClassRangeType getAliases(Value v) const;
@@ -164,10 +158,10 @@ class BufferizationAliasInfo {
/// Auxiliary structure to store all the values a given value aliases with.
/// These are the conservative cases that can further decompose into
/// "equivalent" buffer relationships.
- llvm::EquivalenceClasses<ValueWrapper> aliasInfo;
+ llvm::EquivalenceClasses<Value, ValueComparator> aliasInfo;
/// Auxiliary structure to store all the equivalent buffer classes.
- llvm::EquivalenceClasses<ValueWrapper> equivalentInfo;
+ llvm::EquivalenceClasses<Value, ValueComparator> equivalentInfo;
};
/// Analyze the `ops` to determine which OpResults are inplaceable.
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
index 0373867ee1b6e..691d56f7ab3e3 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
@@ -1213,11 +1213,11 @@ bool BufferizationAliasInfo::isSourceEquivalentToAMatchingInplaceExtractSliceOp(
for (auto mit = leaderIt, meit = equivalentInfo.member_end(); mit != meit;
++mit) {
auto extractSliceOp =
- dyn_cast_or_null<ExtractSliceOp>(mit->v.getDefiningOp());
+ dyn_cast_or_null<ExtractSliceOp>(mit->getDefiningOp());
if (extractSliceOp &&
areEquivalentExtractSliceOps(extractSliceOp, insertSliceOp) &&
getInPlace(extractSliceOp.result()) == InPlaceSpec::True) {
- LDBG("\tfound: " << *mit->v.getDefiningOp() << '\n');
+ LDBG("\tfound: " << *mit->getDefiningOp() << '\n');
return true;
}
}
@@ -1231,7 +1231,7 @@ void BufferizationAliasInfo::applyOnEquivalenceClass(
auto leaderIt = equivalentInfo.findLeader(v);
for (auto mit = leaderIt, meit = equivalentInfo.member_end(); mit != meit;
++mit) {
- fun(mit->v);
+ fun(*mit);
}
}
More information about the Mlir-commits
mailing list