[Mlir-commits] [mlir] [mlir][CSE] Introduce DialectInterface for CSE (PR #73520)

Ivan Butygin llvmlistbot at llvm.org
Mon Nov 27 06:21:52 PST 2023


https://github.com/Hardcode84 created https://github.com/llvm/llvm-project/pull/73520

Based on https://reviews.llvm.org/D154512 by @uenoku

Depends on https://github.com/llvm/llvm-project/pull/73455 , review second commit

Changed interface return types to `std::optional` so interface can override equivalence for some ops and rely on default implementation for others.

Original description:

This patch implements CSE dialect interfaces proposed in https://discourse.llvm.org/t/rfc-dialectinterface-for-cse/71831/9.

CSEDialectInterface has three methods getHashValue and isEqual are passed to DenseMapInfo used by CSE to provide a way to customize CSE behaivor from user sides. mergeOperations is a callback when an operation is eliminated in order to propagate operation information one to another.

>From 9ff0f41cddef79d2cb11ea84496bd9e5c2e3c5e7 Mon Sep 17 00:00:00 2001
From: Hideto Ueno <uenoku.tokotoko at gmail.com>
Date: Sun, 26 Nov 2023 19:46:56 +0100
Subject: [PATCH 1/2] [mlir][OperationEquivalence] Add an extra callback to
 hook operation equivalence check

Rebase of https://reviews.llvm.org/D155577 by @uenoku

The only changes are merge consflict around `simpleOpEquivalence` func and `CFGToSCF.cpp` change.

Original description:

This patch adds an extract callback `checkOpStructureEquivalent` to a method OperationEquivalence::isEquivalentTo so that users can use customize equivalence regarding structural properties of operations.

I'm not sure that "structural properties" is a best term to describe operation information used here. I think just "properties" is better but I added "structural" to clearly distinguish from operation properties.

This patch is preparation to provide dialect interface for CSE https://discourse.llvm.org/t/rfc-dialectinterface-for-cse/71831
---
 mlir/include/mlir/IR/OperationSupport.h   | 48 +++++++++++---
 mlir/lib/IR/OperationSupport.cpp          | 76 +++++++++++++++--------
 mlir/lib/Transforms/Utils/CFGToSCF.cpp    |  1 +
 mlir/lib/Transforms/Utils/RegionUtils.cpp |  3 +-
 4 files changed, 91 insertions(+), 37 deletions(-)

diff --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h
index 6a5ec129ad564b4..d6c5d50efb7adc1 100644
--- a/mlir/include/mlir/IR/OperationSupport.h
+++ b/mlir/include/mlir/IR/OperationSupport.h
@@ -1236,16 +1236,31 @@ struct OperationEquivalence {
   };
 
   /// Compute a hash for the given operation.
-  /// The `hashOperands` and `hashResults` callbacks are expected to return a
-  /// unique hash_code for a given Value.
+  /// The `hashOp` is a callback to compute a hash for structural properties of
+  /// `op` such as op name, result types and attributes. The `hashOperands` and
+  /// `hashResults` callbacks are expected to return a unique hash_code for a
+  /// given Value.
   static llvm::hash_code computeHash(
-      Operation *op,
+      Operation *op, function_ref<llvm::hash_code(Operation *)> hashOp,
       function_ref<llvm::hash_code(Value)> hashOperands =
           [](Value v) { return hash_value(v); },
       function_ref<llvm::hash_code(Value)> hashResults =
           [](Value v) { return hash_value(v); },
       Flags flags = Flags::None);
 
+  static llvm::hash_code computeHash(
+      Operation *op,
+      function_ref<llvm::hash_code(Value)> hashOperands =
+          [](Value v) { return hash_value(v); },
+      function_ref<llvm::hash_code(Value)> hashResults =
+          [](Value v) { return hash_value(v); },
+      Flags flags = Flags::None) {
+    return computeHash(op, simpleOpHash, hashOperands, hashResults, flags);
+  }
+
+  /// Helper that can be used with `computeHash` above to combine hashes for
+  /// basic structural properties.
+  static llvm::hash_code simpleOpHash(Operation *op);
   /// Helper that can be used with `computeHash` above to ignore operation
   /// operands/result mapping.
   static llvm::hash_code ignoreHashValue(Value) { return llvm::hash_code{}; }
@@ -1255,38 +1270,51 @@ struct OperationEquivalence {
 
   /// Compare two operations (including their regions) and return if they are
   /// equivalent.
-  ///
-  /// * `checkEquivalent` is a callback to check if two values are equivalent.
+  /// * `checkOpStructureEquivalent` is a callback to check if the structures of
+  /// two operations are equivalent.
+  /// * `checkValueEquivalent` is a callback to check if two values are
+  /// equivalent.
   ///   For two operations to be equivalent, their operands must be the same SSA
   ///   value or this callback must return `success`.
   /// * `markEquivalent` is a callback to inform the caller that the analysis
   ///   determined that two values are equivalent.
   ///
   /// Note: Additional information regarding value equivalence can be injected
-  /// into the analysis via `checkEquivalent`. Typically, callers may want
-  /// values that were determined to be equivalent as per `markEquivalent` to be
-  /// reflected in `checkEquivalent`, unless `exactValueMatch` or a different
+  /// into the analysis via `checkOpStructureEquivalent` and
+  /// `checkValueEquivalent`. Typically, callers may want values that were
+  /// determined to be equivalent as per `markEquivalent` to be reflected in
+  /// `checkValueEquivalent`, unless `exactValueMatch` or a different
   /// equivalence relationship is desired.
   static bool
   isEquivalentTo(Operation *lhs, Operation *rhs,
-                 function_ref<LogicalResult(Value, Value)> checkEquivalent,
+                 function_ref<LogicalResult(Operation *, Operation *)>
+                     checkOpStructureEquivalent,
+                 function_ref<LogicalResult(Value, Value)> checkValueEquivalent,
                  function_ref<void(Value, Value)> markEquivalent = nullptr,
                  Flags flags = Flags::None);
 
   /// Compare two operations and return if they are equivalent.
   static bool isEquivalentTo(Operation *lhs, Operation *rhs, Flags flags);
+  static bool
+  isEquivalentTo(Operation *lhs, Operation *rhs,
+                 function_ref<LogicalResult(Operation *, Operation *)>
+                     checkOpStructureEquivalent,
+                 Flags flags);
 
   /// Compare two regions (including their subregions) and return if they are
   /// equivalent. See also `isEquivalentTo` for details.
   static bool isRegionEquivalentTo(
       Region *lhs, Region *rhs,
-      function_ref<LogicalResult(Value, Value)> checkEquivalent,
+      function_ref<LogicalResult(Operation *, Operation *)>
+          checkOpStructureEquivalent,
+      function_ref<LogicalResult(Value, Value)> checkValueEquivalent,
       function_ref<void(Value, Value)> markEquivalent,
       OperationEquivalence::Flags flags);
 
   /// Compare two regions and return if they are equivalent.
   static bool isRegionEquivalentTo(Region *lhs, Region *rhs,
                                    OperationEquivalence::Flags flags);
+  static LogicalResult simpleOpEquivalence(Operation *lhs, Operation *rhs);
 
   /// Helper that can be used with `isEquivalentTo` above to consider ops
   /// equivalent even if their operands are not equivalent.
diff --git a/mlir/lib/IR/OperationSupport.cpp b/mlir/lib/IR/OperationSupport.cpp
index fc5ccd23b5108d8..1d2189e530cdf60 100644
--- a/mlir/lib/IR/OperationSupport.cpp
+++ b/mlir/lib/IR/OperationSupport.cpp
@@ -668,15 +668,11 @@ Value ValueRange::dereference_iterator(const OwnerT &owner, ptrdiff_t index) {
 //===----------------------------------------------------------------------===//
 
 llvm::hash_code OperationEquivalence::computeHash(
-    Operation *op, function_ref<llvm::hash_code(Value)> hashOperands,
+    Operation *op, function_ref<llvm::hash_code(Operation *)> hashOp,
+    function_ref<llvm::hash_code(Value)> hashOperands,
     function_ref<llvm::hash_code(Value)> hashResults, Flags flags) {
-  // Hash operations based upon their:
-  //   - Operation Name
-  //   - Attributes
-  //   - Result Types
-  llvm::hash_code hash =
-      llvm::hash_combine(op->getName(), op->getDiscardableAttrDictionary(),
-                         op->getResultTypes(), op->hashProperties());
+  // Hash operations based upon their structural properties using `hashOp`.
+  llvm::hash_code hash = hashOp(op);
 
   //   - Location if required
   if (!(flags & Flags::IgnoreLocations))
@@ -694,7 +690,9 @@ llvm::hash_code OperationEquivalence::computeHash(
 
 /*static*/ bool OperationEquivalence::isRegionEquivalentTo(
     Region *lhs, Region *rhs,
-    function_ref<LogicalResult(Value, Value)> checkEquivalent,
+    function_ref<LogicalResult(Operation *, Operation *)>
+        checkOpStructureEquivalent,
+    function_ref<LogicalResult(Value, Value)> checkValueEquivalent,
     function_ref<void(Value, Value)> markEquivalent,
     OperationEquivalence::Flags flags) {
   DenseMap<Block *, Block *> blocksMap;
@@ -724,8 +722,9 @@ llvm::hash_code OperationEquivalence::computeHash(
 
     auto opsEquivalent = [&](Operation &lOp, Operation &rOp) {
       // Check for op equality (recursively).
-      if (!OperationEquivalence::isEquivalentTo(&lOp, &rOp, checkEquivalent,
-                                                markEquivalent, flags))
+      if (!OperationEquivalence::isEquivalentTo(
+              &lOp, &rOp, checkOpStructureEquivalent, checkValueEquivalent,
+              markEquivalent, flags))
         return false;
       // Check successor mapping.
       for (auto successorsPair :
@@ -766,7 +765,7 @@ OperationEquivalence::isRegionEquivalentTo(Region *lhs, Region *rhs,
                                            OperationEquivalence::Flags flags) {
   ValueEquivalenceCache cache;
   return isRegionEquivalentTo(
-      lhs, rhs,
+      lhs, rhs, simpleOpEquivalence,
       [&](Value lhsValue, Value rhsValue) -> LogicalResult {
         return cache.checkEquivalent(lhsValue, rhsValue);
       },
@@ -776,24 +775,39 @@ OperationEquivalence::isRegionEquivalentTo(Region *lhs, Region *rhs,
       flags);
 }
 
+/*static*/ llvm::hash_code OperationEquivalence::simpleOpHash(Operation *op) {
+  return llvm::hash_combine(op->getName(), op->getResultTypes(),
+                            op->hashProperties(),
+                            op->getDiscardableAttrDictionary());
+}
+
+/*static*/ LogicalResult
+OperationEquivalence::simpleOpEquivalence(Operation *lhs, Operation *rhs) {
+  return LogicalResult::success(
+      lhs->getName() == rhs->getName() &&
+      lhs->getDiscardableAttrDictionary() ==
+          rhs->getDiscardableAttrDictionary() &&
+      lhs->getNumRegions() == rhs->getNumRegions() &&
+      lhs->getNumSuccessors() == rhs->getNumSuccessors() &&
+      lhs->getNumOperands() == rhs->getNumOperands() &&
+      lhs->getNumResults() == rhs->getNumResults() &&
+      lhs->getName().compareOpProperties(lhs->getPropertiesStorage(),
+                                         rhs->getPropertiesStorage()));
+}
+
 /*static*/ bool OperationEquivalence::isEquivalentTo(
     Operation *lhs, Operation *rhs,
-    function_ref<LogicalResult(Value, Value)> checkEquivalent,
+    function_ref<LogicalResult(Operation *, Operation *)>
+        checkOpStructureEquivalent,
+    function_ref<LogicalResult(Value, Value)> checkValueEquivalent,
     function_ref<void(Value, Value)> markEquivalent, Flags flags) {
   if (lhs == rhs)
     return true;
 
-  // 1. Compare the operation properties.
-  if (lhs->getName() != rhs->getName() ||
-      lhs->getDiscardableAttrDictionary() !=
-          rhs->getDiscardableAttrDictionary() ||
-      lhs->getNumRegions() != rhs->getNumRegions() ||
-      lhs->getNumSuccessors() != rhs->getNumSuccessors() ||
-      lhs->getNumOperands() != rhs->getNumOperands() ||
-      lhs->getNumResults() != rhs->getNumResults() ||
-      !lhs->getName().compareOpProperties(lhs->getPropertiesStorage(),
-                                        rhs->getPropertiesStorage()))
+  // 1. Compare the operation structural properties.
+  if (failed(checkOpStructureEquivalent(lhs, rhs)))
     return false;
+
   if (!(flags & IgnoreLocations) && lhs->getLoc() != rhs->getLoc())
     return false;
 
@@ -805,7 +819,7 @@ OperationEquivalence::isRegionEquivalentTo(Region *lhs, Region *rhs,
       continue;
     if (curArg.getType() != otherArg.getType())
       return false;
-    if (failed(checkEquivalent(curArg, otherArg)))
+    if (failed(checkValueEquivalent(curArg, otherArg)))
       return false;
   }
 
@@ -822,7 +836,8 @@ OperationEquivalence::isRegionEquivalentTo(Region *lhs, Region *rhs,
   // 4. Compare regions.
   for (auto regionPair : llvm::zip(lhs->getRegions(), rhs->getRegions()))
     if (!isRegionEquivalentTo(&std::get<0>(regionPair),
-                              &std::get<1>(regionPair), checkEquivalent,
+                              &std::get<1>(regionPair),
+                              checkOpStructureEquivalent, checkValueEquivalent,
                               markEquivalent, flags))
       return false;
 
@@ -832,9 +847,18 @@ OperationEquivalence::isRegionEquivalentTo(Region *lhs, Region *rhs,
 /*static*/ bool OperationEquivalence::isEquivalentTo(Operation *lhs,
                                                      Operation *rhs,
                                                      Flags flags) {
+  return OperationEquivalence::isEquivalentTo(lhs, rhs, simpleOpEquivalence,
+                                              flags);
+}
+
+/*static*/ bool OperationEquivalence::isEquivalentTo(
+    Operation *lhs, Operation *rhs,
+    function_ref<LogicalResult(Operation *, Operation *)>
+        checkOpStructureEquivalent,
+    Flags flags) {
   ValueEquivalenceCache cache;
   return OperationEquivalence::isEquivalentTo(
-      lhs, rhs,
+      lhs, rhs, checkOpStructureEquivalent,
       [&](Value lhsValue, Value rhsValue) -> LogicalResult {
         return cache.checkEquivalent(lhsValue, rhsValue);
       },
diff --git a/mlir/lib/Transforms/Utils/CFGToSCF.cpp b/mlir/lib/Transforms/Utils/CFGToSCF.cpp
index f2998b4047e201e..9982015ba3de779 100644
--- a/mlir/lib/Transforms/Utils/CFGToSCF.cpp
+++ b/mlir/lib/Transforms/Utils/CFGToSCF.cpp
@@ -411,6 +411,7 @@ struct ReturnLikeOpEquivalence : public llvm::DenseMapInfo<Operation *> {
       return false;
     return OperationEquivalence::isEquivalentTo(
         const_cast<Operation *>(lhs), const_cast<Operation *>(rhs),
+        OperationEquivalence::simpleOpEquivalence,
         OperationEquivalence::ignoreValueEquivalence, nullptr,
         OperationEquivalence::IgnoreLocations);
   }
diff --git a/mlir/lib/Transforms/Utils/RegionUtils.cpp b/mlir/lib/Transforms/Utils/RegionUtils.cpp
index 1f2344677e6515c..e4e14baaeab7593 100644
--- a/mlir/lib/Transforms/Utils/RegionUtils.cpp
+++ b/mlir/lib/Transforms/Utils/RegionUtils.cpp
@@ -592,7 +592,8 @@ LogicalResult BlockMergeCluster::addToCluster(BlockEquivalenceData &blockData) {
   for (int opI = 0; lhsIt != lhsE && rhsIt != rhsE; ++lhsIt, ++rhsIt, ++opI) {
     // Check that the operations are equivalent.
     if (!OperationEquivalence::isEquivalentTo(
-            &*lhsIt, &*rhsIt, OperationEquivalence::ignoreValueEquivalence,
+            &*lhsIt, &*rhsIt, OperationEquivalence::simpleOpEquivalence,
+            OperationEquivalence::ignoreValueEquivalence,
             /*markEquivalent=*/nullptr,
             OperationEquivalence::Flags::IgnoreLocations))
       return failure();

>From 502d81f6c549a11622d661011e0403a56f8c6177 Mon Sep 17 00:00:00 2001
From: Hideto Ueno <uenoku.tokotoko at gmail.com>
Date: Mon, 27 Nov 2023 14:42:06 +0100
Subject: [PATCH 2/2] [mlir][CSE] Introduce DialectInterface for CSE

Based on https://reviews.llvm.org/D154512 by @uenoku

Depends on https://github.com/llvm/llvm-project/pull/73455

Changed interface return types to `std::optional` so interface can override equivalence for some ops and rely on default implementation for others.

Original description:

This patch implements CSE dialect interfaces proposed in https://discourse.llvm.org/t/rfc-dialectinterface-for-cse/71831/9.

CSEDialectInterface has three methods getHashValue and isEqual are passed to DenseMapInfo used by CSE to provide a way to customize CSE behaivor from user sides. mergeOperations is a callback when an operation is eliminated in order to propagate operation information one to another.
---
 mlir/include/mlir/Transforms/CSE.h            | 37 ++++++++
 mlir/lib/Transforms/CSE.cpp                   | 34 ++++++-
 mlir/test/Transforms/cse.mlir                 | 11 +++
 .../Dialect/Test/TestDialectInterfaces.cpp    | 88 ++++++++++++++++++-
 4 files changed, 168 insertions(+), 2 deletions(-)

diff --git a/mlir/include/mlir/Transforms/CSE.h b/mlir/include/mlir/Transforms/CSE.h
index 3d01ece0780509d..053e60b487b6d2c 100644
--- a/mlir/include/mlir/Transforms/CSE.h
+++ b/mlir/include/mlir/Transforms/CSE.h
@@ -13,11 +13,14 @@
 #ifndef MLIR_TRANSFORMS_CSE_H_
 #define MLIR_TRANSFORMS_CSE_H_
 
+#include "mlir/IR/DialectInterface.h"
+
 namespace mlir {
 
 class DominanceInfo;
 class Operation;
 class RewriterBase;
+struct LogicalResult;
 
 /// Eliminate common subexpressions within the given operation. This transform
 /// looks for and deduplicates equivalent operations.
@@ -27,6 +30,40 @@ void eliminateCommonSubExpressions(RewriterBase &rewriter,
                                    DominanceInfo &domInfo, Operation *op,
                                    bool *changed = nullptr);
 
+//===----------------------------------------------------------------------===//
+// CSEInterface
+//===----------------------------------------------------------------------===//
+
+/// This is the interface that allows users to customize CSE.
+class DialectCSEInterface : public DialectInterface::Base<DialectCSEInterface> {
+public:
+  DialectCSEInterface(Dialect *dialect) : Base(dialect) {}
+
+  //===--------------------------------------------------------------------===//
+  // Analysis Hooks
+  //===--------------------------------------------------------------------===//
+
+  /// These two hooks are called in DenseMapInfo used by CSE.
+
+  /// Returns a hash for the operation.
+  /// CSE will use default implementation if `std::nullopt` is returned.
+  virtual std::optional<unsigned> getHashValue(Operation *op) const = 0;
+
+  /// Returns true if two operations are considered to be equivalent.
+  /// CSE will use default implementation if `std::nullopt` is returned.
+  virtual std::optional<bool> isEqual(Operation *lhs, Operation *rhs) const = 0;
+
+  //===--------------------------------------------------------------------===//
+  // Transformation Hooks
+  //===--------------------------------------------------------------------===//
+
+  /// This hook is called when 'op' is considered to be a common subexpression
+  /// of 'existingOp' and is going to be eliminated. This hook allows users to
+  /// propagate information of 'op' to 'existingOp'. Note that the hash value of
+  /// 'existingOp' must not be changed due the mutation of 'existingOp'.
+  virtual void mergeOperations(Operation *existingOp, Operation *op) const {}
+};
+
 } // namespace mlir
 
 #endif // MLIR_TRANSFORMS_CSE_H_
diff --git a/mlir/lib/Transforms/CSE.cpp b/mlir/lib/Transforms/CSE.cpp
index 3affd88d158de59..771467429c8a8e2 100644
--- a/mlir/lib/Transforms/CSE.cpp
+++ b/mlir/lib/Transforms/CSE.cpp
@@ -35,8 +35,16 @@ using namespace mlir;
 namespace {
 struct SimpleOperationInfo : public llvm::DenseMapInfo<Operation *> {
   static unsigned getHashValue(const Operation *opC) {
+    auto *op = const_cast<Operation *>(opC);
+    // Use a custom hook if provided.
+    if (auto *interface = dyn_cast<DialectCSEInterface>(op->getDialect())) {
+      std::optional<unsigned> val = interface->getHashValue(op);
+      if (val)
+        return *val;
+    }
+
     return OperationEquivalence::computeHash(
-        const_cast<Operation *>(opC),
+        op,
         /*hashOperands=*/OperationEquivalence::directHashValue,
         /*hashResults=*/OperationEquivalence::ignoreHashValue,
         OperationEquivalence::IgnoreLocations);
@@ -49,6 +57,18 @@ struct SimpleOperationInfo : public llvm::DenseMapInfo<Operation *> {
     if (lhs == getTombstoneKey() || lhs == getEmptyKey() ||
         rhs == getTombstoneKey() || rhs == getEmptyKey())
       return false;
+
+    if (lhs->getDialect() != rhs->getDialect())
+      return false;
+
+    if (auto *interface = dyn_cast<DialectCSEInterface>(lhs->getDialect())) {
+      std::optional<bool> val = interface->isEqual(lhs, rhs);
+      assert(val == interface->isEqual(rhs, lhs) &&
+             "DialectCSEInterface::isEqual must be symmetrical");
+      if (val)
+        return *val;
+    }
+
     return OperationEquivalence::isEquivalentTo(
         const_cast<Operation *>(lhsC), const_cast<Operation *>(rhsC),
         OperationEquivalence::IgnoreLocations);
@@ -131,6 +151,18 @@ class CSEDriver {
 void CSEDriver::replaceUsesAndDelete(ScopedMapTy &knownValues, Operation *op,
                                      Operation *existing,
                                      bool hasSSADominance) {
+  // Invoke a callback provided by CSE interface.
+  if (auto *cseInterface = dyn_cast<DialectCSEInterface>(op->getDialect())) {
+#ifndef NDEBUG
+    auto hashPrev = cseInterface->getHashValue(existing);
+#endif
+    cseInterface->mergeOperations(existing, op);
+#ifndef NDEBUG
+    assert(hashPrev == cseInterface->getHashValue(existing) &&
+           "hash values must not be modified by `mergeOperations`");
+#endif
+  }
+
   // If we find one then replace all uses of the current operation with the
   // existing one and mark it for deletion. We can only replace an operand in
   // an operation if it has not been visited yet.
diff --git a/mlir/test/Transforms/cse.mlir b/mlir/test/Transforms/cse.mlir
index c764d2b9bd57d8a..1e348f2d480c9eb 100644
--- a/mlir/test/Transforms/cse.mlir
+++ b/mlir/test/Transforms/cse.mlir
@@ -520,3 +520,14 @@ func.func @cse_recursive_effects_failure() -> (i32, i32, i32) {
   %2 = "test.op_with_memread"() : () -> (i32)
   return %0, %2, %1 : i32, i32, i32
 }
+
+func.func @cse_test_dialect_interface() -> (i32, i32, i32) {
+  %0 = "test.always_speculatable_op"() {test.non_essential = "b"} : () -> i32
+  %1 = "test.always_speculatable_op"() {test.non_essential = "a"} : () -> i32
+  %2 = "test.always_speculatable_op"() {test.non_essential = "c"} : () -> i32
+  return %0, %1, %2 : i32, i32, i32
+}
+
+// CHECK-LABEL: @cse_test_dialect_interface
+// CHECK-NEXT: %[[result:.+]] = "test.always_speculatable_op"() {test.non_essential = "a"} : () -> i32
+// CHECK-NEXT: return %[[result]], %[[result]], %[[result]]
diff --git a/mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp b/mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp
index 80ddcdc8ea69f7c..6f5d86f3b9bc2ca 100644
--- a/mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp
+++ b/mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp
@@ -9,6 +9,7 @@
 #include "TestDialect.h"
 #include "mlir/Interfaces/FoldInterfaces.h"
 #include "mlir/Reducer/ReductionPatternInterface.h"
+#include "mlir/Transforms/CSE.h"
 #include "mlir/Transforms/InliningUtils.h"
 
 using namespace mlir;
@@ -385,6 +386,90 @@ struct TestReductionPatternInterface : public DialectReductionPatternInterface {
   }
 };
 
+/// This class defines the interface for customizing CSE.
+struct TestCSEInterface : public DialectCSEInterface {
+  using DialectCSEInterface::DialectCSEInterface;
+
+  StringRef nonEssentialAttrName = "test.non_essential";
+
+  //===--------------------------------------------------------------------===//
+  // Analysis Hooks
+  //===--------------------------------------------------------------------===//
+
+  /// Return a hash that excludes 'test.non_essential' attributes.
+  std::optional<unsigned> getHashValue(Operation *op) const override {
+    auto attr = op->getDiscardableAttrDictionary();
+    if (!attr.contains(nonEssentialAttrName))
+      return std::nullopt;
+
+    auto hashOp = [&](Operation *op) {
+      auto hash = llvm::hash_combine(op->getName(), op->getResultTypes(),
+                                     op->hashProperties());
+      NamedAttrList attributes(attr);
+      attributes.erase(nonEssentialAttrName);
+      return llvm::hash_combine(hash,
+                                attributes.getDictionary(op->getContext()));
+    };
+
+    return OperationEquivalence::computeHash(
+        op,
+        /*hashOp=*/hashOp,
+        /*hashOperands=*/OperationEquivalence::directHashValue,
+        /*hashResults=*/OperationEquivalence::ignoreHashValue,
+        OperationEquivalence::IgnoreLocations);
+  }
+
+  /// Return true if operations are same except for 'test.non_essential'
+  /// attributes.
+  std::optional<bool> isEqual(Operation *lhs, Operation *rhs) const override {
+    if (!lhs->getDiscardableAttrDictionary().contains(nonEssentialAttrName) ||
+        !rhs->getDiscardableAttrDictionary().contains(nonEssentialAttrName))
+      return std::nullopt;
+
+    auto checkOp = [&](Operation *lhs, Operation *rhs) -> LogicalResult {
+      bool result =
+          lhs->getName() == rhs->getName() &&
+          lhs->getNumRegions() == rhs->getNumRegions() &&
+          lhs->getNumSuccessors() == rhs->getNumSuccessors() &&
+          lhs->getNumOperands() == rhs->getNumOperands() &&
+          lhs->getNumResults() == rhs->getNumResults() &&
+          lhs->getName().compareOpProperties(lhs->getPropertiesStorage(),
+                                             rhs->getPropertiesStorage());
+      if (!result)
+        return failure();
+      auto lhsAttr = lhs->getDiscardableAttrs();
+      auto rhsAttr = rhs->getDiscardableAttrs();
+      NamedAttrList lhsFiltered(lhsAttr);
+      lhsFiltered.erase(nonEssentialAttrName);
+      NamedAttrList rhsFiltered(rhsAttr);
+      rhsFiltered.erase(nonEssentialAttrName);
+      return LogicalResult::success(lhsFiltered == rhsFiltered);
+    };
+
+    return OperationEquivalence::isEquivalentTo(
+        lhs, rhs, checkOp, OperationEquivalence::IgnoreLocations);
+  };
+
+  //===--------------------------------------------------------------------===//
+  // Transformation Hooks
+  //===--------------------------------------------------------------------===//
+
+  /// Propagate 'test.non_essential' to an existing op. Use a smaller value if
+  /// both have the attribute.
+  void mergeOperations(Operation *existingOp,
+                       Operation *opsToBeDeleted) const override {
+    if (auto rhs =
+            opsToBeDeleted->getAttrOfType<StringAttr>(nonEssentialAttrName))
+      if (auto lhs =
+              existingOp->getAttrOfType<StringAttr>(nonEssentialAttrName)) {
+        if (rhs.getValue() < lhs.getValue()) {
+          existingOp->setAttr(nonEssentialAttrName, rhs);
+        }
+      } else {
+        existingOp->setAttr(nonEssentialAttrName, rhs);
+      }
+  }
+};
 } // namespace
 
 void TestDialect::registerInterfaces() {
@@ -392,5 +477,6 @@ void TestDialect::registerInterfaces() {
   addInterface<TestOpAsmInterface>(blobInterface);
 
   addInterfaces<TestDialectFoldInterface, TestInlinerInterface,
-                TestReductionPatternInterface, TestBytecodeDialectInterface>();
+                TestReductionPatternInterface, TestCSEInterface,
+                TestBytecodeDialectInterface>();
 }



More information about the Mlir-commits mailing list