[Mlir-commits] [mlir] df00e46 - [mlir] Move the operation equivalence out of CSE and into OperationSupport

River Riddle llvmlistbot at llvm.org
Wed Apr 29 16:49:44 PDT 2020


Author: River Riddle
Date: 2020-04-29T16:48:15-07:00
New Revision: df00e466daf59358725f977a578e85bae4c52765

URL: https://github.com/llvm/llvm-project/commit/df00e466daf59358725f977a578e85bae4c52765
DIFF: https://github.com/llvm/llvm-project/commit/df00e466daf59358725f977a578e85bae4c52765.diff

LOG: [mlir] Move the operation equivalence out of CSE and into OperationSupport

This provides a general hash and comparison for checking if two operations are equivalent. This revision also optimizes the handling of result types to take advantage of how result types are stored on the operation.

Differential Revision: https://reviews.llvm.org/D79029

Added: 
    

Modified: 
    mlir/include/mlir/IR/OperationSupport.h
    mlir/lib/IR/OperationSupport.cpp
    mlir/lib/Transforms/CSE.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h
index 1beb8d14151c..2214b5db2f20 100644
--- a/mlir/include/mlir/IR/OperationSupport.h
+++ b/mlir/include/mlir/IR/OperationSupport.h
@@ -814,6 +814,20 @@ class ValueRange final
   /// Allow access to `offset_base` and `dereference_iterator`.
   friend RangeBaseT;
 };
