[Mlir-commits] [mlir] 0be5d1a - Implement recursive support into OperationEquivalence::isEquivalentTo()

Mehdi Amini llvmlistbot at llvm.org
Wed Jul 28 22:12:10 PDT 2021


Author: Mehdi Amini
Date: 2021-07-29T05:06:37Z
New Revision: 0be5d1a96c8973a4aa56b3fdd8fc22c8a95a7171

URL: https://github.com/llvm/llvm-project/commit/0be5d1a96c8973a4aa56b3fdd8fc22c8a95a7171
DIFF: https://github.com/llvm/llvm-project/commit/0be5d1a96c8973a4aa56b3fdd8fc22c8a95a7171.diff

LOG: Implement recursive support into OperationEquivalence::isEquivalentTo()

This allows to use OperationEquivalence to track structural comparison for equality
between two operations.

Differential Revision: https://reviews.llvm.org/D106422

Added: 
    mlir/test/IR/operation-equality.mlir
    mlir/test/lib/IR/TestOperationEquals.cpp

Modified: 
    mlir/include/mlir/IR/OperationSupport.h
    mlir/lib/IR/OperationSupport.cpp
    mlir/lib/Transforms/CSE.cpp
    mlir/lib/Transforms/Utils/RegionUtils.cpp
    mlir/test/lib/IR/CMakeLists.txt
    mlir/tools/mlir-opt/mlir-opt.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h
index 43be1ab67d9ba..111e2b8d0bc0c 100644
--- a/mlir/include/mlir/IR/OperationSupport.h
+++ b/mlir/include/mlir/IR/OperationSupport.h
@@ -901,21 +901,51 @@ struct OperationEquivalence {
   enum Flags {
     None = 0,
 
-    /// This flag signals that operands should not be considered when checking
-    /// for equivalence. This allows for users to implement there own
-    /// equivalence schemes for operand values. The number of operands are still
-    /// checked, just not the operands themselves.
-    IgnoreOperands = 1,
+    // When provided, the location attached to the operation are ignored.
+    IgnoreLocations = 1,
 
-    LLVM_MARK_AS_BITMASK_ENUM(/* LargestValue = */ IgnoreOperands)
+    LLVM_MARK_AS_BITMASK_ENUM(/* LargestValue = */ IgnoreLocations)
   };
 
   /// Compute a hash for the given operation.
-  static llvm::hash_code computeHash(Operation *op, Flags flags = Flags::None);
+  /// The `hashOperands` and `hashResults` callbacks are expected to return a
+  /// unique hash_code for a given Value.
+  static llvm::hash_code computeHash(
+      Operation *op,
+      function_ref<llvm::hash_code(Value)> hashOperands =
+          [](Value v) { return hash_value(v); },
+      function_ref<llvm::hash_code(Value)> hashResults =
+          [](Value v) { return hash_value(v); },
+      Flags flags = Flags::None);
+
+  /// Helper that can be used with `computeHash` above to ignore operation
+  /// operands/result mapping.
+  static llvm::hash_code ignoreHashValue(Value) { return llvm::hash_code{}; }
+  /// Helper that can be used with `computeHash` above to ignore operation
+  /// operands/result mapping.
+  static llvm::hash_code directHashValue(Value v) { return hash_value(v); }
 
   /// Compare two operations and return if they are equivalent.
-  static bool isEquivalentTo(Operation *lhs, Operation *rhs,
-                             Flags flags = Flags::None);
+  /// `mapOperands` and `mapResults` are optional callbacks that allows the
+  /// caller to check the mapping of SSA value between the lhs and rhs
+  /// operations. It is expected to return success if the mapping is valid and
+  /// failure if it conflicts with a previous mapping.
+  static bool
+  isEquivalentTo(Operation *lhs, Operation *rhs,
+                 function_ref<LogicalResult(Value, Value)> mapOperands,
+                 function_ref<LogicalResult(Value, Value)> mapResults,
+                 Flags flags = Flags::None);
+
+  /// Helper that can be used with `isEquivalentTo` above to ignore operation
+  /// operands/result mapping.
+  static LogicalResult ignoreValueEquivalence(Value lhs, Value rhs) {
+    return success();
+  }
+  /// Helper that can be used with `isEquivalentTo` above to ignore operation
+  /// operands/result mapping.
+  static LogicalResult exactValueMatch(Value lhs, Value rhs) {
+    return success(lhs == rhs);
+  }
 };
 
 /// Enable Bitmask enums for OperationEquivalence::Flags.

