[Mlir-commits] [mlir] [mlir] Handle simple commutative cases in CSE. (PR #75274)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Dec 12 18:39:41 PST 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-core

Author: Jacques Pienaar (jpienaar)

<details>
<summary>Changes</summary>

Tried to keep this simple while handling obvious CSE instances. For more complicated cases the expectation is still that the sorting pass would run before. While simple, this case did turn up in a real deployed instance where it had a large e2e impact. This can of course be refined.

---
Full diff: https://github.com/llvm/llvm-project/pull/75274.diff


2 Files Affected:

- (modified) mlir/include/mlir/IR/OperationSupport.h (+6-2) 
- (modified) mlir/lib/IR/OperationSupport.cpp (+71-14) 


``````````diff
diff --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h
index 6a5ec129ad564..ba66dffeeb8e9 100644
--- a/mlir/include/mlir/IR/OperationSupport.h
+++ b/mlir/include/mlir/IR/OperationSupport.h
@@ -1271,7 +1271,9 @@ struct OperationEquivalence {
   isEquivalentTo(Operation *lhs, Operation *rhs,
                  function_ref<LogicalResult(Value, Value)> checkEquivalent,
                  function_ref<void(Value, Value)> markEquivalent = nullptr,
-                 Flags flags = Flags::None);
+                 Flags flags = Flags::None,
+                 function_ref<LogicalResult(ValueRange, ValueRange)>
+                     checkCommutativeEquivalent = nullptr);
 
   /// Compare two operations and return if they are equivalent.
   static bool isEquivalentTo(Operation *lhs, Operation *rhs, Flags flags);
@@ -1282,7 +1284,9 @@ struct OperationEquivalence {
       Region *lhs, Region *rhs,
       function_ref<LogicalResult(Value, Value)> checkEquivalent,
       function_ref<void(Value, Value)> markEquivalent,
-      OperationEquivalence::Flags flags);
+      OperationEquivalence::Flags flags,
+      function_ref<LogicalResult(ValueRange, ValueRange)>
+          checkCommutativeEquivalent = nullptr);
 
   /// Compare two regions and return if they are equivalent.
   static bool isRegionEquivalentTo(Region *lhs, Region *rhs,
diff --git a/mlir/lib/IR/OperationSupport.cpp b/mlir/lib/IR/OperationSupport.cpp
index fc5ccd23b5108..630a3bc5016ff 100644
--- a/mlir/lib/IR/OperationSupport.cpp
+++ b/mlir/lib/IR/OperationSupport.cpp
@@ -683,8 +683,16 @@ llvm::hash_code OperationEquivalence::computeHash(
     hash = llvm::hash_combine(hash, op->getLoc());
 
   //   - Operands
-  for (Value operand : op->getOperands())
-    hash = llvm::hash_combine(hash, hashOperands(operand));
+  if (op->hasTrait<mlir::OpTrait::IsCommutative>() && op->getNumOperands() > 0) {
+    // If commutative, don't hash the operands as hash is not order independent
+    // and even if it were would not be sufficient for CSE usage.
+    // FIXME: This has the effect of resulting in more hash collisions
+    // for the sake of CSE, this could be improved.
+    hash = llvm::hash_combine(hash, op->getNumOperands());
+  } else {
+    for (Value operand : op->getOperands())
+      hash = llvm::hash_combine(hash, hashOperands(operand));
+  }
 
   //   - Results
   for (Value result : op->getResults())
@@ -696,7 +704,9 @@ llvm::hash_code OperationEquivalence::computeHash(
     Region *lhs, Region *rhs,
     function_ref<LogicalResult(Value, Value)> checkEquivalent,
     function_ref<void(Value, Value)> markEquivalent,
-    OperationEquivalence::Flags flags) {
+    OperationEquivalence::Flags flags,
+    function_ref<LogicalResult(ValueRange, ValueRange)>
+        checkCommutativeEquivalent) {
   DenseMap<Block *, Block *> blocksMap;
   auto blocksEquivalent = [&](Block &lBlock, Block &rBlock) {
     // Check block arguments.
@@ -751,6 +761,36 @@ struct ValueEquivalenceCache {
     return success(lhsValue == rhsValue ||
                    equivalentValues.lookup(lhsValue) == rhsValue);
   }
+  LogicalResult checkCommutativeEquivalent(ValueRange lhsRange,
+                                           ValueRange rhsRange) {
+    // Handle simple case where sizes mismatch.
+    if (lhsRange.size() != rhsRange.size())
+      return failure();
+
+    // Handle where operands in order are equivalent.
+    auto lhsIt = lhsRange.begin();
+    auto rhsIt = rhsRange.begin();
+    for (; lhsIt != lhsRange.end(); ++lhsIt, ++rhsIt) {
+      if (failed(checkEquivalent(*lhsIt, *rhsIt)))
+        break;
+    }
+    if (lhsIt == lhsRange.end())
+      return success();
+
+    // Handle another simple case where operands are just a permutation.
+    // Note: This is not sufficient, this handles simple cases relatively
+    // cheaply.
+    auto sortValues = [](ValueRange values) {
+      SmallVector<Value> sortedValues = llvm::to_vector(values);
+      llvm::sort(sortedValues, [](Value a, Value b) {
+        return a.getAsOpaquePointer() < b.getAsOpaquePointer();
+      });
+      return sortedValues;
+    };
+    auto lhsSorted = sortValues({lhsIt, lhsRange.end()});
+    auto rhsSorted = sortValues({rhsIt, rhsRange.end()});
+    return success(lhsSorted == rhsSorted);
+  }
   void markEquivalent(Value lhsResult, Value rhsResult) {
     auto insertion = equivalentValues.insert({lhsResult, rhsResult});
     // Make sure that the value was not already marked equivalent to some other
@@ -773,13 +813,18 @@ OperationEquivalence::isRegionEquivalentTo(Region *lhs, Region *rhs,
       [&](Value lhsResult, Value rhsResult) {
         cache.markEquivalent(lhsResult, rhsResult);
       },
-      flags);
+      flags,
+      [&](ValueRange lhs, ValueRange rhs) -> LogicalResult {
+        return cache.checkCommutativeEquivalent(lhs, rhs);
+      });
 }
 
 /*static*/ bool OperationEquivalence::isEquivalentTo(
     Operation *lhs, Operation *rhs,
     function_ref<LogicalResult(Value, Value)> checkEquivalent,
-    function_ref<void(Value, Value)> markEquivalent, Flags flags) {
+    function_ref<void(Value, Value)> markEquivalent, Flags flags,
+    function_ref<LogicalResult(ValueRange, ValueRange)>
+        checkCommutativeEquivalent) {
   if (lhs == rhs)
     return true;
 
@@ -798,15 +843,24 @@ OperationEquivalence::isRegionEquivalentTo(Region *lhs, Region *rhs,
     return false;
 
   // 2. Compare operands.
-  for (auto operandPair : llvm::zip(lhs->getOperands(), rhs->getOperands())) {
-    Value curArg = std::get<0>(operandPair);
-    Value otherArg = std::get<1>(operandPair);
-    if (curArg == otherArg)
-      continue;
-    if (curArg.getType() != otherArg.getType())
-      return false;
-    if (failed(checkEquivalent(curArg, otherArg)))
+  if (checkCommutativeEquivalent &&
+      lhs->hasTrait<mlir::OpTrait::IsCommutative>()) {
+    auto lhsRange = lhs->getOperands();
+    auto rhsRange = rhs->getOperands();
+    if (failed(checkCommutativeEquivalent(lhsRange, rhsRange)))
       return false;
+  } else {
+    // Check pair wise for equivalence.
+    for (auto operandPair : llvm::zip(lhs->getOperands(), rhs->getOperands())) {
+      Value curArg = std::get<0>(operandPair);
+      Value otherArg = std::get<1>(operandPair);
+      if (curArg == otherArg)
+        continue;
+      if (curArg.getType() != otherArg.getType())
+        return false;
+      if (failed(checkEquivalent(curArg, otherArg)))
+        return false;
+    }
   }
 
   // 3. Compare result types and mark results as equivalent.
@@ -841,7 +895,10 @@ OperationEquivalence::isRegionEquivalentTo(Region *lhs, Region *rhs,
       [&](Value lhsResult, Value rhsResult) {
         cache.markEquivalent(lhsResult, rhsResult);
       },
-      flags);
+      flags,
+      [&](ValueRange lhs, ValueRange rhs) -> LogicalResult {
+        return cache.checkCommutativeEquivalent(lhs, rhs);
+      });
 }
 
 //===----------------------------------------------------------------------===//

``````````

</details>


https://github.com/llvm/llvm-project/pull/75274


More information about the Mlir-commits mailing list