[Mlir-commits] [mlir] [mlir][CSE] Introduce DialectInterface for CSE (PR #73520)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Nov 27 06:22:22 PST 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Ivan Butygin (Hardcode84)

<details>
<summary>Changes</summary>

Based on https://reviews.llvm.org/D154512 by @<!-- -->uenoku

Depends on https://github.com/llvm/llvm-project/pull/73455 , review second commit

Changed interface return types to `std::optional` so interface can override equivalence for some ops and rely on default implementation for others.

Original description:

This patch implements CSE dialect interfaces proposed in https://discourse.llvm.org/t/rfc-dialectinterface-for-cse/71831/9.

CSEDialectInterface has three methods getHashValue and isEqual are passed to DenseMapInfo used by CSE to provide a way to customize CSE behaivor from user sides. mergeOperations is a callback when an operation is eliminated in order to propagate operation information one to another.

---

Patch is 23.39 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/73520.diff


8 Files Affected:

- (modified) mlir/include/mlir/IR/OperationSupport.h (+38-10) 
- (modified) mlir/include/mlir/Transforms/CSE.h (+37) 
- (modified) mlir/lib/IR/OperationSupport.cpp (+50-26) 
- (modified) mlir/lib/Transforms/CSE.cpp (+33-1) 
- (modified) mlir/lib/Transforms/Utils/CFGToSCF.cpp (+1) 
- (modified) mlir/lib/Transforms/Utils/RegionUtils.cpp (+2-1) 
- (modified) mlir/test/Transforms/cse.mlir (+11) 
- (modified) mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp (+87-1) 


``````````diff
diff --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h
index 6a5ec129ad564b4..d6c5d50efb7adc1 100644
--- a/mlir/include/mlir/IR/OperationSupport.h
+++ b/mlir/include/mlir/IR/OperationSupport.h
@@ -1236,16 +1236,31 @@ struct OperationEquivalence {
   };
 
   /// Compute a hash for the given operation.
-  /// The `hashOperands` and `hashResults` callbacks are expected to return a
-  /// unique hash_code for a given Value.
+  /// The `hashOp` is a callback to compute a hash for structural properties of
+  /// `op` such as op name, result types and attributes. The `hashOperands` and
+  /// `hashResults` callbacks are expected to return a unique hash_code for a
+  /// given Value.
   static llvm::hash_code computeHash(
-      Operation *op,
+      Operation *op, function_ref<llvm::hash_code(Operation *)> hashOp,
       function_ref<llvm::hash_code(Value)> hashOperands =
           [](Value v) { return hash_value(v); },
       function_ref<llvm::hash_code(Value)> hashResults =
           [](Value v) { return hash_value(v); },
       Flags flags = Flags::None);
 
+  static llvm::hash_code computeHash(
+      Operation *op,
+      function_ref<llvm::hash_code(Value)> hashOperands =
+          [](Value v) { return hash_value(v); },
+      function_ref<llvm::hash_code(Value)> hashResults =
+          [](Value v) { return hash_value(v); },
+      Flags flags = Flags::None) {
+    return computeHash(op, simpleOpHash, hashOperands, hashResults, flags);
+  }
+
+  /// Helper that can be used with `computeHash` above to combine hashes for
+  /// basic structural properties.
+  static llvm::hash_code simpleOpHash(Operation *op);
   /// Helper that can be used with `computeHash` above to ignore operation
   /// operands/result mapping.
   static llvm::hash_code ignoreHashValue(Value) { return llvm::hash_code{}; }
@@ -1255,38 +1270,51 @@ struct OperationEquivalence {
 
   /// Compare two operations (including their regions) and return if they are
   /// equivalent.
-  ///
-  /// * `checkEquivalent` is a callback to check if two values are equivalent.
+  /// * `checkOpStructureEquivalent` is a callback to check if the structures of
+  /// two operations are equivalent.
+  /// * `checkValueEquivalent` is a callback to check if two values are
+  /// equivalent.
   ///   For two operations to be equivalent, their operands must be the same SSA
   ///   value or this callback must return `success`.
   /// * `markEquivalent` is a callback to inform the caller that the analysis
   ///   determined that two values are equivalent.
   ///
   /// Note: Additional information regarding value equivalence can be injected
-  /// into the analysis via `checkEquivalent`. Typically, callers may want
-  /// values that were determined to be equivalent as per `markEquivalent` to be
-  /// reflected in `checkEquivalent`, unless `exactValueMatch` or a different
+  /// into the analysis via `checkOpStructureEquivalent` and
+  /// `checkValueEquivalent`. Typically, callers may want values that were
+  /// determined to be equivalent as per `markEquivalent` to be reflected in
+  /// `checkValueEquivalent`, unless `exactValueMatch` or a different
   /// equivalence relationship is desired.
   static bool
   isEquivalentTo(Operation *lhs, Operation *rhs,
-                 function_ref<LogicalResult(Value, Value)> checkEquivalent,
+                 function_ref<LogicalResult(Operation *, Operation *)>
+                     checkOpStructureEquivalent,
+                 function_ref<LogicalResult(Value, Value)> checkValueEquivalent,
                  function_ref<void(Value, Value)> markEquivalent = nullptr,
                  Flags flags = Flags::None);
 
   /// Compare two operations and return if they are equivalent.
   static bool isEquivalentTo(Operation *lhs, Operation *rhs, Flags flags);
+  static bool
+  isEquivalentTo(Operation *lhs, Operation *rhs,
+                 function_ref<LogicalResult(Operation *, Operation *)>
+                     checkOpStructureEquivalent,
+                 Flags flags);
 
   /// Compare two regions (including their subregions) and return if they are
   /// equivalent. See also `isEquivalentTo` for details.
   static bool isRegionEquivalentTo(
       Region *lhs, Region *rhs,
-      function_ref<LogicalResult(Value, Value)> checkEquivalent,
+      function_ref<LogicalResult(Operation *, Operation *)>
+          checkOpStructureEquivalent,
+      function_ref<LogicalResult(Value, Value)> checkValueEquivalent,
       function_ref<void(Value, Value)> markEquivalent,
       OperationEquivalence::Flags flags);
 
   /// Compare two regions and return if they are equivalent.
   static bool isRegionEquivalentTo(Region *lhs, Region *rhs,
                                    OperationEquivalence::Flags flags);
+  static LogicalResult simpleOpEquivalence(Operation *lhs, Operation *rhs);
 
   /// Helper that can be used with `isEquivalentTo` above to consider ops
   /// equivalent even if their operands are not equivalent.
diff --git a/mlir/include/mlir/Transforms/CSE.h b/mlir/include/mlir/Transforms/CSE.h
index 3d01ece0780509d..053e60b487b6d2c 100644
--- a/mlir/include/mlir/Transforms/CSE.h
+++ b/mlir/include/mlir/Transforms/CSE.h
@@ -13,11 +13,14 @@
 #ifndef MLIR_TRANSFORMS_CSE_H_
 #define MLIR_TRANSFORMS_CSE_H_
 
+#include "mlir/IR/DialectInterface.h"
+
 namespace mlir {
 
 class DominanceInfo;
 class Operation;
 class RewriterBase;
+struct LogicalResult;
 
 /// Eliminate common subexpressions within the given operation. This transform
 /// looks for and deduplicates equivalent operations.
@@ -27,6 +30,40 @@ void eliminateCommonSubExpressions(RewriterBase &rewriter,
                                    DominanceInfo &domInfo, Operation *op,
                                    bool *changed = nullptr);
 
+//===----------------------------------------------------------------------===//
+// CSEInterface
+//===----------------------------------------------------------------------===//
+
+/// This is the interface that allows users to customize CSE.
+class DialectCSEInterface : public DialectInterface::Base<DialectCSEInterface> {
+public:
+  DialectCSEInterface(Dialect *dialect) : Base(dialect) {}
+
+  //===--------------------------------------------------------------------===//
+  // Analysis Hooks
+  //===--------------------------------------------------------------------===//
+
+  /// These two hooks are called in DenseMapInfo used by CSE.
+
+  /// Returns a hash for the operation.
+  /// CSE will use default implementation if `std::nullopt` is returned.
+  virtual std::optional<unsigned> getHashValue(Operation *op) const = 0;
+
+  /// Returns true if two operations are considered to be equivalent.
+  /// CSE will use default implementation if `std::nullopt` is returned.
+  virtual std::optional<bool> isEqual(Operation *lhs, Operation *rhs) const = 0;
+
+  //===--------------------------------------------------------------------===//
+  // Transformation Hooks
+  //===--------------------------------------------------------------------===//
+
+  /// This hook is called when 'op' is considered to be a common subexpression
+  /// of 'existingOp' and is going to be eliminated. This hook allows users to
+  /// propagate information of 'op' to 'existingOp'. Note that the hash value of
+  /// 'existingOp' must not be changed due the mutation of 'existingOp'.
+  virtual void mergeOperations(Operation *existingOp, Operation *op) const {}
+};
+
 } // namespace mlir
 
 #endif // MLIR_TRANSFORMS_CSE_H_
diff --git a/mlir/lib/IR/OperationSupport.cpp b/mlir/lib/IR/OperationSupport.cpp
index fc5ccd23b5108d8..1d2189e530cdf60 100644
--- a/mlir/lib/IR/OperationSupport.cpp
+++ b/mlir/lib/IR/OperationSupport.cpp
@@ -668,15 +668,11 @@ Value ValueRange::dereference_iterator(const OwnerT &owner, ptrdiff_t index) {
 //===----------------------------------------------------------------------===//
 
 llvm::hash_code OperationEquivalence::computeHash(
-    Operation *op, function_ref<llvm::hash_code(Value)> hashOperands,
+    Operation *op, function_ref<llvm::hash_code(Operation *)> hashOp,
+    function_ref<llvm::hash_code(Value)> hashOperands,
     function_ref<llvm::hash_code(Value)> hashResults, Flags flags) {
-  // Hash operations based upon their:
-  //   - Operation Name
-  //   - Attributes
-  //   - Result Types
-  llvm::hash_code hash =
-      llvm::hash_combine(op->getName(), op->getDiscardableAttrDictionary(),
-                         op->getResultTypes(), op->hashProperties());
+  // Hash operations based upon their structural properties using `hashOp`.
+  llvm::hash_code hash = hashOp(op);
 
   //   - Location if required
   if (!(flags & Flags::IgnoreLocations))
@@ -694,7 +690,9 @@ llvm::hash_code OperationEquivalence::computeHash(
 
 /*static*/ bool OperationEquivalence::isRegionEquivalentTo(
     Region *lhs, Region *rhs,
-    function_ref<LogicalResult(Value, Value)> checkEquivalent,
+    function_ref<LogicalResult(Operation *, Operation *)>
+        checkOpStructureEquivalent,
+    function_ref<LogicalResult(Value, Value)> checkValueEquivalent,
     function_ref<void(Value, Value)> markEquivalent,
     OperationEquivalence::Flags flags) {
   DenseMap<Block *, Block *> blocksMap;
@@ -724,8 +722,9 @@ llvm::hash_code OperationEquivalence::computeHash(
 
     auto opsEquivalent = [&](Operation &lOp, Operation &rOp) {
       // Check for op equality (recursively).
-      if (!OperationEquivalence::isEquivalentTo(&lOp, &rOp, checkEquivalent,
-                                                markEquivalent, flags))
+      if (!OperationEquivalence::isEquivalentTo(
+              &lOp, &rOp, checkOpStructureEquivalent, checkValueEquivalent,
+              markEquivalent, flags))
         return false;
       // Check successor mapping.
       for (auto successorsPair :
@@ -766,7 +765,7 @@ OperationEquivalence::isRegionEquivalentTo(Region *lhs, Region *rhs,
                                            OperationEquivalence::Flags flags) {
   ValueEquivalenceCache cache;
   return isRegionEquivalentTo(
-      lhs, rhs,
+      lhs, rhs, simpleOpEquivalence,
       [&](Value lhsValue, Value rhsValue) -> LogicalResult {
         return cache.checkEquivalent(lhsValue, rhsValue);
       },
@@ -776,24 +775,39 @@ OperationEquivalence::isRegionEquivalentTo(Region *lhs, Region *rhs,
       flags);
 }
 
+/*static*/ llvm::hash_code OperationEquivalence::simpleOpHash(Operation *op) {
+  return llvm::hash_combine(op->getName(), op->getResultTypes(),
+                            op->hashProperties(),
+                            op->getDiscardableAttrDictionary());
+}
+
+/*static*/ LogicalResult
+OperationEquivalence::simpleOpEquivalence(Operation *lhs, Operation *rhs) {
+  return LogicalResult::success(
+      lhs->getName() == rhs->getName() &&
+      lhs->getDiscardableAttrDictionary() ==
+          rhs->getDiscardableAttrDictionary() &&
+      lhs->getNumRegions() == rhs->getNumRegions() &&
+      lhs->getNumSuccessors() == rhs->getNumSuccessors() &&
+      lhs->getNumOperands() == rhs->getNumOperands() &&
+      lhs->getNumResults() == rhs->getNumResults() &&
+      lhs->getName().compareOpProperties(lhs->getPropertiesStorage(),
+                                         rhs->getPropertiesStorage()));
+}
+
 /*static*/ bool OperationEquivalence::isEquivalentTo(
     Operation *lhs, Operation *rhs,
-    function_ref<LogicalResult(Value, Value)> checkEquivalent,
+    function_ref<LogicalResult(Operation *, Operation *)>
+        checkOpStructureEquivalent,
+    function_ref<LogicalResult(Value, Value)> checkValueEquivalent,
     function_ref<void(Value, Value)> markEquivalent, Flags flags) {
   if (lhs == rhs)
     return true;
 
-  // 1. Compare the operation properties.
-  if (lhs->getName() != rhs->getName() ||
-      lhs->getDiscardableAttrDictionary() !=
-          rhs->getDiscardableAttrDictionary() ||
-      lhs->getNumRegions() != rhs->getNumRegions() ||
-      lhs->getNumSuccessors() != rhs->getNumSuccessors() ||
-      lhs->getNumOperands() != rhs->getNumOperands() ||
-      lhs->getNumResults() != rhs->getNumResults() ||
-      !lhs->getName().compareOpProperties(lhs->getPropertiesStorage(),
-                                        rhs->getPropertiesStorage()))
+  // 1. Compare the operation structural properties.
+  if (failed(checkOpStructureEquivalent(lhs, rhs)))
     return false;
+
   if (!(flags & IgnoreLocations) && lhs->getLoc() != rhs->getLoc())
     return false;
 
@@ -805,7 +819,7 @@ OperationEquivalence::isRegionEquivalentTo(Region *lhs, Region *rhs,
       continue;
     if (curArg.getType() != otherArg.getType())
       return false;
-    if (failed(checkEquivalent(curArg, otherArg)))
+    if (failed(checkValueEquivalent(curArg, otherArg)))
       return false;
   }
 
@@ -822,7 +836,8 @@ OperationEquivalence::isRegionEquivalentTo(Region *lhs, Region *rhs,
   // 4. Compare regions.
   for (auto regionPair : llvm::zip(lhs->getRegions(), rhs->getRegions()))
     if (!isRegionEquivalentTo(&std::get<0>(regionPair),
-                              &std::get<1>(regionPair), checkEquivalent,
+                              &std::get<1>(regionPair),
+                              checkOpStructureEquivalent, checkValueEquivalent,
                               markEquivalent, flags))
       return false;
 
@@ -832,9 +847,18 @@ OperationEquivalence::isRegionEquivalentTo(Region *lhs, Region *rhs,
 /*static*/ bool OperationEquivalence::isEquivalentTo(Operation *lhs,
                                                      Operation *rhs,
                                                      Flags flags) {
+  return OperationEquivalence::isEquivalentTo(lhs, rhs, simpleOpEquivalence,
+                                              flags);
+}
+
+/*static*/ bool OperationEquivalence::isEquivalentTo(
+    Operation *lhs, Operation *rhs,
+    function_ref<LogicalResult(Operation *, Operation *)>
+        checkOpStructureEquivalent,
+    Flags flags) {
   ValueEquivalenceCache cache;
   return OperationEquivalence::isEquivalentTo(
-      lhs, rhs,
+      lhs, rhs, checkOpStructureEquivalent,
       [&](Value lhsValue, Value rhsValue) -> LogicalResult {
         return cache.checkEquivalent(lhsValue, rhsValue);
       },
diff --git a/mlir/lib/Transforms/CSE.cpp b/mlir/lib/Transforms/CSE.cpp
index 3affd88d158de59..771467429c8a8e2 100644
--- a/mlir/lib/Transforms/CSE.cpp
+++ b/mlir/lib/Transforms/CSE.cpp
@@ -35,8 +35,16 @@ using namespace mlir;
 namespace {
 struct SimpleOperationInfo : public llvm::DenseMapInfo<Operation *> {
   static unsigned getHashValue(const Operation *opC) {
+    auto *op = const_cast<Operation *>(opC);
+    // Use a custom hook if provided.
+    if (auto *interface = dyn_cast<DialectCSEInterface>(op->getDialect())) {
+      std::optional<unsigned> val = interface->getHashValue(op);
+      if (val)
+        return *val;
+    }
+
     return OperationEquivalence::computeHash(
-        const_cast<Operation *>(opC),
+        op,
         /*hashOperands=*/OperationEquivalence::directHashValue,
         /*hashResults=*/OperationEquivalence::ignoreHashValue,
         OperationEquivalence::IgnoreLocations);
@@ -49,6 +57,18 @@ struct SimpleOperationInfo : public llvm::DenseMapInfo<Operation *> {
     if (lhs == getTombstoneKey() || lhs == getEmptyKey() ||
         rhs == getTombstoneKey() || rhs == getEmptyKey())
       return false;
+
+    if (lhs->getDialect() != rhs->getDialect())
+      return false;
+
+    if (auto *interface = dyn_cast<DialectCSEInterface>(lhs->getDialect())) {
+      std::optional<bool> val = interface->isEqual(lhs, rhs);
+      assert(val == interface->isEqual(rhs, lhs) &&
+             "DialectCSEInterface::isEqual must be symmetrical");
+      if (val)
+        return *val;
+    }
+
     return OperationEquivalence::isEquivalentTo(
         const_cast<Operation *>(lhsC), const_cast<Operation *>(rhsC),
         OperationEquivalence::IgnoreLocations);
@@ -131,6 +151,18 @@ class CSEDriver {
 void CSEDriver::replaceUsesAndDelete(ScopedMapTy &knownValues, Operation *op,
                                      Operation *existing,
                                      bool hasSSADominance) {
+  // Invoke a callback provided by CSE interface.
+  if (auto *cseInterface = dyn_cast<DialectCSEInterface>(op->getDialect())) {
+#ifndef NDEBUG
+    auto hashPrev = cseInterface->getHashValue(existing);
+#endif
+    cseInterface->mergeOperations(existing, op);
+#ifndef NDEBUG
+    assert(hashPrev == cseInterface->getHashValue(existing) &&
+           "hash values must not be modified by `mergeOperations`");
+#endif
+  }
+
   // If we find one then replace all uses of the current operation with the
   // existing one and mark it for deletion. We can only replace an operand in
   // an operation if it has not been visited yet.
diff --git a/mlir/lib/Transforms/Utils/CFGToSCF.cpp b/mlir/lib/Transforms/Utils/CFGToSCF.cpp
index f2998b4047e201e..9982015ba3de779 100644
--- a/mlir/lib/Transforms/Utils/CFGToSCF.cpp
+++ b/mlir/lib/Transforms/Utils/CFGToSCF.cpp
@@ -411,6 +411,7 @@ struct ReturnLikeOpEquivalence : public llvm::DenseMapInfo<Operation *> {
       return false;
     return OperationEquivalence::isEquivalentTo(
         const_cast<Operation *>(lhs), const_cast<Operation *>(rhs),
+        OperationEquivalence::simpleOpEquivalence,
         OperationEquivalence::ignoreValueEquivalence, nullptr,
         OperationEquivalence::IgnoreLocations);
   }
diff --git a/mlir/lib/Transforms/Utils/RegionUtils.cpp b/mlir/lib/Transforms/Utils/RegionUtils.cpp
index 1f2344677e6515c..e4e14baaeab7593 100644
--- a/mlir/lib/Transforms/Utils/RegionUtils.cpp
+++ b/mlir/lib/Transforms/Utils/RegionUtils.cpp
@@ -592,7 +592,8 @@ LogicalResult BlockMergeCluster::addToCluster(BlockEquivalenceData &blockData) {
   for (int opI = 0; lhsIt != lhsE && rhsIt != rhsE; ++lhsIt, ++rhsIt, ++opI) {
     // Check that the operations are equivalent.
     if (!OperationEquivalence::isEquivalentTo(
-            &*lhsIt, &*rhsIt, OperationEquivalence::ignoreValueEquivalence,
+            &*lhsIt, &*rhsIt, OperationEquivalence::simpleOpEquivalence,
+            OperationEquivalence::ignoreValueEquivalence,
             /*markEquivalent=*/nullptr,
             OperationEquivalence::Flags::IgnoreLocations))
       return failure();
diff --git a/mlir/test/Transforms/cse.mlir b/mlir/test/Transforms/cse.mlir
index c764d2b9bd57d8a..1e348f2d480c9eb 100644
--- a/mlir/test/Transforms/cse.mlir
+++ b/mlir/test/Transforms/cse.mlir
@@ -520,3 +520,14 @@ func.func @cse_recursive_effects_failure() -> (i32, i32, i32) {
   %2 = "test.op_with_memread"() : () -> (i32)
   return %0, %2, %1 : i32, i32, i32
 }
+
+func.func @cse_test_dialect_interface() -> (i32, i32, i32) {
+  %0 = "test.always_speculatable_op"() {test.non_essential = "b"} : () -> i32
+  %1 = "test.always_speculatable_op"() {test.non_essential = "a"} : () -> i32
+  %2 = "test.always_speculatable_op"() {test.non_essential = "c"} : () -> i32
+  return %0, %1, %2 : i32, i32, i32
+}
+
+// CHECK-LABEL: @cse_test_dialect_interface
+// CHECK-NEXT: %[[result:.+]] = "test.always_speculatable_op"() {test.non_essential = "a"} : () -> i32
+// CHECK-NEXT: return %[[result]], %[[result]], %[[result]]
diff --git a/mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp b/mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp
index 80ddcdc8ea69f7c..6f5d86f3b9bc2ca 100644
--- a/mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp
+++ b/mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp
@@ -9,6 +9,7 @@
 #include "TestDialect.h"
 #include "mlir/Interfaces/FoldInterfaces.h"
 #include "mlir/Reducer/ReductionPatternInterface.h"
+#include "mlir/Transforms/CSE.h"
 #include "mlir/Transforms/InliningUtils.h"
 
 using namespace mlir;
@@ -385,6 +386,90 @@ struct TestReductionPatternInterface : public DialectReductionPatternInterface {
   }
 };
 
+/// This class defines the interface for customizing CSE.
+struct TestCSEInterface : public DialectCSEInterface {
+  using DialectCSEInterface::DialectCSEInterface;
+
+  StringRef nonEssentialAttrName = "test.non_essential";
+
+  //===--------------------------------------------------------------------===//
+  // Analysis Hooks
+  //===--------------------------------------------------------------------===//
+
+  /// Return a hash that excludes 'test.non_essential' attributes.
+  std::optional<unsigned> getHashValue(Operation *op) const override {
+    auto attr = op->g...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/73520


More information about the Mlir-commits mailing list