[Mlir-commits] [mlir] [mlir][OperationEquivalence] Add an extra callback to hook operation equivalence check (PR #73455)
Ivan Butygin
llvmlistbot at llvm.org
Sun Nov 26 10:59:04 PST 2023
https://github.com/Hardcode84 created https://github.com/llvm/llvm-project/pull/73455
Rebase of https://reviews.llvm.org/D155577 by @uenoku
The only changes are merge conflict around `simpleOpEquivalence` func and `CFGToSCF.cpp` change.
Original description:
This patch adds an extract callback `checkOpStructureEquivalent` to a method `OperationEquivalence::isEquivalentTo` so that users can use customize equivalence regarding structural properties of operations.
I'm not sure that "structural properties" is a best term to describe operation information used here. I think just "properties" is better but I added "structural" to clearly distinguish from operation properties.
This patch is preparation to provide dialect interface for CSE https://discourse.llvm.org/t/rfc-dialectinterface-for-cse/71831
>From 6ab1bfdb1971653985c83c7d8efa9f36290bce3c Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sun, 26 Nov 2023 19:46:56 +0100
Subject: [PATCH] [mlir][OperationEquivalence] Add an extra callback to hook
operation equivalence check
Rebase of https://reviews.llvm.org/D155577 by @uenoku
The only changes are merge consflict around `simpleOpEquivalence` func and `CFGToSCF.cpp` change.
Original description:
This patch adds an extract callback `checkOpStructureEquivalent` to a method OperationEquivalence::isEquivalentTo so that users can use customize equivalence regarding structural properties of operations.
I'm not sure that "structural properties" is a best term to describe operation information used here. I think just "properties" is better but I added "structural" to clearly distinguish from operation properties.
This patch is preparation to provide dialect interface for CSE https://discourse.llvm.org/t/rfc-dialectinterface-for-cse/71831
---
mlir/include/mlir/IR/OperationSupport.h | 48 +++++++++++---
mlir/lib/IR/OperationSupport.cpp | 76 +++++++++++++++--------
mlir/lib/Transforms/Utils/CFGToSCF.cpp | 1 +
mlir/lib/Transforms/Utils/RegionUtils.cpp | 3 +-
4 files changed, 91 insertions(+), 37 deletions(-)
diff --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h
index 6a5ec129ad564b4..d6c5d50efb7adc1 100644
--- a/mlir/include/mlir/IR/OperationSupport.h
+++ b/mlir/include/mlir/IR/OperationSupport.h
@@ -1236,16 +1236,31 @@ struct OperationEquivalence {
};
/// Compute a hash for the given operation.
- /// The `hashOperands` and `hashResults` callbacks are expected to return a
- /// unique hash_code for a given Value.
+ /// The `hashOp` is a callback to compute a hash for structural properties of
+ /// `op` such as op name, result types and attributes. The `hashOperands` and
+ /// `hashResults` callbacks are expected to return a unique hash_code for a
+ /// given Value.
static llvm::hash_code computeHash(
- Operation *op,
+ Operation *op, function_ref<llvm::hash_code(Operation *)> hashOp,
function_ref<llvm::hash_code(Value)> hashOperands =
[](Value v) { return hash_value(v); },
function_ref<llvm::hash_code(Value)> hashResults =
[](Value v) { return hash_value(v); },
Flags flags = Flags::None);
+ static llvm::hash_code computeHash(
+ Operation *op,
+ function_ref<llvm::hash_code(Value)> hashOperands =
+ [](Value v) { return hash_value(v); },
+ function_ref<llvm::hash_code(Value)> hashResults =
+ [](Value v) { return hash_value(v); },
+ Flags flags = Flags::None) {
+ return computeHash(op, simpleOpHash, hashOperands, hashResults, flags);
+ }
+
+ /// Helper that can be used with `computeHash` above to combine hashes for
+ /// basic structural properties.
+ static llvm::hash_code simpleOpHash(Operation *op);
/// Helper that can be used with `computeHash` above to ignore operation
/// operands/result mapping.
static llvm::hash_code ignoreHashValue(Value) { return llvm::hash_code{}; }
@@ -1255,38 +1270,51 @@ struct OperationEquivalence {
/// Compare two operations (including their regions) and return if they are
/// equivalent.
- ///
- /// * `checkEquivalent` is a callback to check if two values are equivalent.
+ /// * `checkOpStructureEquivalent` is a callback to check if the structures of
+ /// two operations are equivalent.
+ /// * `checkValueEquivalent` is a callback to check if two values are
+ /// equivalent.
/// For two operations to be equivalent, their operands must be the same SSA
/// value or this callback must return `success`.
/// * `markEquivalent` is a callback to inform the caller that the analysis
/// determined that two values are equivalent.
///
/// Note: Additional information regarding value equivalence can be injected
- /// into the analysis via `checkEquivalent`. Typically, callers may want
- /// values that were determined to be equivalent as per `markEquivalent` to be
- /// reflected in `checkEquivalent`, unless `exactValueMatch` or a different
+ /// into the analysis via `checkOpStructureEquivalent` and
+ /// `checkValueEquivalent`. Typically, callers may want values that were
+ /// determined to be equivalent as per `markEquivalent` to be reflected in
+ /// `checkValueEquivalent`, unless `exactValueMatch` or a different
/// equivalence relationship is desired.
static bool
isEquivalentTo(Operation *lhs, Operation *rhs,
- function_ref<LogicalResult(Value, Value)> checkEquivalent,
+ function_ref<LogicalResult(Operation *, Operation *)>
+ checkOpStructureEquivalent,
+ function_ref<LogicalResult(Value, Value)> checkValueEquivalent,
function_ref<void(Value, Value)> markEquivalent = nullptr,
Flags flags = Flags::None);
/// Compare two operations and return if they are equivalent.
static bool isEquivalentTo(Operation *lhs, Operation *rhs, Flags flags);
+ static bool
+ isEquivalentTo(Operation *lhs, Operation *rhs,
+ function_ref<LogicalResult(Operation *, Operation *)>
+ checkOpStructureEquivalent,
+ Flags flags);
/// Compare two regions (including their subregions) and return if they are
/// equivalent. See also `isEquivalentTo` for details.
static bool isRegionEquivalentTo(
Region *lhs, Region *rhs,
- function_ref<LogicalResult(Value, Value)> checkEquivalent,
+ function_ref<LogicalResult(Operation *, Operation *)>
+ checkOpStructureEquivalent,
+ function_ref<LogicalResult(Value, Value)> checkValueEquivalent,
function_ref<void(Value, Value)> markEquivalent,
OperationEquivalence::Flags flags);
/// Compare two regions and return if they are equivalent.
static bool isRegionEquivalentTo(Region *lhs, Region *rhs,
OperationEquivalence::Flags flags);
+ static LogicalResult simpleOpEquivalence(Operation *lhs, Operation *rhs);
/// Helper that can be used with `isEquivalentTo` above to consider ops
/// equivalent even if their operands are not equivalent.
diff --git a/mlir/lib/IR/OperationSupport.cpp b/mlir/lib/IR/OperationSupport.cpp
index fc5ccd23b5108d8..1d2189e530cdf60 100644
--- a/mlir/lib/IR/OperationSupport.cpp
+++ b/mlir/lib/IR/OperationSupport.cpp
@@ -668,15 +668,11 @@ Value ValueRange::dereference_iterator(const OwnerT &owner, ptrdiff_t index) {
//===----------------------------------------------------------------------===//
llvm::hash_code OperationEquivalence::computeHash(
- Operation *op, function_ref<llvm::hash_code(Value)> hashOperands,
+ Operation *op, function_ref<llvm::hash_code(Operation *)> hashOp,
+ function_ref<llvm::hash_code(Value)> hashOperands,
function_ref<llvm::hash_code(Value)> hashResults, Flags flags) {
- // Hash operations based upon their:
- // - Operation Name
- // - Attributes
- // - Result Types
- llvm::hash_code hash =
- llvm::hash_combine(op->getName(), op->getDiscardableAttrDictionary(),
- op->getResultTypes(), op->hashProperties());
+ // Hash operations based upon their structural properties using `hashOp`.
+ llvm::hash_code hash = hashOp(op);
// - Location if required
if (!(flags & Flags::IgnoreLocations))
@@ -694,7 +690,9 @@ llvm::hash_code OperationEquivalence::computeHash(
/*static*/ bool OperationEquivalence::isRegionEquivalentTo(
Region *lhs, Region *rhs,
- function_ref<LogicalResult(Value, Value)> checkEquivalent,
+ function_ref<LogicalResult(Operation *, Operation *)>
+ checkOpStructureEquivalent,
+ function_ref<LogicalResult(Value, Value)> checkValueEquivalent,
function_ref<void(Value, Value)> markEquivalent,
OperationEquivalence::Flags flags) {
DenseMap<Block *, Block *> blocksMap;
@@ -724,8 +722,9 @@ llvm::hash_code OperationEquivalence::computeHash(
auto opsEquivalent = [&](Operation &lOp, Operation &rOp) {
// Check for op equality (recursively).
- if (!OperationEquivalence::isEquivalentTo(&lOp, &rOp, checkEquivalent,
- markEquivalent, flags))
+ if (!OperationEquivalence::isEquivalentTo(
+ &lOp, &rOp, checkOpStructureEquivalent, checkValueEquivalent,
+ markEquivalent, flags))
return false;
// Check successor mapping.
for (auto successorsPair :
@@ -766,7 +765,7 @@ OperationEquivalence::isRegionEquivalentTo(Region *lhs, Region *rhs,
OperationEquivalence::Flags flags) {
ValueEquivalenceCache cache;
return isRegionEquivalentTo(
- lhs, rhs,
+ lhs, rhs, simpleOpEquivalence,
[&](Value lhsValue, Value rhsValue) -> LogicalResult {
return cache.checkEquivalent(lhsValue, rhsValue);
},
@@ -776,24 +775,39 @@ OperationEquivalence::isRegionEquivalentTo(Region *lhs, Region *rhs,
flags);
}
+/*static*/ llvm::hash_code OperationEquivalence::simpleOpHash(Operation *op) {
+ return llvm::hash_combine(op->getName(), op->getResultTypes(),
+ op->hashProperties(),
+ op->getDiscardableAttrDictionary());
+}
+
+/*static*/ LogicalResult
+OperationEquivalence::simpleOpEquivalence(Operation *lhs, Operation *rhs) {
+ return LogicalResult::success(
+ lhs->getName() == rhs->getName() &&
+ lhs->getDiscardableAttrDictionary() ==
+ rhs->getDiscardableAttrDictionary() &&
+ lhs->getNumRegions() == rhs->getNumRegions() &&
+ lhs->getNumSuccessors() == rhs->getNumSuccessors() &&
+ lhs->getNumOperands() == rhs->getNumOperands() &&
+ lhs->getNumResults() == rhs->getNumResults() &&
+ lhs->getName().compareOpProperties(lhs->getPropertiesStorage(),
+ rhs->getPropertiesStorage()));
+}
+
/*static*/ bool OperationEquivalence::isEquivalentTo(
Operation *lhs, Operation *rhs,
- function_ref<LogicalResult(Value, Value)> checkEquivalent,
+ function_ref<LogicalResult(Operation *, Operation *)>
+ checkOpStructureEquivalent,
+ function_ref<LogicalResult(Value, Value)> checkValueEquivalent,
function_ref<void(Value, Value)> markEquivalent, Flags flags) {
if (lhs == rhs)
return true;
- // 1. Compare the operation properties.
- if (lhs->getName() != rhs->getName() ||
- lhs->getDiscardableAttrDictionary() !=
- rhs->getDiscardableAttrDictionary() ||
- lhs->getNumRegions() != rhs->getNumRegions() ||
- lhs->getNumSuccessors() != rhs->getNumSuccessors() ||
- lhs->getNumOperands() != rhs->getNumOperands() ||
- lhs->getNumResults() != rhs->getNumResults() ||
- !lhs->getName().compareOpProperties(lhs->getPropertiesStorage(),
- rhs->getPropertiesStorage()))
+ // 1. Compare the operation structural properties.
+ if (failed(checkOpStructureEquivalent(lhs, rhs)))
return false;
+
if (!(flags & IgnoreLocations) && lhs->getLoc() != rhs->getLoc())
return false;
@@ -805,7 +819,7 @@ OperationEquivalence::isRegionEquivalentTo(Region *lhs, Region *rhs,
continue;
if (curArg.getType() != otherArg.getType())
return false;
- if (failed(checkEquivalent(curArg, otherArg)))
+ if (failed(checkValueEquivalent(curArg, otherArg)))
return false;
}
@@ -822,7 +836,8 @@ OperationEquivalence::isRegionEquivalentTo(Region *lhs, Region *rhs,
// 4. Compare regions.
for (auto regionPair : llvm::zip(lhs->getRegions(), rhs->getRegions()))
if (!isRegionEquivalentTo(&std::get<0>(regionPair),
- &std::get<1>(regionPair), checkEquivalent,
+ &std::get<1>(regionPair),
+ checkOpStructureEquivalent, checkValueEquivalent,
markEquivalent, flags))
return false;
@@ -832,9 +847,18 @@ OperationEquivalence::isRegionEquivalentTo(Region *lhs, Region *rhs,
/*static*/ bool OperationEquivalence::isEquivalentTo(Operation *lhs,
Operation *rhs,
Flags flags) {
+ return OperationEquivalence::isEquivalentTo(lhs, rhs, simpleOpEquivalence,
+ flags);
+}
+
+/*static*/ bool OperationEquivalence::isEquivalentTo(
+ Operation *lhs, Operation *rhs,
+ function_ref<LogicalResult(Operation *, Operation *)>
+ checkOpStructureEquivalent,
+ Flags flags) {
ValueEquivalenceCache cache;
return OperationEquivalence::isEquivalentTo(
- lhs, rhs,
+ lhs, rhs, checkOpStructureEquivalent,
[&](Value lhsValue, Value rhsValue) -> LogicalResult {
return cache.checkEquivalent(lhsValue, rhsValue);
},
diff --git a/mlir/lib/Transforms/Utils/CFGToSCF.cpp b/mlir/lib/Transforms/Utils/CFGToSCF.cpp
index f2998b4047e201e..9982015ba3de779 100644
--- a/mlir/lib/Transforms/Utils/CFGToSCF.cpp
+++ b/mlir/lib/Transforms/Utils/CFGToSCF.cpp
@@ -411,6 +411,7 @@ struct ReturnLikeOpEquivalence : public llvm::DenseMapInfo<Operation *> {
return false;
return OperationEquivalence::isEquivalentTo(
const_cast<Operation *>(lhs), const_cast<Operation *>(rhs),
+ OperationEquivalence::simpleOpEquivalence,
OperationEquivalence::ignoreValueEquivalence, nullptr,
OperationEquivalence::IgnoreLocations);
}
diff --git a/mlir/lib/Transforms/Utils/RegionUtils.cpp b/mlir/lib/Transforms/Utils/RegionUtils.cpp
index 1f2344677e6515c..e4e14baaeab7593 100644
--- a/mlir/lib/Transforms/Utils/RegionUtils.cpp
+++ b/mlir/lib/Transforms/Utils/RegionUtils.cpp
@@ -592,7 +592,8 @@ LogicalResult BlockMergeCluster::addToCluster(BlockEquivalenceData &blockData) {
for (int opI = 0; lhsIt != lhsE && rhsIt != rhsE; ++lhsIt, ++rhsIt, ++opI) {
// Check that the operations are equivalent.
if (!OperationEquivalence::isEquivalentTo(
- &*lhsIt, &*rhsIt, OperationEquivalence::ignoreValueEquivalence,
+ &*lhsIt, &*rhsIt, OperationEquivalence::simpleOpEquivalence,
+ OperationEquivalence::ignoreValueEquivalence,
/*markEquivalent=*/nullptr,
OperationEquivalence::Flags::IgnoreLocations))
return failure();
More information about the Mlir-commits
mailing list