[Mlir-commits] [mlir] c864288 - [mlir][transforms] Simplify OperationEquivalence and CSE
Matthias Springer
llvmlistbot at llvm.org
Fri Jan 27 02:57:20 PST 2023
Author: Matthias Springer
Date: 2023-01-27T11:56:48+01:00
New Revision: c864288dcf169fd1aa663320b7414c0ee72d5c2b
URL: https://github.com/llvm/llvm-project/commit/c864288dcf169fd1aa663320b7414c0ee72d5c2b
DIFF: https://github.com/llvm/llvm-project/commit/c864288dcf169fd1aa663320b7414c0ee72d5c2b.diff
LOG: [mlir][transforms] Simplify OperationEquivalence and CSE
Replace `mapOperands` and `mapResults` with two new callbacks. It was not clear what "mapping" meant and why the equivalence relationship was a property of the Operand/OpResult as opposed to just SSA values.
This revision changes the contract of the two callbacks: `checkEquivalent` compares two values for equivalence. `markEquivalent` informs the caller that the analysis determined that two values are equivalent. This simplifies the API because callers do not have to reason about operands/results, but just SSA values.
`OperationEquivalence::isEquivalentTo` can be used directly in CSE and there is no need for a custom op equivalence analysis.
Differential Revision: https://reviews.llvm.org/D142558
Added:
Modified:
mlir/include/mlir/IR/OperationSupport.h
mlir/lib/IR/OperationSupport.cpp
mlir/lib/Transforms/CSE.cpp
mlir/lib/Transforms/Utils/RegionUtils.cpp
mlir/test/lib/IR/TestOperationEquals.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h
index 99e63f207e96b..003029d6db94b 100644
--- a/mlir/include/mlir/IR/OperationSupport.h
+++ b/mlir/include/mlir/IR/OperationSupport.h
@@ -926,24 +926,36 @@ struct OperationEquivalence {
/// operands/result mapping.
static llvm::hash_code directHashValue(Value v) { return hash_value(v); }
- /// Compare two operations and return if they are equivalent.
- /// `mapOperands` and `mapResults` are optional callbacks that allows the
- /// caller to check the mapping of SSA value between the lhs and rhs
- /// operations. It is expected to return success if the mapping is valid and
- /// failure if it conflicts with a previous mapping.
+ /// Compare two operations (including their regions) and return if they are
+ /// equivalent.
+ ///
+ /// * `checkEquivalent` is a callback to check if two values are equivalent.
+ /// For two operations to be equivalent, their operands must be the same SSA
+ /// value or this callback must return `success`.
+ /// * `markEquivalent` is a callback to inform the caller that the analysis
+ /// determined that two values are equivalent.
+ ///
+ /// Note: Additional information regarding value equivalence can be injected
+ /// into the analysis via `checkEquivalent`. Typically, callers may want
+ /// values that were determined to be equivalent as per `markEquivalent` to be
+ /// reflected in `checkEquivalent`, unless `exactValueMatch` or a
diff erent
+ /// equivalence relationship is desired.
static bool
isEquivalentTo(Operation *lhs, Operation *rhs,
- function_ref<LogicalResult(Value, Value)> mapOperands,
- function_ref<LogicalResult(Value, Value)> mapResults,
+ function_ref<LogicalResult(Value, Value)> checkEquivalent,
+ function_ref<void(Value, Value)> markEquivalent = nullptr,
Flags flags = Flags::None);
- /// Helper that can be used with `isEquivalentTo` above to ignore operation
- /// operands/result mapping.
+ /// Compare two operations and return if they are equivalent.
+ static bool isEquivalentTo(Operation *lhs, Operation *rhs, Flags flags);
+
+ /// Helper that can be used with `isEquivalentTo` above to consider ops
+ /// equivalent even if their operands are not equivalent.
static LogicalResult ignoreValueEquivalence(Value lhs, Value rhs) {
return success();
}
- /// Helper that can be used with `isEquivalentTo` above to ignore operation
- /// operands/result mapping.
+ /// Helper that can be used with `isEquivalentTo` above to consider ops
+ /// equivalent only if their operands are the exact same SSA values.
static LogicalResult exactValueMatch(Value lhs, Value rhs) {
return success(lhs == rhs);
}
diff --git a/mlir/lib/IR/OperationSupport.cpp b/mlir/lib/IR/OperationSupport.cpp
index 8f3d37976d2a8..20ce9b36737fb 100644
--- a/mlir/lib/IR/OperationSupport.cpp
+++ b/mlir/lib/IR/OperationSupport.cpp
@@ -652,8 +652,8 @@ llvm::hash_code OperationEquivalence::computeHash(
static bool
isRegionEquivalentTo(Region *lhs, Region *rhs,
- function_ref<LogicalResult(Value, Value)> mapOperands,
- function_ref<LogicalResult(Value, Value)> mapResults,
+ function_ref<LogicalResult(Value, Value)> checkEquivalent,
+ function_ref<void(Value, Value)> markEquivalent,
OperationEquivalence::Flags flags) {
DenseMap<Block *, Block *> blocksMap;
auto blocksEquivalent = [&](Block &lBlock, Block &rBlock) {
@@ -675,15 +675,15 @@ isRegionEquivalentTo(Region *lhs, Region *rhs,
if (!(flags & OperationEquivalence::IgnoreLocations) &&
curArg.getLoc() != otherArg.getLoc())
return false;
- // Check if this value was already mapped to another value.
- if (failed(mapOperands(curArg, otherArg)))
- return false;
+ // Corresponding bbArgs are equivalent.
+ if (markEquivalent)
+ markEquivalent(curArg, otherArg);
}
auto opsEquivalent = [&](Operation &lOp, Operation &rOp) {
// Check for op equality (recursively).
- if (!OperationEquivalence::isEquivalentTo(&lOp, &rOp, mapOperands,
- mapResults, flags))
+ if (!OperationEquivalence::isEquivalentTo(&lOp, &rOp, checkEquivalent,
+ markEquivalent, flags))
return false;
// Check successor mapping.
for (auto successorsPair :
@@ -703,12 +703,12 @@ isRegionEquivalentTo(Region *lhs, Region *rhs,
bool OperationEquivalence::isEquivalentTo(
Operation *lhs, Operation *rhs,
- function_ref<LogicalResult(Value, Value)> mapOperands,
- function_ref<LogicalResult(Value, Value)> mapResults, Flags flags) {
+ function_ref<LogicalResult(Value, Value)> checkEquivalent,
+ function_ref<void(Value, Value)> markEquivalent, Flags flags) {
if (lhs == rhs)
return true;
- // Compare the operation properties.
+ // 1. Compare the operation properties.
if (lhs->getName() != rhs->getName() ||
lhs->getAttrDictionary() != rhs->getAttrDictionary() ||
lhs->getNumRegions() != rhs->getNumRegions() ||
@@ -719,6 +719,7 @@ bool OperationEquivalence::isEquivalentTo(
if (!(flags & IgnoreLocations) && lhs->getLoc() != rhs->getLoc())
return false;
+ // 2. Compare operands.
ValueRange lhsOperands = lhs->getOperands(), rhsOperands = rhs->getOperands();
SmallVector<Value> lhsOperandStorage, rhsOperandStorage;
if (lhs->hasTrait<mlir::OpTrait::IsCommutative>()) {
@@ -752,32 +753,58 @@ bool OperationEquivalence::isEquivalentTo(
rhsOperandStorage = sortValues(rhsOperands);
rhsOperands = rhsOperandStorage;
}
- auto checkValueRangeMapping =
- [](ValueRange lhs, ValueRange rhs,
- function_ref<LogicalResult(Value, Value)> mapValues) {
- for (auto operandPair : llvm::zip(lhs, rhs)) {
- Value curArg = std::get<0>(operandPair);
- Value otherArg = std::get<1>(operandPair);
- if (curArg.getType() != otherArg.getType())
- return false;
- if (failed(mapValues(curArg, otherArg)))
- return false;
- }
- return true;
- };
- // Check mapping of operands and results.
- if (!checkValueRangeMapping(lhsOperands, rhsOperands, mapOperands))
- return false;
- if (!checkValueRangeMapping(lhs->getResults(), rhs->getResults(), mapResults))
- return false;
+
+ for (auto operandPair : llvm::zip(lhsOperands, rhsOperands)) {
+ Value curArg = std::get<0>(operandPair);
+ Value otherArg = std::get<1>(operandPair);
+ if (curArg == otherArg)
+ continue;
+ if (curArg.getType() != otherArg.getType())
+ return false;
+ if (failed(checkEquivalent(curArg, otherArg)))
+ return false;
+ }
+
+ // 3. Compare result types and mark results as equivalent.
+ for (auto resultPair : llvm::zip(lhs->getResults(), rhs->getResults())) {
+ Value curArg = std::get<0>(resultPair);
+ Value otherArg = std::get<1>(resultPair);
+ if (curArg.getType() != otherArg.getType())
+ return false;
+ if (markEquivalent)
+ markEquivalent(curArg, otherArg);
+ }
+
+ // 4. Compare regions.
for (auto regionPair : llvm::zip(lhs->getRegions(), rhs->getRegions()))
if (!isRegionEquivalentTo(&std::get<0>(regionPair),
- &std::get<1>(regionPair), mapOperands, mapResults,
- flags))
+ &std::get<1>(regionPair), checkEquivalent,
+ markEquivalent, flags))
return false;
+
return true;
}
+bool OperationEquivalence::isEquivalentTo(Operation *lhs, Operation *rhs,
+ Flags flags) {
+ // Equivalent values in lhs and rhs.
+ DenseMap<Value, Value> equivalentValues;
+ auto checkEquivalent = [&](Value lhsValue, Value rhsValue) -> LogicalResult {
+ return success(lhsValue == rhsValue ||
+ equivalentValues.lookup(lhsValue) == rhsValue);
+ };
+ auto markEquivalent = [&](Value lhsResult, Value rhsResult) {
+ auto insertion = equivalentValues.insert({lhsResult, rhsResult});
+ // Make sure that the value was not already marked equivalent to some other
+ // value.
+ (void)insertion;
+ assert(insertion.first->second == rhsResult &&
+ "inconsistent OperationEquivalence state");
+ };
+ return OperationEquivalence::isEquivalentTo(lhs, rhs, checkEquivalent,
+ markEquivalent, flags);
+}
+
//===----------------------------------------------------------------------===//
// OperationFingerPrint
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Transforms/CSE.cpp b/mlir/lib/Transforms/CSE.cpp
index cd0f90b411f0a..86debe7271fc1 100644
--- a/mlir/lib/Transforms/CSE.cpp
+++ b/mlir/lib/Transforms/CSE.cpp
@@ -47,70 +47,9 @@ struct SimpleOperationInfo : public llvm::DenseMapInfo<Operation *> {
if (lhs == getTombstoneKey() || lhs == getEmptyKey() ||
rhs == getTombstoneKey() || rhs == getEmptyKey())
return false;
-
- // If op has no regions, operation equivalence w.r.t operands alone is
- // enough.
- if (lhs->getNumRegions() == 0 && rhs->getNumRegions() == 0) {
- return OperationEquivalence::isEquivalentTo(
- const_cast<Operation *>(lhsC), const_cast<Operation *>(rhsC),
- OperationEquivalence::exactValueMatch,
- OperationEquivalence::ignoreValueEquivalence,
- OperationEquivalence::IgnoreLocations);
- }
-
- // If lhs or rhs does not have a single region with a single block, they
- // aren't CSEed for now.
- if (lhs->getNumRegions() != 1 || rhs->getNumRegions() != 1 ||
- !llvm::hasSingleElement(lhs->getRegion(0)) ||
- !llvm::hasSingleElement(rhs->getRegion(0)))
- return false;
-
- // Compare the two blocks.
- Block &lhsBlock = lhs->getRegion(0).front();
- Block &rhsBlock = rhs->getRegion(0).front();
-
- // Don't CSE if number of arguments
diff er.
- if (lhsBlock.getNumArguments() != rhsBlock.getNumArguments())
- return false;
-
- // Map to store `Value`s from `lhsBlock` that are equivalent to `Value`s in
- // `rhsBlock`. `Value`s from `lhsBlock` are the key.
- DenseMap<Value, Value> areEquivalentValues;
- for (auto bbArgs : llvm::zip(lhs->getRegion(0).getArguments(),
- rhs->getRegion(0).getArguments())) {
- areEquivalentValues[std::get<0>(bbArgs)] = std::get<1>(bbArgs);
- }
-
- // Helper function to get the parent operation.
- auto getParent = [](Value v) -> Operation * {
- if (auto blockArg = v.dyn_cast<BlockArgument>())
- return blockArg.getParentBlock()->getParentOp();
- return v.getDefiningOp()->getParentOp();
- };
-
- // Callback to compare if operands of ops in the region of `lhs` and `rhs`
- // are equivalent.
- auto mapOperands = [&](Value lhsValue, Value rhsValue) -> LogicalResult {
- if (lhsValue == rhsValue)
- return success();
- if (areEquivalentValues.lookup(lhsValue) == rhsValue)
- return success();
- return failure();
- };
-
- // Callback to compare if results of ops in the region of `lhs` and `rhs`
- // are equivalent.
- auto mapResults = [&](Value lhsResult, Value rhsResult) -> LogicalResult {
- if (getParent(lhsResult) == lhs && getParent(rhsResult) == rhs) {
- auto insertion = areEquivalentValues.insert({lhsResult, rhsResult});
- return success(insertion.first->second == rhsResult);
- }
- return success();
- };
-
return OperationEquivalence::isEquivalentTo(
const_cast<Operation *>(lhsC), const_cast<Operation *>(rhsC),
- mapOperands, mapResults, OperationEquivalence::IgnoreLocations);
+ OperationEquivalence::IgnoreLocations);
}
};
} // namespace
diff --git a/mlir/lib/Transforms/Utils/RegionUtils.cpp b/mlir/lib/Transforms/Utils/RegionUtils.cpp
index 996588243f565..e5824bd6663b0 100644
--- a/mlir/lib/Transforms/Utils/RegionUtils.cpp
+++ b/mlir/lib/Transforms/Utils/RegionUtils.cpp
@@ -493,7 +493,7 @@ LogicalResult BlockMergeCluster::addToCluster(BlockEquivalenceData &blockData) {
// Check that the operations are equivalent.
if (!OperationEquivalence::isEquivalentTo(
&*lhsIt, &*rhsIt, OperationEquivalence::ignoreValueEquivalence,
- OperationEquivalence::ignoreValueEquivalence,
+ /*markEquivalent=*/nullptr,
OperationEquivalence::Flags::IgnoreLocations))
return failure();
diff --git a/mlir/test/lib/IR/TestOperationEquals.cpp b/mlir/test/lib/IR/TestOperationEquals.cpp
index 6bfb59ef55533..ef3589636e529 100644
--- a/mlir/test/lib/IR/TestOperationEquals.cpp
+++ b/mlir/test/lib/IR/TestOperationEquals.cpp
@@ -28,11 +28,6 @@ struct TestOperationEqualPass
<< opCount;
return signalPassFailure();
}
- DenseMap<Value, Value> valuesMap;
- auto mapValue = [&](Value lhs, Value rhs) {
- auto insertion = valuesMap.insert({lhs, rhs});
- return success(insertion.first->second == rhs);
- };
Operation *first = &module.getBody()->front();
llvm::outs() << first->getName().getStringRef() << " with attr "
@@ -41,7 +36,7 @@ struct TestOperationEqualPass
if (!first->hasAttr("strict_loc_check"))
flags |= OperationEquivalence::IgnoreLocations;
if (OperationEquivalence::isEquivalentTo(first, &module.getBody()->back(),
- mapValue, mapValue, flags))
+ flags))
llvm::outs() << " compares equals.\n";
else
llvm::outs() << " compares NOT equals!\n";
More information about the Mlir-commits
mailing list