[Mlir-commits] [mlir] [mlir][CSE] Introduce DialectInterface for CSE (PR #73520)
Ivan Butygin
llvmlistbot at llvm.org
Mon Nov 27 06:21:52 PST 2023
https://github.com/Hardcode84 created https://github.com/llvm/llvm-project/pull/73520
Based on https://reviews.llvm.org/D154512 by @uenoku
Depends on https://github.com/llvm/llvm-project/pull/73455 , review second commit
Changed interface return types to `std::optional` so interface can override equivalence for some ops and rely on default implementation for others.
Original description:
This patch implements CSE dialect interfaces proposed in https://discourse.llvm.org/t/rfc-dialectinterface-for-cse/71831/9.
CSEDialectInterface has three methods getHashValue and isEqual are passed to DenseMapInfo used by CSE to provide a way to customize CSE behaivor from user sides. mergeOperations is a callback when an operation is eliminated in order to propagate operation information one to another.
>From 9ff0f41cddef79d2cb11ea84496bd9e5c2e3c5e7 Mon Sep 17 00:00:00 2001
From: Hideto Ueno <uenoku.tokotoko at gmail.com>
Date: Sun, 26 Nov 2023 19:46:56 +0100
Subject: [PATCH 1/2] [mlir][OperationEquivalence] Add an extra callback to
hook operation equivalence check
Rebase of https://reviews.llvm.org/D155577 by @uenoku
The only changes are merge consflict around `simpleOpEquivalence` func and `CFGToSCF.cpp` change.
Original description:
This patch adds an extract callback `checkOpStructureEquivalent` to a method OperationEquivalence::isEquivalentTo so that users can use customize equivalence regarding structural properties of operations.
I'm not sure that "structural properties" is a best term to describe operation information used here. I think just "properties" is better but I added "structural" to clearly distinguish from operation properties.
This patch is preparation to provide dialect interface for CSE https://discourse.llvm.org/t/rfc-dialectinterface-for-cse/71831
---
mlir/include/mlir/IR/OperationSupport.h | 48 +++++++++++---
mlir/lib/IR/OperationSupport.cpp | 76 +++++++++++++++--------
mlir/lib/Transforms/Utils/CFGToSCF.cpp | 1 +
mlir/lib/Transforms/Utils/RegionUtils.cpp | 3 +-
4 files changed, 91 insertions(+), 37 deletions(-)
diff --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h
index 6a5ec129ad564b4..d6c5d50efb7adc1 100644
--- a/mlir/include/mlir/IR/OperationSupport.h
+++ b/mlir/include/mlir/IR/OperationSupport.h
@@ -1236,16 +1236,31 @@ struct OperationEquivalence {
};
/// Compute a hash for the given operation.
- /// The `hashOperands` and `hashResults` callbacks are expected to return a
- /// unique hash_code for a given Value.
+ /// The `hashOp` is a callback to compute a hash for structural properties of
+ /// `op` such as op name, result types and attributes. The `hashOperands` and
+ /// `hashResults` callbacks are expected to return a unique hash_code for a
+ /// given Value.
static llvm::hash_code computeHash(
- Operation *op,
+ Operation *op, function_ref<llvm::hash_code(Operation *)> hashOp,
function_ref<llvm::hash_code(Value)> hashOperands =
[](Value v) { return hash_value(v); },
function_ref<llvm::hash_code(Value)> hashResults =
[](Value v) { return hash_value(v); },
Flags flags = Flags::None);
+ static llvm::hash_code computeHash(
+ Operation *op,
+ function_ref<llvm::hash_code(Value)> hashOperands =
+ [](Value v) { return hash_value(v); },
+ function_ref<llvm::hash_code(Value)> hashResults =
+ [](Value v) { return hash_value(v); },
+ Flags flags = Flags::None) {
+ return computeHash(op, simpleOpHash, hashOperands, hashResults, flags);
+ }
+
+ /// Helper that can be used with `computeHash` above to combine hashes for
+ /// basic structural properties.
+ static llvm::hash_code simpleOpHash(Operation *op);
/// Helper that can be used with `computeHash` above to ignore operation
/// operands/result mapping.
static llvm::hash_code ignoreHashValue(Value) { return llvm::hash_code{}; }
@@ -1255,38 +1270,51 @@ struct OperationEquivalence {
/// Compare two operations (including their regions) and return if they are
/// equivalent.
- ///
- /// * `checkEquivalent` is a callback to check if two values are equivalent.
+ /// * `checkOpStructureEquivalent` is a callback to check if the structures of
+ /// two operations are equivalent.
+ /// * `checkValueEquivalent` is a callback to check if two values are
+ /// equivalent.
/// For two operations to be equivalent, their operands must be the same SSA
/// value or this callback must return `success`.
/// * `markEquivalent` is a callback to inform the caller that the analysis
/// determined that two values are equivalent.
///
/// Note: Additional information regarding value equivalence can be injected
- /// into the analysis via `checkEquivalent`. Typically, callers may want
- /// values that were determined to be equivalent as per `markEquivalent` to be
- /// reflected in `checkEquivalent`, unless `exactValueMatch` or a different
+ /// into the analysis via `checkOpStructureEquivalent` and
+ /// `checkValueEquivalent`. Typically, callers may want values that were
+ /// determined to be equivalent as per `markEquivalent` to be reflected in
+ /// `checkValueEquivalent`, unless `exactValueMatch` or a different
/// equivalence relationship is desired.
static bool
isEquivalentTo(Operation *lhs, Operation *rhs,
- function_ref<LogicalResult(Value, Value)> checkEquivalent,
+ function_ref<LogicalResult(Operation *, Operation *)>
+ checkOpStructureEquivalent,
+ function_ref<LogicalResult(Value, Value)> checkValueEquivalent,
function_ref<void(Value, Value)> markEquivalent = nullptr,
Flags flags = Flags::None);
/// Compare two operations and return if they are equivalent.
static bool isEquivalentTo(Operation *lhs, Operation *rhs, Flags flags);
+ static bool
+ isEquivalentTo(Operation *lhs, Operation *rhs,
+ function_ref<LogicalResult(Operation *, Operation *)>
+ checkOpStructureEquivalent,
+ Flags flags);
/// Compare two regions (including their subregions) and return if they are
/// equivalent. See also `isEquivalentTo` for details.
static bool isRegionEquivalentTo(
Region *lhs, Region *rhs,
- function_ref<LogicalResult(Value, Value)> checkEquivalent,
+ function_ref<LogicalResult(Operation *, Operation *)>
+ checkOpStructureEquivalent,
+ function_ref<LogicalResult(Value, Value)> checkValueEquivalent,
function_ref<void(Value, Value)> markEquivalent,
OperationEquivalence::Flags flags);
/// Compare two regions and return if they are equivalent.
static bool isRegionEquivalentTo(Region *lhs, Region *rhs,
OperationEquivalence::Flags flags);
+ static LogicalResult simpleOpEquivalence(Operation *lhs, Operation *rhs);
/// Helper that can be used with `isEquivalentTo` above to consider ops
/// equivalent even if their operands are not equivalent.
diff --git a/mlir/lib/IR/OperationSupport.cpp b/mlir/lib/IR/OperationSupport.cpp
index fc5ccd23b5108d8..1d2189e530cdf60 100644
--- a/mlir/lib/IR/OperationSupport.cpp
+++ b/mlir/lib/IR/OperationSupport.cpp
@@ -668,15 +668,11 @@ Value ValueRange::dereference_iterator(const OwnerT &owner, ptrdiff_t index) {
//===----------------------------------------------------------------------===//
llvm::hash_code OperationEquivalence::computeHash(
- Operation *op, function_ref<llvm::hash_code(Value)> hashOperands,
+ Operation *op, function_ref<llvm::hash_code(Operation *)> hashOp,
+ function_ref<llvm::hash_code(Value)> hashOperands,
function_ref<llvm::hash_code(Value)> hashResults, Flags flags) {
- // Hash operations based upon their:
- // - Operation Name
- // - Attributes
- // - Result Types
- llvm::hash_code hash =
- llvm::hash_combine(op->getName(), op->getDiscardableAttrDictionary(),
- op->getResultTypes(), op->hashProperties());
+ // Hash operations based upon their structural properties using `hashOp`.
+ llvm::hash_code hash = hashOp(op);
// - Location if required
if (!(flags & Flags::IgnoreLocations))
@@ -694,7 +690,9 @@ llvm::hash_code OperationEquivalence::computeHash(
/*static*/ bool OperationEquivalence::isRegionEquivalentTo(
Region *lhs, Region *rhs,
- function_ref<LogicalResult(Value, Value)> checkEquivalent,
+ function_ref<LogicalResult(Operation *, Operation *)>
+ checkOpStructureEquivalent,
+ function_ref<LogicalResult(Value, Value)> checkValueEquivalent,
function_ref<void(Value, Value)> markEquivalent,
OperationEquivalence::Flags flags) {
DenseMap<Block *, Block *> blocksMap;
@@ -724,8 +722,9 @@ llvm::hash_code OperationEquivalence::computeHash(
auto opsEquivalent = [&](Operation &lOp, Operation &rOp) {
// Check for op equality (recursively).
- if (!OperationEquivalence::isEquivalentTo(&lOp, &rOp, checkEquivalent,
- markEquivalent, flags))
+ if (!OperationEquivalence::isEquivalentTo(
+ &lOp, &rOp, checkOpStructureEquivalent, checkValueEquivalent,
+ markEquivalent, flags))
return false;
// Check successor mapping.
for (auto successorsPair :
@@ -766,7 +765,7 @@ OperationEquivalence::isRegionEquivalentTo(Region *lhs, Region *rhs,
OperationEquivalence::Flags flags) {
ValueEquivalenceCache cache;
return isRegionEquivalentTo(
- lhs, rhs,
+ lhs, rhs, simpleOpEquivalence,
[&](Value lhsValue, Value rhsValue) -> LogicalResult {
return cache.checkEquivalent(lhsValue, rhsValue);
},
@@ -776,24 +775,39 @@ OperationEquivalence::isRegionEquivalentTo(Region *lhs, Region *rhs,
flags);
}
+/*static*/ llvm::hash_code OperationEquivalence::simpleOpHash(Operation *op) {
+ return llvm::hash_combine(op->getName(), op->getResultTypes(),
+ op->hashProperties(),
+ op->getDiscardableAttrDictionary());
+}
+
+/*static*/ LogicalResult
+OperationEquivalence::simpleOpEquivalence(Operation *lhs, Operation *rhs) {
+ return LogicalResult::success(
+ lhs->getName() == rhs->getName() &&
+ lhs->getDiscardableAttrDictionary() ==
+ rhs->getDiscardableAttrDictionary() &&
+ lhs->getNumRegions() == rhs->getNumRegions() &&
+ lhs->getNumSuccessors() == rhs->getNumSuccessors() &&
+ lhs->getNumOperands() == rhs->getNumOperands() &&
+ lhs->getNumResults() == rhs->getNumResults() &&
+ lhs->getName().compareOpProperties(lhs->getPropertiesStorage(),
+ rhs->getPropertiesStorage()));
+}
+
/*static*/ bool OperationEquivalence::isEquivalentTo(
Operation *lhs, Operation *rhs,
- function_ref<LogicalResult(Value, Value)> checkEquivalent,
+ function_ref<LogicalResult(Operation *, Operation *)>
+ checkOpStructureEquivalent,
+ function_ref<LogicalResult(Value, Value)> checkValueEquivalent,
function_ref<void(Value, Value)> markEquivalent, Flags flags) {
if (lhs == rhs)
return true;
- // 1. Compare the operation properties.
- if (lhs->getName() != rhs->getName() ||
- lhs->getDiscardableAttrDictionary() !=
- rhs->getDiscardableAttrDictionary() ||
- lhs->getNumRegions() != rhs->getNumRegions() ||
- lhs->getNumSuccessors() != rhs->getNumSuccessors() ||
- lhs->getNumOperands() != rhs->getNumOperands() ||
- lhs->getNumResults() != rhs->getNumResults() ||
- !lhs->getName().compareOpProperties(lhs->getPropertiesStorage(),
- rhs->getPropertiesStorage()))
+ // 1. Compare the operation structural properties.
+ if (failed(checkOpStructureEquivalent(lhs, rhs)))
return false;
+
if (!(flags & IgnoreLocations) && lhs->getLoc() != rhs->getLoc())
return false;
@@ -805,7 +819,7 @@ OperationEquivalence::isRegionEquivalentTo(Region *lhs, Region *rhs,
continue;
if (curArg.getType() != otherArg.getType())
return false;
- if (failed(checkEquivalent(curArg, otherArg)))
+ if (failed(checkValueEquivalent(curArg, otherArg)))
return false;
}
@@ -822,7 +836,8 @@ OperationEquivalence::isRegionEquivalentTo(Region *lhs, Region *rhs,
// 4. Compare regions.
for (auto regionPair : llvm::zip(lhs->getRegions(), rhs->getRegions()))
if (!isRegionEquivalentTo(&std::get<0>(regionPair),
- &std::get<1>(regionPair), checkEquivalent,
+ &std::get<1>(regionPair),
+ checkOpStructureEquivalent, checkValueEquivalent,
markEquivalent, flags))
return false;
@@ -832,9 +847,18 @@ OperationEquivalence::isRegionEquivalentTo(Region *lhs, Region *rhs,
/*static*/ bool OperationEquivalence::isEquivalentTo(Operation *lhs,
Operation *rhs,
Flags flags) {
+ return OperationEquivalence::isEquivalentTo(lhs, rhs, simpleOpEquivalence,
+ flags);
+}
+
+/*static*/ bool OperationEquivalence::isEquivalentTo(
+ Operation *lhs, Operation *rhs,
+ function_ref<LogicalResult(Operation *, Operation *)>
+ checkOpStructureEquivalent,
+ Flags flags) {
ValueEquivalenceCache cache;
return OperationEquivalence::isEquivalentTo(
- lhs, rhs,
+ lhs, rhs, checkOpStructureEquivalent,
[&](Value lhsValue, Value rhsValue) -> LogicalResult {
return cache.checkEquivalent(lhsValue, rhsValue);
},
diff --git a/mlir/lib/Transforms/Utils/CFGToSCF.cpp b/mlir/lib/Transforms/Utils/CFGToSCF.cpp
index f2998b4047e201e..9982015ba3de779 100644
--- a/mlir/lib/Transforms/Utils/CFGToSCF.cpp
+++ b/mlir/lib/Transforms/Utils/CFGToSCF.cpp
@@ -411,6 +411,7 @@ struct ReturnLikeOpEquivalence : public llvm::DenseMapInfo<Operation *> {
return false;
return OperationEquivalence::isEquivalentTo(
const_cast<Operation *>(lhs), const_cast<Operation *>(rhs),
+ OperationEquivalence::simpleOpEquivalence,
OperationEquivalence::ignoreValueEquivalence, nullptr,
OperationEquivalence::IgnoreLocations);
}
diff --git a/mlir/lib/Transforms/Utils/RegionUtils.cpp b/mlir/lib/Transforms/Utils/RegionUtils.cpp
index 1f2344677e6515c..e4e14baaeab7593 100644
--- a/mlir/lib/Transforms/Utils/RegionUtils.cpp
+++ b/mlir/lib/Transforms/Utils/RegionUtils.cpp
@@ -592,7 +592,8 @@ LogicalResult BlockMergeCluster::addToCluster(BlockEquivalenceData &blockData) {
for (int opI = 0; lhsIt != lhsE && rhsIt != rhsE; ++lhsIt, ++rhsIt, ++opI) {
// Check that the operations are equivalent.
if (!OperationEquivalence::isEquivalentTo(
- &*lhsIt, &*rhsIt, OperationEquivalence::ignoreValueEquivalence,
+ &*lhsIt, &*rhsIt, OperationEquivalence::simpleOpEquivalence,
+ OperationEquivalence::ignoreValueEquivalence,
/*markEquivalent=*/nullptr,
OperationEquivalence::Flags::IgnoreLocations))
return failure();
>From 502d81f6c549a11622d661011e0403a56f8c6177 Mon Sep 17 00:00:00 2001
From: Hideto Ueno <uenoku.tokotoko at gmail.com>
Date: Mon, 27 Nov 2023 14:42:06 +0100
Subject: [PATCH 2/2] [mlir][CSE] Introduce DialectInterface for CSE
Based on https://reviews.llvm.org/D154512 by @uenoku
Depends on https://github.com/llvm/llvm-project/pull/73455
Changed interface return types to `std::optional` so interface can override equivalence for some ops and rely on default implementation for others.
Original description:
This patch implements CSE dialect interfaces proposed in https://discourse.llvm.org/t/rfc-dialectinterface-for-cse/71831/9.
CSEDialectInterface has three methods getHashValue and isEqual are passed to DenseMapInfo used by CSE to provide a way to customize CSE behaivor from user sides. mergeOperations is a callback when an operation is eliminated in order to propagate operation information one to another.
---
mlir/include/mlir/Transforms/CSE.h | 37 ++++++++
mlir/lib/Transforms/CSE.cpp | 34 ++++++-
mlir/test/Transforms/cse.mlir | 11 +++
.../Dialect/Test/TestDialectInterfaces.cpp | 88 ++++++++++++++++++-
4 files changed, 168 insertions(+), 2 deletions(-)
diff --git a/mlir/include/mlir/Transforms/CSE.h b/mlir/include/mlir/Transforms/CSE.h
index 3d01ece0780509d..053e60b487b6d2c 100644
--- a/mlir/include/mlir/Transforms/CSE.h
+++ b/mlir/include/mlir/Transforms/CSE.h
@@ -13,11 +13,14 @@
#ifndef MLIR_TRANSFORMS_CSE_H_
#define MLIR_TRANSFORMS_CSE_H_
+#include "mlir/IR/DialectInterface.h"
+
namespace mlir {
class DominanceInfo;
class Operation;
class RewriterBase;
+struct LogicalResult;
/// Eliminate common subexpressions within the given operation. This transform
/// looks for and deduplicates equivalent operations.
@@ -27,6 +30,40 @@ void eliminateCommonSubExpressions(RewriterBase &rewriter,
DominanceInfo &domInfo, Operation *op,
bool *changed = nullptr);
+//===----------------------------------------------------------------------===//
+// CSEInterface
+//===----------------------------------------------------------------------===//
+
+/// This is the interface that allows users to customize CSE.
+class DialectCSEInterface : public DialectInterface::Base<DialectCSEInterface> {
+public:
+ DialectCSEInterface(Dialect *dialect) : Base(dialect) {}
+
+ //===--------------------------------------------------------------------===//
+ // Analysis Hooks
+ //===--------------------------------------------------------------------===//
+
+ /// These two hooks are called in DenseMapInfo used by CSE.
+
+ /// Returns a hash for the operation.
+ /// CSE will use default implementation if `std::nullopt` is returned.
+ virtual std::optional<unsigned> getHashValue(Operation *op) const = 0;
+
+ /// Returns true if two operations are considered to be equivalent.
+ /// CSE will use default implementation if `std::nullopt` is returned.
+ virtual std::optional<bool> isEqual(Operation *lhs, Operation *rhs) const = 0;
+
+ //===--------------------------------------------------------------------===//
+ // Transformation Hooks
+ //===--------------------------------------------------------------------===//
+
+ /// This hook is called when 'op' is considered to be a common subexpression
+ /// of 'existingOp' and is going to be eliminated. This hook allows users to
+ /// propagate information of 'op' to 'existingOp'. Note that the hash value of
+ /// 'existingOp' must not be changed due the mutation of 'existingOp'.
+ virtual void mergeOperations(Operation *existingOp, Operation *op) const {}
+};
+
} // namespace mlir
#endif // MLIR_TRANSFORMS_CSE_H_
diff --git a/mlir/lib/Transforms/CSE.cpp b/mlir/lib/Transforms/CSE.cpp
index 3affd88d158de59..771467429c8a8e2 100644
--- a/mlir/lib/Transforms/CSE.cpp
+++ b/mlir/lib/Transforms/CSE.cpp
@@ -35,8 +35,16 @@ using namespace mlir;
namespace {
struct SimpleOperationInfo : public llvm::DenseMapInfo<Operation *> {
static unsigned getHashValue(const Operation *opC) {
+ auto *op = const_cast<Operation *>(opC);
+ // Use a custom hook if provided.
+ if (auto *interface = dyn_cast<DialectCSEInterface>(op->getDialect())) {
+ std::optional<unsigned> val = interface->getHashValue(op);
+ if (val)
+ return *val;
+ }
+
return OperationEquivalence::computeHash(
- const_cast<Operation *>(opC),
+ op,
/*hashOperands=*/OperationEquivalence::directHashValue,
/*hashResults=*/OperationEquivalence::ignoreHashValue,
OperationEquivalence::IgnoreLocations);
@@ -49,6 +57,18 @@ struct SimpleOperationInfo : public llvm::DenseMapInfo<Operation *> {
if (lhs == getTombstoneKey() || lhs == getEmptyKey() ||
rhs == getTombstoneKey() || rhs == getEmptyKey())
return false;
+
+ if (lhs->getDialect() != rhs->getDialect())
+ return false;
+
+ if (auto *interface = dyn_cast<DialectCSEInterface>(lhs->getDialect())) {
+ std::optional<bool> val = interface->isEqual(lhs, rhs);
+ assert(val == interface->isEqual(rhs, lhs) &&
+ "DialectCSEInterface::isEqual must be symmetrical");
+ if (val)
+ return *val;
+ }
+
return OperationEquivalence::isEquivalentTo(
const_cast<Operation *>(lhsC), const_cast<Operation *>(rhsC),
OperationEquivalence::IgnoreLocations);
@@ -131,6 +151,18 @@ class CSEDriver {
void CSEDriver::replaceUsesAndDelete(ScopedMapTy &knownValues, Operation *op,
Operation *existing,
bool hasSSADominance) {
+ // Invoke a callback provided by CSE interface.
+ if (auto *cseInterface = dyn_cast<DialectCSEInterface>(op->getDialect())) {
+#ifndef NDEBUG
+ auto hashPrev = cseInterface->getHashValue(existing);
+#endif
+ cseInterface->mergeOperations(existing, op);
+#ifndef NDEBUG
+ assert(hashPrev == cseInterface->getHashValue(existing) &&
+ "hash values must not be modified by `mergeOperations`");
+#endif
+ }
+
// If we find one then replace all uses of the current operation with the
// existing one and mark it for deletion. We can only replace an operand in
// an operation if it has not been visited yet.
diff --git a/mlir/test/Transforms/cse.mlir b/mlir/test/Transforms/cse.mlir
index c764d2b9bd57d8a..1e348f2d480c9eb 100644
--- a/mlir/test/Transforms/cse.mlir
+++ b/mlir/test/Transforms/cse.mlir
@@ -520,3 +520,14 @@ func.func @cse_recursive_effects_failure() -> (i32, i32, i32) {
%2 = "test.op_with_memread"() : () -> (i32)
return %0, %2, %1 : i32, i32, i32
}
+
+func.func @cse_test_dialect_interface() -> (i32, i32, i32) {
+ %0 = "test.always_speculatable_op"() {test.non_essential = "b"} : () -> i32
+ %1 = "test.always_speculatable_op"() {test.non_essential = "a"} : () -> i32
+ %2 = "test.always_speculatable_op"() {test.non_essential = "c"} : () -> i32
+ return %0, %1, %2 : i32, i32, i32
+}
+
+// CHECK-LABEL: @cse_test_dialect_interface
+// CHECK-NEXT: %[[result:.+]] = "test.always_speculatable_op"() {test.non_essential = "a"} : () -> i32
+// CHECK-NEXT: return %[[result]], %[[result]], %[[result]]
diff --git a/mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp b/mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp
index 80ddcdc8ea69f7c..6f5d86f3b9bc2ca 100644
--- a/mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp
+++ b/mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp
@@ -9,6 +9,7 @@
#include "TestDialect.h"
#include "mlir/Interfaces/FoldInterfaces.h"
#include "mlir/Reducer/ReductionPatternInterface.h"
+#include "mlir/Transforms/CSE.h"
#include "mlir/Transforms/InliningUtils.h"
using namespace mlir;
@@ -385,6 +386,90 @@ struct TestReductionPatternInterface : public DialectReductionPatternInterface {
}
};
+/// This class defines the interface for customizing CSE.
+struct TestCSEInterface : public DialectCSEInterface {
+ using DialectCSEInterface::DialectCSEInterface;
+
+ StringRef nonEssentialAttrName = "test.non_essential";
+
+ //===--------------------------------------------------------------------===//
+ // Analysis Hooks
+ //===--------------------------------------------------------------------===//
+
+ /// Return a hash that excludes 'test.non_essential' attributes.
+ std::optional<unsigned> getHashValue(Operation *op) const override {
+ auto attr = op->getDiscardableAttrDictionary();
+ if (!attr.contains(nonEssentialAttrName))
+ return std::nullopt;
+
+ auto hashOp = [&](Operation *op) {
+ auto hash = llvm::hash_combine(op->getName(), op->getResultTypes(),
+ op->hashProperties());
+ NamedAttrList attributes(attr);
+ attributes.erase(nonEssentialAttrName);
+ return llvm::hash_combine(hash,
+ attributes.getDictionary(op->getContext()));
+ };
+
+ return OperationEquivalence::computeHash(
+ op,
+ /*hashOp=*/hashOp,
+ /*hashOperands=*/OperationEquivalence::directHashValue,
+ /*hashResults=*/OperationEquivalence::ignoreHashValue,
+ OperationEquivalence::IgnoreLocations);
+ }
+
+ /// Return true if operations are same except for 'test.non_essential'
+ /// attributes.
+ std::optional<bool> isEqual(Operation *lhs, Operation *rhs) const override {
+ if (!lhs->getDiscardableAttrDictionary().contains(nonEssentialAttrName) ||
+ !rhs->getDiscardableAttrDictionary().contains(nonEssentialAttrName))
+ return std::nullopt;
+
+ auto checkOp = [&](Operation *lhs, Operation *rhs) -> LogicalResult {
+ bool result =
+ lhs->getName() == rhs->getName() &&
+ lhs->getNumRegions() == rhs->getNumRegions() &&
+ lhs->getNumSuccessors() == rhs->getNumSuccessors() &&
+ lhs->getNumOperands() == rhs->getNumOperands() &&
+ lhs->getNumResults() == rhs->getNumResults() &&
+ lhs->getName().compareOpProperties(lhs->getPropertiesStorage(),
+ rhs->getPropertiesStorage());
+ if (!result)
+ return failure();
+ auto lhsAttr = lhs->getDiscardableAttrs();
+ auto rhsAttr = rhs->getDiscardableAttrs();
+ NamedAttrList lhsFiltered(lhsAttr);
+ lhsFiltered.erase(nonEssentialAttrName);
+ NamedAttrList rhsFiltered(rhsAttr);
+ rhsFiltered.erase(nonEssentialAttrName);
+ return LogicalResult::success(lhsFiltered == rhsFiltered);
+ };
+
+ return OperationEquivalence::isEquivalentTo(
+ lhs, rhs, checkOp, OperationEquivalence::IgnoreLocations);
+ };
+
+ //===--------------------------------------------------------------------===//
+ // Transformation Hooks
+ //===--------------------------------------------------------------------===//
+
+ /// Propagate 'test.non_essential' to an existing op. Use a smaller value if
+ /// both have the attribute.
+ void mergeOperations(Operation *existingOp,
+ Operation *opsToBeDeleted) const override {
+ if (auto rhs =
+ opsToBeDeleted->getAttrOfType<StringAttr>(nonEssentialAttrName))
+ if (auto lhs =
+ existingOp->getAttrOfType<StringAttr>(nonEssentialAttrName)) {
+ if (rhs.getValue() < lhs.getValue()) {
+ existingOp->setAttr(nonEssentialAttrName, rhs);
+ }
+ } else {
+ existingOp->setAttr(nonEssentialAttrName, rhs);
+ }
+ }
+};
} // namespace
void TestDialect::registerInterfaces() {
@@ -392,5 +477,6 @@ void TestDialect::registerInterfaces() {
addInterface<TestOpAsmInterface>(blobInterface);
addInterfaces<TestDialectFoldInterface, TestInlinerInterface,
- TestReductionPatternInterface, TestBytecodeDialectInterface>();
+ TestReductionPatternInterface, TestCSEInterface,
+ TestBytecodeDialectInterface>();
}
More information about the Mlir-commits
mailing list