[Mlir-commits] [mlir] 31fc47e - [MLIR] Expose region equivalence check through OperationEquivalence

Frederik Gossen llvmlistbot at llvm.org
Mon Feb 27 07:50:08 PST 2023


Author: Frederik Gossen
Date: 2023-02-27T10:49:51-05:00
New Revision: 31fc47e3ab8b6703166e27ac02b056ce1fec0fcc

URL: https://github.com/llvm/llvm-project/commit/31fc47e3ab8b6703166e27ac02b056ce1fec0fcc
DIFF: https://github.com/llvm/llvm-project/commit/31fc47e3ab8b6703166e27ac02b056ce1fec0fcc.diff

LOG: [MLIR] Expose region equivalence check through OperationEquivalence

Differential Revision: https://reviews.llvm.org/D144735

Added: 
    

Modified: 
    mlir/include/mlir/IR/OperationSupport.h
    mlir/lib/IR/OperationSupport.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h
index da95260508cb9..ebeb0a96523bd 100644
--- a/mlir/include/mlir/IR/OperationSupport.h
+++ b/mlir/include/mlir/IR/OperationSupport.h
@@ -949,6 +949,18 @@ struct OperationEquivalence {
   /// Compare two operations and return if they are equivalent.
   static bool isEquivalentTo(Operation *lhs, Operation *rhs, Flags flags);
 
+  /// Compare two regions (including their subregions) and return if they are
+  /// equivalent. See also `isEquivalentTo` for details.
+  static bool isRegionEquivalentTo(
+      Region *lhs, Region *rhs,
+      function_ref<LogicalResult(Value, Value)> checkEquivalent,
+      function_ref<void(Value, Value)> markEquivalent,
+      OperationEquivalence::Flags flags);
+
+  /// Compare two regions and return if they are equivalent.
+  static bool isRegionEquivalentTo(Region *lhs, Region *rhs,
+                                   OperationEquivalence::Flags flags);
+
   /// Helper that can be used with `isEquivalentTo` above to consider ops
   /// equivalent even if their operands are not equivalent.
   static LogicalResult ignoreValueEquivalence(Value lhs, Value rhs) {

diff  --git a/mlir/lib/IR/OperationSupport.cpp b/mlir/lib/IR/OperationSupport.cpp
index a38a12def567c..f6167be20cf83 100644
--- a/mlir/lib/IR/OperationSupport.cpp
+++ b/mlir/lib/IR/OperationSupport.cpp
@@ -655,11 +655,11 @@ llvm::hash_code OperationEquivalence::computeHash(
   return hash;
 }
 
-static bool
-isRegionEquivalentTo(Region *lhs, Region *rhs,
-                     function_ref<LogicalResult(Value, Value)> checkEquivalent,
-                     function_ref<void(Value, Value)> markEquivalent,
-                     OperationEquivalence::Flags flags) {
+/*static*/ bool OperationEquivalence::isRegionEquivalentTo(
+    Region *lhs, Region *rhs,
+    function_ref<LogicalResult(Value, Value)> checkEquivalent,
+    function_ref<void(Value, Value)> markEquivalent,
+    OperationEquivalence::Flags flags) {
   DenseMap<Block *, Block *> blocksMap;
   auto blocksEquivalent = [&](Block &lBlock, Block &rBlock) {
     // Check block arguments.
@@ -706,7 +706,40 @@ isRegionEquivalentTo(Region *lhs, Region *rhs,
   return llvm::all_of_zip(*lhs, *rhs, blocksEquivalent);
 }
 
-bool OperationEquivalence::isEquivalentTo(
+// Value equivalence cache to be used with `isRegionEquivalentTo` and
+// `isEquivalentTo`.
+struct ValueEquivalenceCache {
+  DenseMap<Value, Value> equivalentValues;
+  LogicalResult checkEquivalent(Value lhsValue, Value rhsValue) {
+    return success(lhsValue == rhsValue ||
+                   equivalentValues.lookup(lhsValue) == rhsValue);
+  }
+  void markEquivalent(Value lhsResult, Value rhsResult) {
+    auto insertion = equivalentValues.insert({lhsResult, rhsResult});
+    // Make sure that the value was not already marked equivalent to some other
+    // value.
+    (void)insertion;
+    assert(insertion.first->second == rhsResult &&
+           "inconsistent OperationEquivalence state");
+  }
+};
+
+/*static*/ bool
+OperationEquivalence::isRegionEquivalentTo(Region *lhs, Region *rhs,
+                                           OperationEquivalence::Flags flags) {
+  ValueEquivalenceCache cache;
+  return isRegionEquivalentTo(
+      lhs, rhs,
+      [&](Value lhsValue, Value rhsValue) -> LogicalResult {
+        return cache.checkEquivalent(lhsValue, rhsValue);
+      },
+      [&](Value lhsResult, Value rhsResult) {
+        cache.markEquivalent(lhsResult, rhsResult);
+      },
+      flags);
+}
+
+/*static*/ bool OperationEquivalence::isEquivalentTo(
     Operation *lhs, Operation *rhs,
     function_ref<LogicalResult(Value, Value)> checkEquivalent,
     function_ref<void(Value, Value)> markEquivalent, Flags flags) {
@@ -790,24 +823,19 @@ bool OperationEquivalence::isEquivalentTo(
   return true;
 }
 
-bool OperationEquivalence::isEquivalentTo(Operation *lhs, Operation *rhs,
-                                          Flags flags) {
-  // Equivalent values in lhs and rhs.
-  DenseMap<Value, Value> equivalentValues;
-  auto checkEquivalent = [&](Value lhsValue, Value rhsValue) -> LogicalResult {
-    return success(lhsValue == rhsValue ||
-                   equivalentValues.lookup(lhsValue) == rhsValue);
-  };
-  auto markEquivalent = [&](Value lhsResult, Value rhsResult) {
-    auto insertion = equivalentValues.insert({lhsResult, rhsResult});
-    // Make sure that the value was not already marked equivalent to some other
-    // value.
-    (void)insertion;
-    assert(insertion.first->second == rhsResult &&
-           "inconsistent OperationEquivalence state");
-  };
-  return OperationEquivalence::isEquivalentTo(lhs, rhs, checkEquivalent,
-                                              markEquivalent, flags);
+/*static*/ bool OperationEquivalence::isEquivalentTo(Operation *lhs,
+                                                     Operation *rhs,
+                                                     Flags flags) {
+  ValueEquivalenceCache cache;
+  return OperationEquivalence::isEquivalentTo(
+      lhs, rhs,
+      [&](Value lhsValue, Value rhsValue) -> LogicalResult {
+        return cache.checkEquivalent(lhsValue, rhsValue);
+      },
+      [&](Value lhsResult, Value rhsResult) {
+        cache.markEquivalent(lhsResult, rhsResult);
+      },
+      flags);
 }
 
 //===----------------------------------------------------------------------===//


        


More information about the Mlir-commits mailing list