diff  --git a/mlir/lib/IR/OperationSupport.cpp b/mlir/lib/IR/OperationSupport.cpp
index cfa1302331714..bb9a5603f8714 100644
--- a/mlir/lib/IR/OperationSupport.cpp
+++ b/mlir/lib/IR/OperationSupport.cpp
@@ -522,7 +522,9 @@ Value ValueRange::dereference_iterator(const OwnerT &owner, ptr
diff _t index) {
 // Operation Equivalency
 //===----------------------------------------------------------------------===//
 
-llvm::hash_code OperationEquivalence::computeHash(Operation *op, Flags flags) {
+llvm::hash_code OperationEquivalence::computeHash(
+    Operation *op, function_ref<llvm::hash_code(Value)> hashOperands,
+    function_ref<llvm::hash_code(Value)> hashResults, Flags flags) {
   // Hash operations based upon their:
   //   - Operation Name
   //   - Attributes
@@ -531,37 +533,106 @@ llvm::hash_code OperationEquivalence::computeHash(Operation *op, Flags flags) {
       op->getName(), op->getAttrDictionary(), op->getResultTypes());
 
   //   - Operands
-  bool ignoreOperands = flags & Flags::IgnoreOperands;
-  if (!ignoreOperands) {
-    // TODO: Allow commutative operations to have 
diff erent ordering.
-    hash = llvm::hash_combine(
-        hash, llvm::hash_combine_range(op->operand_begin(), op->operand_end()));
-  }
+  for (Value operand : op->getOperands())
+    hash = llvm::hash_combine(hash, hashOperands(operand));
+  //   - Operands
+  for (Value result : op->getResults())
+    hash = llvm::hash_combine(hash, hashResults(result));
   return hash;
 }
 
-bool OperationEquivalence::isEquivalentTo(Operation *lhs, Operation *rhs,
-                                          Flags flags) {
+static bool
+isRegionEquivalentTo(Region *lhs, Region *rhs,
+                     function_ref<LogicalResult(Value, Value)> mapOperands,
+                     function_ref<LogicalResult(Value, Value)> mapResults,
+                     OperationEquivalence::Flags flags) {
+  DenseMap<Block *, Block *> blocksMap;
+  auto blocksEquivalent = [&](Block &lBlock, Block &rBlock) {
+    // Check block arguments.
+    if (lBlock.getNumArguments() != rBlock.getNumArguments())
+      return false;
+
+    // Map the two blocks.
+    auto insertion = blocksMap.insert({&lBlock, &rBlock});
+    if (insertion.first->getSecond() != &rBlock)
+      return false;
+
+    for (auto argPair :
+         llvm::zip(lBlock.getArguments(), rBlock.getArguments())) {
+      Value curArg = std::get<0>(argPair);
+      Value otherArg = std::get<1>(argPair);
+      if (curArg.getType() != otherArg.getType())
+        return false;
+      if (!(flags & OperationEquivalence::IgnoreLocations) &&
+          curArg.getLoc() != otherArg.getLoc())
+        return false;
+      // Check if this value was already mapped to another value.
+      if (failed(mapOperands(curArg, otherArg)))
+        return false;
+    }
+
+    auto opsEquivalent = [&](Operation &lOp, Operation &rOp) {
+      // Check for op equality (recursively).
+      if (!OperationEquivalence::isEquivalentTo(&lOp, &rOp, mapOperands,
+                                                mapResults, flags))
+        return false;
+      // Check successor mapping.
+      for (auto successorsPair :
+           llvm::zip(lOp.getSuccessors(), rOp.getSuccessors())) {
+        Block *curSuccessor = std::get<0>(successorsPair);
+        Block *otherSuccessor = std::get<1>(successorsPair);
+        auto insertion = blocksMap.insert({curSuccessor, otherSuccessor});
+        if (insertion.first->getSecond() != otherSuccessor)
+          return false;
+      }
+      return true;
+    };
+    return llvm::all_of_zip(lBlock, rBlock, opsEquivalent);
+  };
+  return llvm::all_of_zip(*lhs, *rhs, blocksEquivalent);
+}
+
+bool OperationEquivalence::isEquivalentTo(
+    Operation *lhs, Operation *rhs,
+    function_ref<LogicalResult(Value, Value)> mapOperands,
+    function_ref<LogicalResult(Value, Value)> mapResults, Flags flags) {
   if (lhs == rhs)
     return true;
 
-  // Compare the operation name.
-  if (lhs->getName() != rhs->getName())
+  // Compare the operation properties.
+  if (lhs->getName() != rhs->getName() ||
+      lhs->getAttrDictionary() != rhs->getAttrDictionary() ||
+      lhs->getNumRegions() != rhs->getNumRegions() ||
+      lhs->getNumSuccessors() != rhs->getNumSuccessors() ||
+      lhs->getNumOperands() != rhs->getNumOperands() ||
+      lhs->getNumResults() != rhs->getNumResults())
     return false;
-  // Check operand counts.
-  if (lhs->getNumOperands() != rhs->getNumOperands())
+  if (!(flags & IgnoreLocations) && lhs->getLoc() != rhs->getLoc())
     return false;
-  // Compare attributes.
-  if (lhs->getAttrDictionary() != rhs->getAttrDictionary())
+
+  auto checkValueRangeMapping =
+      [](ValueRange lhs, ValueRange rhs,
+         function_ref<LogicalResult(Value, Value)> mapValues) {
+        for (auto operandPair : llvm::zip(lhs, rhs)) {
+          Value curArg = std::get<0>(operandPair);
+          Value otherArg = std::get<1>(operandPair);
+          if (curArg.getType() != otherArg.getType())
+            return false;
+          if (failed(mapValues(curArg, otherArg)))
+            return false;
+        }
+        return true;
+      };
+  // Check mapping of operands and results.
+  if (!checkValueRangeMapping(lhs->getOperands(), rhs->getOperands(),
+                              mapOperands))
     return false;
-  // Compare result types.
-  if (lhs->getResultTypes() != rhs->getResultTypes())
+  if (!checkValueRangeMapping(lhs->getResults(), rhs->getResults(), mapResults))
     return false;
-  // Compare operands.
-  bool ignoreOperands = flags & Flags::IgnoreOperands;
-  if (ignoreOperands)
-    return true;
-  // TODO: Allow commutative operations to have 
diff erent ordering.
-  return std::equal(lhs->operand_begin(), lhs->operand_end(),
-                    rhs->operand_begin());
+  for (auto regionPair : llvm::zip(lhs->getRegions(), rhs->getRegions()))
+    if (!isRegionEquivalentTo(&std::get<0>(regionPair),
+                              &std::get<1>(regionPair), mapOperands, mapResults,
+                              flags))
+      return false;
+  return true;
 }

diff  --git a/mlir/lib/Transforms/CSE.cpp b/mlir/lib/Transforms/CSE.cpp
index 4b2ba0e9ce487..c7b09b5f1a180 100644
--- a/mlir/lib/Transforms/CSE.cpp
+++ b/mlir/lib/Transforms/CSE.cpp
@@ -28,7 +28,11 @@ using namespace mlir;
 namespace {
 struct SimpleOperationInfo : public llvm::DenseMapInfo<Operation *> {
   static unsigned getHashValue(const Operation *opC) {
-    return OperationEquivalence::computeHash(const_cast<Operation *>(opC));
+    return OperationEquivalence::computeHash(
+        const_cast<Operation *>(opC),
+        /*hashOperands=*/OperationEquivalence::directHashValue,
+        /*hashResults=*/OperationEquivalence::ignoreHashValue,
+        OperationEquivalence::IgnoreLocations);
   }
   static bool isEqual(const Operation *lhsC, const Operation *rhsC) {
     auto *lhs = const_cast<Operation *>(lhsC);
@@ -38,8 +42,11 @@ struct SimpleOperationInfo : public llvm::DenseMapInfo<Operation *> {
     if (lhs == getTombstoneKey() || lhs == getEmptyKey() ||
         rhs == getTombstoneKey() || rhs == getEmptyKey())
       return false;
-    return OperationEquivalence::isEquivalentTo(const_cast<Operation *>(lhsC),
-                                                const_cast<Operation *>(rhsC));
+    return OperationEquivalence::isEquivalentTo(
+        const_cast<Operation *>(lhsC), const_cast<Operation *>(rhsC),
+        /*mapOperands=*/OperationEquivalence::exactValueMatch,
+        /*mapResults=*/OperationEquivalence::ignoreValueEquivalence,
+        OperationEquivalence::IgnoreLocations);
   }
 };
 } // end anonymous namespace

diff  --git a/mlir/lib/Transforms/Utils/RegionUtils.cpp b/mlir/lib/Transforms/Utils/RegionUtils.cpp
index 7dedb7df9afe2..26868d17a66e0 100644
--- a/mlir/lib/Transforms/Utils/RegionUtils.cpp
+++ b/mlir/lib/Transforms/Utils/RegionUtils.cpp
@@ -428,7 +428,9 @@ BlockEquivalenceData::BlockEquivalenceData(Block *block)
       orderIt += numResults;
     }
     auto opHash = OperationEquivalence::computeHash(
-        &op, OperationEquivalence::Flags::IgnoreOperands);
+        &op, OperationEquivalence::ignoreHashValue,
+        OperationEquivalence::ignoreHashValue,
+        OperationEquivalence::IgnoreLocations);
     hash = llvm::hash_combine(hash, opHash);
   }
 }
