[Mlir-commits] [mlir] d1cff36 - [MLIR][analysis] Lattice: Fix automatic delegation of meet to lattice value classes (#82620)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri May 17 04:41:33 PDT 2024
Author: Andi Drebes
Date: 2024-05-17T13:41:30+02:00
New Revision: d1cff36e5e37cd552ec049335feb1dd8f94517ea
URL: https://github.com/llvm/llvm-project/commit/d1cff36e5e37cd552ec049335feb1dd8f94517ea
DIFF: https://github.com/llvm/llvm-project/commit/d1cff36e5e37cd552ec049335feb1dd8f94517ea.diff
LOG: [MLIR][analysis] Lattice: Fix automatic delegation of meet to lattice value classes (#82620)
The class `Lattice` should automatically delegate invocations of the
meet operator to the meet operation of the associated lattice value
class if that class provides a static function called `meet`. This
process fails for two reasons:
1. `Lattice::has_meet` checks for a member function `meet` without
arguments of the lattice value class, although it should check for a
static member function.
2. The function template `Lattice::meet<VT>()` implementing the default
meet operation directly in the lattice is always present and takes
precedence over the delegating function template `Lattice::meet<VT,
std::integral_constant<bool, true>>()`.
This change fixes the automatic delegation of the meet operation of a
lattice to the lattice value class in the presence of a static `meet`
function by conditionally enabling either the delegating function
template or the non-delegating function template and by changing
`Lattice::has_meet` so that it checks for a static `meet` member
function in the lattice value type.
The test from `TestSparseBackwardDataFlowAnalysis.cpp` is changed, such
that the `meet` function is not provided directly in the `WrittenTo`
lattice, but by the `Lattice` base class in order to trigger delegation
to a lattice value class.
Added:
Modified:
mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
mlir/test/lib/Analysis/DataFlow/TestSparseBackwardDataFlowAnalysis.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
index b65ac8bb1dec2..7aadd5409cc69 100644
--- a/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
+++ b/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
@@ -132,14 +132,15 @@ class Lattice : public AbstractSparseLattice {
/// analysis, lattices will only have a `join`, no `meet`, but we want to use
/// the same `Lattice` class for both directions.
template <typename T, typename... Args>
- using has_meet = decltype(std::declval<T>().meet());
+ using has_meet = decltype(&T::meet);
template <typename T>
using lattice_has_meet = llvm::is_detected<has_meet, T>;
/// Meet (intersect) the information contained in the 'rhs' value with this
/// lattice. Returns if the state of the current lattice changed. If the
/// lattice elements don't have a `meet` method, this is a no-op (see below.)
- template <typename VT, std::enable_if_t<lattice_has_meet<VT>::value>>
+ template <typename VT,
+ std::enable_if_t<lattice_has_meet<VT>::value> * = nullptr>
ChangeResult meet(const VT &rhs) {
ValueT newValue = ValueT::meet(value, rhs);
assert(ValueT::meet(newValue, value) == newValue &&
@@ -155,7 +156,8 @@ class Lattice : public AbstractSparseLattice {
return ChangeResult::Change;
}
- template <typename VT>
+ template <typename VT,
+ std::enable_if_t<!lattice_has_meet<VT>::value> * = nullptr>
ChangeResult meet(const VT &rhs) {
return ChangeResult::NoChange;
}
diff --git a/mlir/test/lib/Analysis/DataFlow/TestSparseBackwardDataFlowAnalysis.cpp b/mlir/test/lib/Analysis/DataFlow/TestSparseBackwardDataFlowAnalysis.cpp
index e1c60f06a6b5e..3029738046644 100644
--- a/mlir/test/lib/Analysis/DataFlow/TestSparseBackwardDataFlowAnalysis.cpp
+++ b/mlir/test/lib/Analysis/DataFlow/TestSparseBackwardDataFlowAnalysis.cpp
@@ -18,18 +18,27 @@ using namespace mlir::dataflow;
namespace {
-/// This lattice represents, for a given value, the set of memory resources that
-/// this value, or anything derived from this value, is potentially written to.
-struct WrittenTo : public AbstractSparseLattice {
- MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(WrittenTo)
- using AbstractSparseLattice::AbstractSparseLattice;
+/// Lattice value storing the a set of memory resources that something
+/// is written to.
+struct WrittenToLatticeValue {
+ bool operator==(const WrittenToLatticeValue &other) {
+ return this->writes == other.writes;
+ }
- void print(raw_ostream &os) const override {
- os << "[";
- llvm::interleave(
- writes, os, [&](const StringAttr &a) { os << a.str(); }, " ");
- os << "]";
+ static WrittenToLatticeValue meet(const WrittenToLatticeValue &lhs,
+ const WrittenToLatticeValue &rhs) {
+ WrittenToLatticeValue res = lhs;
+ (void)res.addWrites(rhs.writes);
+
+ return res;
}
+
+ static WrittenToLatticeValue join(const WrittenToLatticeValue &lhs,
+ const WrittenToLatticeValue &rhs) {
+ // Should not be triggered by this test, but required by `Lattice<T>`
+ llvm_unreachable("Join should not be triggered by this test");
+ }
+
ChangeResult addWrites(const SetVector<StringAttr> &writes) {
int sizeBefore = this->writes.size();
this->writes.insert(writes.begin(), writes.end());
@@ -37,14 +46,26 @@ struct WrittenTo : public AbstractSparseLattice {
return sizeBefore == sizeAfter ? ChangeResult::NoChange
: ChangeResult::Change;
}
- ChangeResult meet(const AbstractSparseLattice &other) override {
- const auto *rhs = reinterpret_cast<const WrittenTo *>(&other);
- return addWrites(rhs->writes);
+
+ void print(raw_ostream &os) const {
+ os << "[";
+ llvm::interleave(
+ writes, os, [&](const StringAttr &a) { os << a.str(); }, " ");
+ os << "]";
}
+ void clear() { writes.clear(); }
+
SetVector<StringAttr> writes;
};
+/// This lattice represents, for a given value, the set of memory resources that
+/// this value, or anything derived from this value, is potentially written to.
+struct WrittenTo : public Lattice<WrittenToLatticeValue> {
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(WrittenTo)
+ using Lattice::Lattice;
+};
+
/// An analysis that, by going backwards along the dataflow graph, annotates
/// each value with all the memory resources it (or anything derived from it)
/// is eventually written to.
@@ -65,7 +86,9 @@ class WrittenToAnalysis : public SparseBackwardDataFlowAnalysis<WrittenTo> {
void visitExternalCall(CallOpInterface call, ArrayRef<WrittenTo *> operands,
ArrayRef<const WrittenTo *> results) override;
- void setToExitState(WrittenTo *lattice) override { lattice->writes.clear(); }
+ void setToExitState(WrittenTo *lattice) override {
+ lattice->getValue().clear();
+ }
private:
bool assumeFuncWrites;
@@ -77,7 +100,8 @@ void WrittenToAnalysis::visitOperation(Operation *op,
if (auto store = dyn_cast<memref::StoreOp>(op)) {
SetVector<StringAttr> newWrites;
newWrites.insert(op->getAttrOfType<StringAttr>("tag_name"));
- propagateIfChanged(operands[0], operands[0]->addWrites(newWrites));
+ propagateIfChanged(operands[0],
+ operands[0]->getValue().addWrites(newWrites));
return;
} // By default, every result of an op depends on every operand.
for (const WrittenTo *r : results) {
@@ -95,7 +119,7 @@ void WrittenToAnalysis::visitBranchOperand(OpOperand &operand) {
newWrites.insert(
StringAttr::get(operand.getOwner()->getContext(),
"brancharg" + Twine(operand.getOperandNumber())));
- propagateIfChanged(lattice, lattice->addWrites(newWrites));
+ propagateIfChanged(lattice, lattice->getValue().addWrites(newWrites));
}
void WrittenToAnalysis::visitCallOperand(OpOperand &operand) {
@@ -105,7 +129,7 @@ void WrittenToAnalysis::visitCallOperand(OpOperand &operand) {
newWrites.insert(
StringAttr::get(operand.getOwner()->getContext(),
"callarg" + Twine(operand.getOperandNumber())));
- propagateIfChanged(lattice, lattice->addWrites(newWrites));
+ propagateIfChanged(lattice, lattice->getValue().addWrites(newWrites));
}
void WrittenToAnalysis::visitExternalCall(CallOpInterface call,
@@ -124,7 +148,7 @@ void WrittenToAnalysis::visitExternalCall(CallOpInterface call,
call.getOperation()->getName().getStringRef());
}
newWrites.insert(name);
- propagateIfChanged(lattice, lattice->addWrites(newWrites));
+ propagateIfChanged(lattice, lattice->getValue().addWrites(newWrites));
}
}
More information about the Mlir-commits
mailing list