+
+//===----------------------------------------------------------------------===//
+// Operation Equivalency
+//===----------------------------------------------------------------------===//
+
+/// This class provides utilities for computing if two operations are
+/// equivalent.
+struct OperationEquivalence {
+  /// Compute a hash for the given operation.
+  static llvm::hash_code computeHash(Operation *op);
+
+  /// Compare two operations and return if they are equivalent.
+  static bool isEquivalentTo(Operation *lhs, Operation *rhs);
+};
 } // end namespace mlir
 
 namespace llvm {

diff  --git a/mlir/lib/IR/OperationSupport.cpp b/mlir/lib/IR/OperationSupport.cpp
index 087828e6e519..83b4f0bf176e 100644
--- a/mlir/lib/IR/OperationSupport.cpp
+++ b/mlir/lib/IR/OperationSupport.cpp
@@ -395,3 +395,78 @@ Value ValueRange::dereference_iterator(const OwnerT &owner, ptr
diff _t index) {
   Operation *operation = reinterpret_cast<Operation *>(owner.ptr.get<void *>());
   return operation->getResult(owner.startIndex + index);
 }
+
+//===----------------------------------------------------------------------===//
+// Operation Equivalency
+//===----------------------------------------------------------------------===//
+
+llvm::hash_code OperationEquivalence::computeHash(Operation *op) {
+  // Hash operations based upon their:
+  //   - Operation Name
+  //   - Attributes
+  llvm::hash_code hash = llvm::hash_combine(
+      op->getName(), op->getMutableAttrDict().getDictionary());
+
+  //   - Result Types
+  ArrayRef<Type> resultTypes = op->getResultTypes();
+  switch (resultTypes.size()) {
+  case 0:
+    // We don't need to add anything to the hash.
+    break;
+  case 1:
+    // Add in the result type.
+    hash = llvm::hash_combine(hash, resultTypes.front());
+    break;
+  default:
+    // Use the type buffer as the hash, as we can guarantee it is the same for
+    // any given range of result types. This takes advantage of the fact the
+    // result types >1 are stored in a TupleType and uniqued.
+    hash = llvm::hash_combine(hash, resultTypes.data());
+    break;
+  }
+
+  //   - Operands
+  // TODO: Allow commutative operations to have 
diff erent ordering.
+  return llvm::hash_combine(
+      hash, llvm::hash_combine_range(op->operand_begin(), op->operand_end()));
+}
+
+bool OperationEquivalence::isEquivalentTo(Operation *lhs, Operation *rhs) {
+  if (lhs == rhs)
+    return true;
+
+  // Compare the operation name.
+  if (lhs->getName() != rhs->getName())
+    return false;
+  // Check operand counts.
+  if (lhs->getNumOperands() != rhs->getNumOperands())
+    return false;
+  // Compare attributes.
+  if (lhs->getMutableAttrDict() != rhs->getMutableAttrDict())
+    return false;
+  // Compare result types.
+  ArrayRef<Type> lhsResultTypes = lhs->getResultTypes();
+  ArrayRef<Type> rhsResultTypes = rhs->getResultTypes();
+  if (lhsResultTypes.size() != rhsResultTypes.size())
+    return false;
+  switch (lhsResultTypes.size()) {
+  case 0:
+    break;
+  case 1:
+    // Compare the single result type.
+    if (lhsResultTypes.front() != rhsResultTypes.front())
+      return false;
+    break;
+  default:
+    // Use the type buffer for the comparison, as we can guarantee it is the
+    // same for any given range of result types. This takes advantage of the
+    // fact the result types >1 are stored in a TupleType and uniqued.
+    if (lhsResultTypes.data() != rhsResultTypes.data())
+      return false;
+    break;
+  }
+  // Compare operands.
+  // TODO: Allow commutative operations to have 
diff erent ordering.
+  return std::equal(lhs->operand_begin(), lhs->operand_end(),
+                    rhs->operand_begin());
+}

diff  --git a/mlir/lib/Transforms/CSE.cpp b/mlir/lib/Transforms/CSE.cpp
index 1a91849e6102..58ffc0991946 100644
--- a/mlir/lib/Transforms/CSE.cpp
+++ b/mlir/lib/Transforms/CSE.cpp
@@ -26,19 +26,9 @@
 using namespace mlir;
 
 namespace {
-// TODO(riverriddle) Handle commutative operations.
 struct SimpleOperationInfo : public llvm::DenseMapInfo<Operation *> {
   static unsigned getHashValue(const Operation *opC) {
-    auto *op = const_cast<Operation *>(opC);
-    // Hash the operations based upon their:
-    //   - Operation Name
-    //   - Attributes
-    //   - Result Types
-    //   - Operands
-    return llvm::hash_combine(
-        op->getName(), op->getMutableAttrDict().getDictionary(),
-        op->getResultTypes(),
-        llvm::hash_combine_range(op->operand_begin(), op->operand_end()));
+    return OperationEquivalence::computeHash(const_cast<Operation *>(opC));
   }
   static bool isEqual(const Operation *lhsC, const Operation *rhsC) {
     auto *lhs = const_cast<Operation *>(lhsC);
@@ -48,24 +38,8 @@ struct SimpleOperationInfo : public llvm::DenseMapInfo<Operation *> {
     if (lhs == getTombstoneKey() || lhs == getEmptyKey() ||
         rhs == getTombstoneKey() || rhs == getEmptyKey())
       return false;
-
-    // Compare the operation name.
-    if (lhs->getName() != rhs->getName())
-      return false;
-    // Check operand and result type counts.
-    if (lhs->getNumOperands() != rhs->getNumOperands() ||
-        lhs->getNumResults() != rhs->getNumResults())
-      return false;
-    // Compare attributes.
-    if (lhs->getMutableAttrDict() != rhs->getMutableAttrDict())
-      return false;
-    // Compare operands.
-    if (!std::equal(lhs->operand_begin(), lhs->operand_end(),
-                    rhs->operand_begin()))
-      return false;
-    // Compare result types.
-    return std::equal(lhs->result_type_begin(), lhs->result_type_end(),
-                      rhs->result_type_begin());
+    return OperationEquivalence::isEquivalentTo(const_cast<Operation *>(lhsC),
+                                                const_cast<Operation *>(rhsC));
   }
 };
 } // end anonymous namespace


        


More information about the Mlir-commits mailing list