[Mlir-commits] [mlir] [mlir][OperationEquivalence] Add an extra callback to hook operation equivalence check (PR #73455)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sun Nov 26 10:59:31 PST 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-core

Author: Ivan Butygin (Hardcode84)

<details>
<summary>Changes</summary>

Rebase of https://reviews.llvm.org/D155577 by @<!-- -->uenoku

The only changes are merge conflict around `simpleOpEquivalence` func and `CFGToSCF.cpp` change.

Original description:

This patch adds an extract callback `checkOpStructureEquivalent` to a method `OperationEquivalence::isEquivalentTo` so that users can use customize equivalence regarding structural properties of operations.

I'm not sure that "structural properties" is a best term to describe operation information used here. I think just "properties" is better but I added "structural" to clearly distinguish from operation properties.

This patch is preparation to provide dialect interface for CSE https://discourse.llvm.org/t/rfc-dialectinterface-for-cse/71831

---
Full diff: https://github.com/llvm/llvm-project/pull/73455.diff


4 Files Affected:

- (modified) mlir/include/mlir/IR/OperationSupport.h (+38-10) 
- (modified) mlir/lib/IR/OperationSupport.cpp (+50-26) 
- (modified) mlir/lib/Transforms/Utils/CFGToSCF.cpp (+1) 
- (modified) mlir/lib/Transforms/Utils/RegionUtils.cpp (+2-1) 


``````````diff
diff --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h
index 6a5ec129ad564b4..d6c5d50efb7adc1 100644
--- a/mlir/include/mlir/IR/OperationSupport.h
+++ b/mlir/include/mlir/IR/OperationSupport.h
@@ -1236,16 +1236,31 @@ struct OperationEquivalence {
   };
 
   /// Compute a hash for the given operation.
-  /// The `hashOperands` and `hashResults` callbacks are expected to return a
-  /// unique hash_code for a given Value.
+  /// The `hashOp` is a callback to compute a hash for structural properties of
+  /// `op` such as op name, result types and attributes. The `hashOperands` and
+  /// `hashResults` callbacks are expected to return a unique hash_code for a
+  /// given Value.
   static llvm::hash_code computeHash(
-      Operation *op,
+      Operation *op, function_ref<llvm::hash_code(Operation *)> hashOp,
       function_ref<llvm::hash_code(Value)> hashOperands =
           [](Value v) { return hash_value(v); },
       function_ref<llvm::hash_code(Value)> hashResults =
           [](Value v) { return hash_value(v); },
       Flags flags = Flags::None);
 
+  static llvm::hash_code computeHash(
+      Operation *op,
+      function_ref<llvm::hash_code(Value)> hashOperands =
+          [](Value v) { return hash_value(v); },
+      function_ref<llvm::hash_code(Value)> hashResults =
+          [](Value v) { return hash_value(v); },
+      Flags flags = Flags::None) {
+    return computeHash(op, simpleOpHash, hashOperands, hashResults, flags);
+  }
+
+  /// Helper that can be used with `computeHash` above to combine hashes for
+  /// basic structural properties.
+  static llvm::hash_code simpleOpHash(Operation *op);
   /// Helper that can be used with `computeHash` above to ignore operation
   /// operands/result mapping.
   static llvm::hash_code ignoreHashValue(Value) { return llvm::hash_code{}; }
@@ -1255,38 +1270,51 @@ struct OperationEquivalence {
 
   /// Compare two operations (including their regions) and return if they are
   /// equivalent.
-  ///
-  /// * `checkEquivalent` is a callback to check if two values are equivalent.
+  /// * `checkOpStructureEquivalent` is a callback to check if the structures of
+  /// two operations are equivalent.
+  /// * `checkValueEquivalent` is a callback to check if two values are
+  /// equivalent.
   ///   For two operations to be equivalent, their operands must be the same SSA
   ///   value or this callback must return `success`.
   /// * `markEquivalent` is a callback to inform the caller that the analysis
   ///   determined that two values are equivalent.
   ///
   /// Note: Additional information regarding value equivalence can be injected
-  /// into the analysis via `checkEquivalent`. Typically, callers may want
-  /// values that were determined to be equivalent as per `markEquivalent` to be
-  /// reflected in `checkEquivalent`, unless `exactValueMatch` or a different
+  /// into the analysis via `checkOpStructureEquivalent` and
+  /// `checkValueEquivalent`. Typically, callers may want values that were
+  /// determined to be equivalent as per `markEquivalent` to be reflected in
+  /// `checkValueEquivalent`, unless `exactValueMatch` or a different
   /// equivalence relationship is desired.
   static bool
   isEquivalentTo(Operation *lhs, Operation *rhs,
-                 function_ref<LogicalResult(Value, Value)> checkEquivalent,
+                 function_ref<LogicalResult(Operation *, Operation *)>
+                     checkOpStructureEquivalent,
+                 function_ref<LogicalResult(Value, Value)> checkValueEquivalent,
                  function_ref<void(Value, Value)> markEquivalent = nullptr,
                  Flags flags = Flags::None);
 
   /// Compare two operations and return if they are equivalent.
   static bool isEquivalentTo(Operation *lhs, Operation *rhs, Flags flags);
+  static bool
+  isEquivalentTo(Operation *lhs, Operation *rhs,
+                 function_ref<LogicalResult(Operation *, Operation *)>
+                     checkOpStructureEquivalent,
+                 Flags flags);
 
   /// Compare two regions (including their subregions) and return if they are
   /// equivalent. See also `isEquivalentTo` for details.
   static bool isRegionEquivalentTo(
       Region *lhs, Region *rhs,
-      function_ref<LogicalResult(Value, Value)> checkEquivalent,
+      function_ref<LogicalResult(Operation *, Operation *)>
+          checkOpStructureEquivalent,
+      function_ref<LogicalResult(Value, Value)> checkValueEquivalent,
       function_ref<void(Value, Value)> markEquivalent,
       OperationEquivalence::Flags flags);
 
   /// Compare two regions and return if they are equivalent.
   static bool isRegionEquivalentTo(Region *lhs, Region *rhs,
                                    OperationEquivalence::Flags flags);
+  static LogicalResult simpleOpEquivalence(Operation *lhs, Operation *rhs);
 
   /// Helper that can be used with `isEquivalentTo` above to consider ops
   /// equivalent even if their operands are not equivalent.
diff --git a/mlir/lib/IR/OperationSupport.cpp b/mlir/lib/IR/OperationSupport.cpp
index fc5ccd23b5108d8..1d2189e530cdf60 100644
--- a/mlir/lib/IR/OperationSupport.cpp
+++ b/mlir/lib/IR/OperationSupport.cpp
@@ -668,15 +668,11 @@ Value ValueRange::dereference_iterator(const OwnerT &owner, ptrdiff_t index) {
 //===----------------------------------------------------------------------===//
 
 llvm::hash_code OperationEquivalence::computeHash(
-    Operation *op, function_ref<llvm::hash_code(Value)> hashOperands,
+    Operation *op, function_ref<llvm::hash_code(Operation *)> hashOp,
+    function_ref<llvm::hash_code(Value)> hashOperands,
     function_ref<llvm::hash_code(Value)> hashResults, Flags flags) {
-  // Hash operations based upon their:
-  //   - Operation Name
-  //   - Attributes
-  //   - Result Types
-  llvm::hash_code hash =
-      llvm::hash_combine(op->getName(), op->getDiscardableAttrDictionary(),
-                         op->getResultTypes(), op->hashProperties());
+  // Hash operations based upon their structural properties using `hashOp`.
+  llvm::hash_code hash = hashOp(op);
 
   //   - Location if required
   if (!(flags & Flags::IgnoreLocations))
@@ -694,7 +690,9 @@ llvm::hash_code OperationEquivalence::computeHash(
 
 /*static*/ bool OperationEquivalence::isRegionEquivalentTo(
     Region *lhs, Region *rhs,
-    function_ref<LogicalResult(Value, Value)> checkEquivalent,
+    function_ref<LogicalResult(Operation *, Operation *)>
+        checkOpStructureEquivalent,
+    function_ref<LogicalResult(Value, Value)> checkValueEquivalent,
     function_ref<void(Value, Value)> markEquivalent,
     OperationEquivalence::Flags flags) {
   DenseMap<Block *, Block *> blocksMap;
@@ -724,8 +722,9 @@ llvm::hash_code OperationEquivalence::computeHash(
 
     auto opsEquivalent = [&](Operation &lOp, Operation &rOp) {
       // Check for op equality (recursively).
-      if (!OperationEquivalence::isEquivalentTo(&lOp, &rOp, checkEquivalent,
-                                                markEquivalent, flags))
+      if (!OperationEquivalence::isEquivalentTo(
+              &lOp, &rOp, checkOpStructureEquivalent, checkValueEquivalent,
+              markEquivalent, flags))
         return false;
       // Check successor mapping.
       for (auto successorsPair :
@@ -766,7 +765,7 @@ OperationEquivalence::isRegionEquivalentTo(Region *lhs, Region *rhs,
                                            OperationEquivalence::Flags flags) {
   ValueEquivalenceCache cache;
   return isRegionEquivalentTo(
-      lhs, rhs,
+      lhs, rhs, simpleOpEquivalence,
       [&](Value lhsValue, Value rhsValue) -> LogicalResult {
         return cache.checkEquivalent(lhsValue, rhsValue);
       },
@@ -776,24 +775,39 @@ OperationEquivalence::isRegionEquivalentTo(Region *lhs, Region *rhs,
       flags);
 }
 
+/*static*/ llvm::hash_code OperationEquivalence::simpleOpHash(Operation *op) {
+  return llvm::hash_combine(op->getName(), op->getResultTypes(),
+                            op->hashProperties(),
+                            op->getDiscardableAttrDictionary());
+}
+
+/*static*/ LogicalResult
+OperationEquivalence::simpleOpEquivalence(Operation *lhs, Operation *rhs) {
+  return LogicalResult::success(
+      lhs->getName() == rhs->getName() &&
+      lhs->getDiscardableAttrDictionary() ==
+          rhs->getDiscardableAttrDictionary() &&
+      lhs->getNumRegions() == rhs->getNumRegions() &&
+      lhs->getNumSuccessors() == rhs->getNumSuccessors() &&
+      lhs->getNumOperands() == rhs->getNumOperands() &&
+      lhs->getNumResults() == rhs->getNumResults() &&
+      lhs->getName().compareOpProperties(lhs->getPropertiesStorage(),
+                                         rhs->getPropertiesStorage()));
+}
+
 /*static*/ bool OperationEquivalence::isEquivalentTo(
     Operation *lhs, Operation *rhs,
-    function_ref<LogicalResult(Value, Value)> checkEquivalent,
+    function_ref<LogicalResult(Operation *, Operation *)>
+        checkOpStructureEquivalent,
+    function_ref<LogicalResult(Value, Value)> checkValueEquivalent,
     function_ref<void(Value, Value)> markEquivalent, Flags flags) {
   if (lhs == rhs)
     return true;
 
-  // 1. Compare the operation properties.
-  if (lhs->getName() != rhs->getName() ||
-      lhs->getDiscardableAttrDictionary() !=
-          rhs->getDiscardableAttrDictionary() ||
-      lhs->getNumRegions() != rhs->getNumRegions() ||
-      lhs->getNumSuccessors() != rhs->getNumSuccessors() ||
-      lhs->getNumOperands() != rhs->getNumOperands() ||
-      lhs->getNumResults() != rhs->getNumResults() ||
-      !lhs->getName().compareOpProperties(lhs->getPropertiesStorage(),
-                                        rhs->getPropertiesStorage()))
+  // 1. Compare the operation structural properties.
+  if (failed(checkOpStructureEquivalent(lhs, rhs)))
     return false;
+
   if (!(flags & IgnoreLocations) && lhs->getLoc() != rhs->getLoc())
     return false;
 
@@ -805,7 +819,7 @@ OperationEquivalence::isRegionEquivalentTo(Region *lhs, Region *rhs,
       continue;
     if (curArg.getType() != otherArg.getType())
       return false;
-    if (failed(checkEquivalent(curArg, otherArg)))
+    if (failed(checkValueEquivalent(curArg, otherArg)))
       return false;
   }
 
@@ -822,7 +836,8 @@ OperationEquivalence::isRegionEquivalentTo(Region *lhs, Region *rhs,
   // 4. Compare regions.
   for (auto regionPair : llvm::zip(lhs->getRegions(), rhs->getRegions()))
     if (!isRegionEquivalentTo(&std::get<0>(regionPair),
-                              &std::get<1>(regionPair), checkEquivalent,
+                              &std::get<1>(regionPair),
+                              checkOpStructureEquivalent, checkValueEquivalent,
                               markEquivalent, flags))
       return false;
 
@@ -832,9 +847,18 @@ OperationEquivalence::isRegionEquivalentTo(Region *lhs, Region *rhs,
 /*static*/ bool OperationEquivalence::isEquivalentTo(Operation *lhs,
                                                      Operation *rhs,
                                                      Flags flags) {
+  return OperationEquivalence::isEquivalentTo(lhs, rhs, simpleOpEquivalence,
+                                              flags);
+}
+
+/*static*/ bool OperationEquivalence::isEquivalentTo(
+    Operation *lhs, Operation *rhs,
+    function_ref<LogicalResult(Operation *, Operation *)>
+        checkOpStructureEquivalent,
+    Flags flags) {
   ValueEquivalenceCache cache;
   return OperationEquivalence::isEquivalentTo(
-      lhs, rhs,
+      lhs, rhs, checkOpStructureEquivalent,
       [&](Value lhsValue, Value rhsValue) -> LogicalResult {
         return cache.checkEquivalent(lhsValue, rhsValue);
       },
diff --git a/mlir/lib/Transforms/Utils/CFGToSCF.cpp b/mlir/lib/Transforms/Utils/CFGToSCF.cpp
index f2998b4047e201e..9982015ba3de779 100644
--- a/mlir/lib/Transforms/Utils/CFGToSCF.cpp
+++ b/mlir/lib/Transforms/Utils/CFGToSCF.cpp
@@ -411,6 +411,7 @@ struct ReturnLikeOpEquivalence : public llvm::DenseMapInfo<Operation *> {
       return false;
     return OperationEquivalence::isEquivalentTo(
         const_cast<Operation *>(lhs), const_cast<Operation *>(rhs),
+        OperationEquivalence::simpleOpEquivalence,
         OperationEquivalence::ignoreValueEquivalence, nullptr,
         OperationEquivalence::IgnoreLocations);
   }
diff --git a/mlir/lib/Transforms/Utils/RegionUtils.cpp b/mlir/lib/Transforms/Utils/RegionUtils.cpp
index 1f2344677e6515c..e4e14baaeab7593 100644
--- a/mlir/lib/Transforms/Utils/RegionUtils.cpp
+++ b/mlir/lib/Transforms/Utils/RegionUtils.cpp
@@ -592,7 +592,8 @@ LogicalResult BlockMergeCluster::addToCluster(BlockEquivalenceData &blockData) {
   for (int opI = 0; lhsIt != lhsE && rhsIt != rhsE; ++lhsIt, ++rhsIt, ++opI) {
     // Check that the operations are equivalent.
     if (!OperationEquivalence::isEquivalentTo(
-            &*lhsIt, &*rhsIt, OperationEquivalence::ignoreValueEquivalence,
+            &*lhsIt, &*rhsIt, OperationEquivalence::simpleOpEquivalence,
+            OperationEquivalence::ignoreValueEquivalence,
             /*markEquivalent=*/nullptr,
             OperationEquivalence::Flags::IgnoreLocations))
       return failure();

``````````

</details>


https://github.com/llvm/llvm-project/pull/73455


More information about the Mlir-commits mailing list