@@ -491,7 +493,9 @@ LogicalResult BlockMergeCluster::addToCluster(BlockEquivalenceData &blockData) {
   for (int opI = 0; lhsIt != lhsE && rhsIt != rhsE; ++lhsIt, ++rhsIt, ++opI) {
     // Check that the operations are equivalent.
     if (!OperationEquivalence::isEquivalentTo(
-            &*lhsIt, &*rhsIt, OperationEquivalence::Flags::IgnoreOperands))
+            &*lhsIt, &*rhsIt, OperationEquivalence::ignoreValueEquivalence,
+            OperationEquivalence::ignoreValueEquivalence,
+            OperationEquivalence::Flags::IgnoreLocations))
       return failure();
 
     // Compare the operands of the two operations. If the operand is within

diff  --git a/mlir/test/IR/operation-equality.mlir b/mlir/test/IR/operation-equality.mlir
new file mode 100644
index 0000000000000..f382d7d0fbf1b
--- /dev/null
+++ b/mlir/test/IR/operation-equality.mlir
@@ -0,0 +1,186 @@
+// RUN: mlir-opt %s -split-input-file --test-operations-equality | FileCheck %s
+
+
+// CHECK-LABEL: test.top_level_op
+// CHECK-SAME: compares equals
+
+"test.top_level_op"() : () -> ()
+"test.top_level_op"() : () -> ()
+
+// -----
+
+// CHECK-LABEL: test.top_level_op_strict_loc
+// CHECK-SAME: compares NOT equals
+
+"test.top_level_op_strict_loc"() { strict_loc_check } : () -> ()
+"test.top_level_op_strict_loc"() { strict_loc_check } : () -> ()
+
+// -----
+
+// CHECK-LABEL: test.top_level_op_loc_match
+// CHECK-SAME: compares equals
+
+"test.top_level_op_loc_match"() { strict_loc_check } : () -> () loc("foo")
+"test.top_level_op_loc_match"() { strict_loc_check } : () -> () loc("foo")
+
+// -----
+
+// CHECK-LABEL: test.top_level_op_block_loc_mismatch
+// CHECK-SAME: compares NOT equals
+
+"test.top_level_op_block_loc_mismatch"() ({
+ ^bb0(%a : i32):
+}) { strict_loc_check } : () -> () loc("foo")
+"test.top_level_op_block_loc_mismatch"() ({
+ ^bb0(%a : i32):
+}) { strict_loc_check } : () -> () loc("foo")
+
+// -----
+
+// CHECK-LABEL: test.top_level_op_block_loc_match
+// CHECK-SAME: compares equals
+
+"test.top_level_op_block_loc_match"() ({
+ ^bb0(%a : i32 loc("bar")):
+}) { strict_loc_check } : () -> () loc("foo")
+"test.top_level_op_block_loc_match"() ({
+ ^bb0(%a : i32 loc("bar")):
+}) { strict_loc_check } : () -> () loc("foo")
+
+// -----
+
+// CHECK-LABEL: test.top_level_name_mismatch
+// CHECK-SAME: compares NOT equals
+
+"test.top_level_name_mismatch"() : () -> ()
+"test.top_level_name_mismatch2"() : () -> ()
+
+// -----
+
+// CHECK-LABEL: test.top_level_op_attr_mismatch
+// CHECK-SAME: compares NOT equals
+
+"test.top_level_op_attr_mismatch"() { foo = "bar" } : () -> ()
+"test.top_level_op_attr_mismatch"() { foo = "bar2"} : () -> ()
+
+// -----
+
+// CHECK-LABEL: test.top_level_op_cfg
+// CHECK-SAME: compares equals
+
+"test.top_level_op_cfg"() ({
+  ^bb0(%arg0 : i32, %arg1 : f32):
+    "test.some_branching_op"(%arg1, %arg0) [^bb1, ^bb2] : (f32, i32) -> ()
+  ^bb1(%arg2 : f32):
+    "test.some_branching_op"() : () -> ()
+  ^bb2(%arg3 : i32):
+    "test.some_branching_op"() : () -> ()
+  }, {
+  ^bb0(%arg0 : i32, %arg1 : f32):
+    "test.some_branching_op"(%arg1, %arg0) [^bb1, ^bb2] : (f32, i32) -> ()
+  ^bb1(%arg2 : f32):
+    "test.some_branching_op"() : () -> ()
+  ^bb2(%arg3 : i32):
+    "test.some_branching_op"() : () -> ()
+  })
+   { attr = "foo" } : () -> ()
+"test.top_level_op_cfg"() ({
+  ^bb0(%arg0 : i32, %arg1 : f32):
+    "test.some_branching_op"(%arg1, %arg0) [^bb1, ^bb2] : (f32, i32) -> ()
+  ^bb1(%arg2 : f32):
+    "test.some_branching_op"() : () -> ()
+  ^bb2(%arg3 : i32):
+    "test.some_branching_op"() : () -> ()
+  }, {
+  ^bb0(%arg0 : i32, %arg1 : f32):
+    "test.some_branching_op"(%arg1, %arg0) [^bb1, ^bb2] : (f32, i32) -> ()
+  ^bb1(%arg2 : f32):
+    "test.some_branching_op"() : () -> ()
+  ^bb2(%arg3 : i32):
+    "test.some_branching_op"() : () -> ()
+  })
+   { attr = "foo" } : () -> ()
+
+// -----
+
+// CHECK-LABEL: test.operand_num_mismatch
+// CHECK-SAME: compares NOT equals
+
+"test.operand_num_mismatch"() ({
+  ^bb0(%arg0 : i32, %arg1 : f32):
+    "test.some_branching_op"(%arg1, %arg0) : (f32, i32) -> ()
+  }) : () -> ()
+"test.operand_num_mismatch"() ({
+  ^bb0(%arg0 : i32, %arg1 : f32):
+    "test.some_branching_op"(%arg1) : (f32) -> ()
+  }) : () -> ()
+
+// -----
+
+// CHECK-LABEL: test.operand_type_mismatch
+// CHECK-SAME: compares NOT equals
+
+"test.operand_type_mismatch"() ({
+  ^bb0(%arg0 : i32, %arg1 : f32):
+    "test.some_branching_op"(%arg1, %arg0) : (f32, i32) -> ()
+  }) : () -> ()
+"test.operand_type_mismatch"() ({
+  ^bb0(%arg0 : i32, %arg1 : f32):
+    "test.some_branching_op"(%arg1, %arg1) : (f32, f32) -> ()
+  }) : () -> ()
+
+// -----
+
+// CHECK-LABEL: test.block_type_mismatch
+// CHECK-SAME: compares NOT equals
+
+"test.block_type_mismatch"() ({
+  ^bb0(%arg0 : f32, %arg1 : f32):
+    "test.some_branching_op"() : () -> ()
+  }) : () -> ()
+"test.block_type_mismatch"() ({
+  ^bb0(%arg0 : i32, %arg1 : f32):
+    "test.some_branching_op"() : () -> ()
+  }) : () -> ()
+
+// -----
+
+// CHECK-LABEL: test.block_arg_num_mismatch
+// CHECK-SAME: compares NOT equals
+
+"test.block_arg_num_mismatch"() ({
+  ^bb0(%arg0 : f32, %arg1 : f32):
+    "test.some_branching_op"() : () -> ()
+  }) : () -> ()
+"test.block_arg_num_mismatch"() ({
+  ^bb0(%arg0 : f32):
+    "test.some_branching_op"() : () -> ()
+  }) : () -> ()
+
+// -----
+
+// CHECK-LABEL: test.dataflow_match
+// CHECK-SAME: compares equals
+
+"test.dataflow_match"() ({
+  %0:2 = "test.producer"() : () -> (i32, i32)
+  "test.consumer"(%0#0, %0#1) : (i32, i32) -> ()
+  }) : () -> ()
+"test.dataflow_match"() ({
+  %0:2 = "test.producer"() : () -> (i32, i32)
+  "test.consumer"(%0#0, %0#1) : (i32, i32) -> ()
+  }) : () -> ()
+
+// -----
+
+// CHECK-LABEL: test.dataflow_mismatch
+// CHECK-SAME: compares NOT equals
+
+"test.dataflow_mismatch"() ({
+  %0:2 = "test.producer"() : () -> (i32, i32)
+  "test.consumer"(%0#0, %0#1) : (i32, i32) -> ()
+  }) : () -> ()
+"test.dataflow_mismatch"() ({
+  %0:2 = "test.producer"() : () -> (i32, i32)
+  "test.consumer"(%0#1, %0#0) : (i32, i32) -> ()
+  }) : () -> ()

diff  --git a/mlir/test/lib/IR/CMakeLists.txt b/mlir/test/lib/IR/CMakeLists.txt
index e809ea1fc0b4d..07d6e4f505c7d 100644
--- a/mlir/test/lib/IR/CMakeLists.txt
+++ b/mlir/test/lib/IR/CMakeLists.txt
@@ -6,6 +6,7 @@ add_mlir_library(MLIRTestIR
   TestInterfaces.cpp
   TestMatchers.cpp
   TestOpaqueLoc.cpp
+  TestOperationEquals.cpp
   TestPrintDefUse.cpp
   TestPrintNesting.cpp
   TestSideEffects.cpp

diff  --git a/mlir/test/lib/IR/TestOperationEquals.cpp b/mlir/test/lib/IR/TestOperationEquals.cpp
new file mode 100644
index 0000000000000..0ef6dfe9e84ba
--- /dev/null
+++ b/mlir/test/lib/IR/TestOperationEquals.cpp
@@ -0,0 +1,55 @@
+//===- TestOperationEquals.cpp - Passes to test OperationEquivalence ------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/Pass/Pass.h"
+
+using namespace mlir;
+
+namespace {
+/// This pass illustrates the IR def-use chains through printing.
+struct TestOperationEqualPass
+    : public PassWrapper<TestOperationEqualPass, OperationPass<ModuleOp>> {
+  StringRef getArgument() const final { return "test-operations-equality"; }
+  StringRef getDescription() const final { return "Test operations equality."; }
+  void runOnOperation() override {
+    ModuleOp module = getOperation();
+    // Expects two operations at the top-level:
+    int opCount = module.getBody()->getOperations().size();
+    if (opCount != 2) {
+      module.emitError() << "expected 2 top-level ops in the module, got "
+                         << opCount;
+      return signalPassFailure();
+    }
+    DenseMap<Value, Value> valuesMap;
+    auto mapValue = [&](Value lhs, Value rhs) {
+      auto insertion = valuesMap.insert({lhs, rhs});
+      return success(insertion.first->second == rhs);
+    };
+
+    Operation *first = &module.getBody()->front();
+    llvm::outs() << first->getName().getStringRef() << " with attr "
+                 << first->getAttrDictionary();
+    OperationEquivalence::Flags flags{};
+    if (!first->hasAttr("strict_loc_check"))
+      flags |= OperationEquivalence::IgnoreLocations;
+    if (OperationEquivalence::isEquivalentTo(first, &module.getBody()->back(),
+                                             mapValue, mapValue, flags))
+      llvm::outs() << " compares equals.\n";
+    else
+      llvm::outs() << " compares NOT equals!\n";
+  }
+};
+} // end anonymous namespace
+
+namespace mlir {
+void registerTestOperationEqualPass() {
+  PassRegistration<TestOperationEqualPass>();
+}
+} // namespace mlir

diff  --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index c5be1daec8304..efbd9ed883b1d 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -43,6 +43,7 @@ void registerTestFunc();
 void registerTestGpuMemoryPromotionPass();
 void registerTestLoopPermutationPass();
 void registerTestMatchers();
+void registerTestOperationEqualPass();
 void registerTestPrintDefUsePass();
 void registerTestPrintNestingPass();
 void registerTestReducer();
@@ -123,6 +124,7 @@ void registerTestPasses() {
   registerTestGpuMemoryPromotionPass();
   registerTestLoopPermutationPass();
   registerTestMatchers();
+  registerTestOperationEqualPass();
   registerTestPrintDefUsePass();
   registerTestPrintNestingPass();
   registerTestReducer();


        


More information about the Mlir-commits mailing list