[flang-commits] [mlir] [flang] [mlir] Handle simple commutative cases in CSE. (PR #75274)

Jacques Pienaar via flang-commits flang-commits at lists.llvm.org
Wed Dec 13 21:11:36 PST 2023


https://github.com/jpienaar updated https://github.com/llvm/llvm-project/pull/75274

>From d3cefdd3d0adbd4e307cb4b71c98b81fe49ce08b Mon Sep 17 00:00:00 2001
From: Jacques Pienaar <jpienaar at google.com>
Date: Tue, 12 Dec 2023 18:30:57 -0800
Subject: [PATCH 1/3] [mlir] Handle simple commutative cases in CSE.

Tried to keep this simple while handling obvious CSE instances. For more complicated cases the expectation is still that the sorting pass would run before. While simple, this case did turn up in a real deployed instance where it had a large (>10% e2e) impact. This can of course be refined.
---
 mlir/include/mlir/IR/OperationSupport.h |  8 ++-
 mlir/lib/IR/OperationSupport.cpp        | 85 +++++++++++++++++++++----
 2 files changed, 77 insertions(+), 16 deletions(-)

diff --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h
index 6a5ec129ad564b..ba66dffeeb8e96 100644
--- a/mlir/include/mlir/IR/OperationSupport.h
+++ b/mlir/include/mlir/IR/OperationSupport.h
@@ -1271,7 +1271,9 @@ struct OperationEquivalence {
   isEquivalentTo(Operation *lhs, Operation *rhs,
                  function_ref<LogicalResult(Value, Value)> checkEquivalent,
                  function_ref<void(Value, Value)> markEquivalent = nullptr,
-                 Flags flags = Flags::None);
+                 Flags flags = Flags::None,
+                 function_ref<LogicalResult(ValueRange, ValueRange)>
+                     checkCommutativeEquivalent = nullptr);
 
   /// Compare two operations and return if they are equivalent.
   static bool isEquivalentTo(Operation *lhs, Operation *rhs, Flags flags);
@@ -1282,7 +1284,9 @@ struct OperationEquivalence {
       Region *lhs, Region *rhs,
       function_ref<LogicalResult(Value, Value)> checkEquivalent,
       function_ref<void(Value, Value)> markEquivalent,
-      OperationEquivalence::Flags flags);
+      OperationEquivalence::Flags flags,
+      function_ref<LogicalResult(ValueRange, ValueRange)>
+          checkCommutativeEquivalent = nullptr);
 
   /// Compare two regions and return if they are equivalent.
   static bool isRegionEquivalentTo(Region *lhs, Region *rhs,
diff --git a/mlir/lib/IR/OperationSupport.cpp b/mlir/lib/IR/OperationSupport.cpp
index fc5ccd23b5108d..630a3bc5016fff 100644
--- a/mlir/lib/IR/OperationSupport.cpp
+++ b/mlir/lib/IR/OperationSupport.cpp
@@ -683,8 +683,16 @@ llvm::hash_code OperationEquivalence::computeHash(
     hash = llvm::hash_combine(hash, op->getLoc());
 
   //   - Operands
-  for (Value operand : op->getOperands())
-    hash = llvm::hash_combine(hash, hashOperands(operand));
+  if (op->hasTrait<mlir::OpTrait::IsCommutative>() && op->getNumOperands() > 0) {
+    // If commutative, don't hash the operands as hash is not order independent
+    // and even if it were would not be sufficient for CSE usage.
+    // FIXME: This has the effect of resulting in more hash collisions
+    // for the sake of CSE, this could be improved.
+    hash = llvm::hash_combine(hash, op->getNumOperands());
+  } else {
+    for (Value operand : op->getOperands())
+      hash = llvm::hash_combine(hash, hashOperands(operand));
+  }
 
   //   - Results
   for (Value result : op->getResults())
@@ -696,7 +704,9 @@ llvm::hash_code OperationEquivalence::computeHash(
     Region *lhs, Region *rhs,
     function_ref<LogicalResult(Value, Value)> checkEquivalent,
     function_ref<void(Value, Value)> markEquivalent,
-    OperationEquivalence::Flags flags) {
+    OperationEquivalence::Flags flags,
+    function_ref<LogicalResult(ValueRange, ValueRange)>
+        checkCommutativeEquivalent) {
   DenseMap<Block *, Block *> blocksMap;
   auto blocksEquivalent = [&](Block &lBlock, Block &rBlock) {
     // Check block arguments.
@@ -751,6 +761,36 @@ struct ValueEquivalenceCache {
     return success(lhsValue == rhsValue ||
                    equivalentValues.lookup(lhsValue) == rhsValue);
   }
+  LogicalResult checkCommutativeEquivalent(ValueRange lhsRange,
+                                           ValueRange rhsRange) {
+    // Handle simple case where sizes mismatch.
+    if (lhsRange.size() != rhsRange.size())
+      return failure();
+
+    // Handle where operands in order are equivalent.
+    auto lhsIt = lhsRange.begin();
+    auto rhsIt = rhsRange.begin();
+    for (; lhsIt != lhsRange.end(); ++lhsIt, ++rhsIt) {
+      if (failed(checkEquivalent(*lhsIt, *rhsIt)))
+        break;
+    }
+    if (lhsIt == lhsRange.end())
+      return success();
+
+    // Handle another simple case where operands are just a permutation.
+    // Note: This is not sufficient, this handles simple cases relatively
+    // cheaply.
+    auto sortValues = [](ValueRange values) {
+      SmallVector<Value> sortedValues = llvm::to_vector(values);
+      llvm::sort(sortedValues, [](Value a, Value b) {
+        return a.getAsOpaquePointer() < b.getAsOpaquePointer();
+      });
+      return sortedValues;
+    };
+    auto lhsSorted = sortValues({lhsIt, lhsRange.end()});
+    auto rhsSorted = sortValues({rhsIt, rhsRange.end()});
+    return success(lhsSorted == rhsSorted);
+  }
   void markEquivalent(Value lhsResult, Value rhsResult) {
     auto insertion = equivalentValues.insert({lhsResult, rhsResult});
     // Make sure that the value was not already marked equivalent to some other
@@ -773,13 +813,18 @@ OperationEquivalence::isRegionEquivalentTo(Region *lhs, Region *rhs,
       [&](Value lhsResult, Value rhsResult) {
         cache.markEquivalent(lhsResult, rhsResult);
       },
-      flags);
+      flags,
+      [&](ValueRange lhs, ValueRange rhs) -> LogicalResult {
+        return cache.checkCommutativeEquivalent(lhs, rhs);
+      });
 }
 
 /*static*/ bool OperationEquivalence::isEquivalentTo(
     Operation *lhs, Operation *rhs,
     function_ref<LogicalResult(Value, Value)> checkEquivalent,
-    function_ref<void(Value, Value)> markEquivalent, Flags flags) {
+    function_ref<void(Value, Value)> markEquivalent, Flags flags,
+    function_ref<LogicalResult(ValueRange, ValueRange)>
+        checkCommutativeEquivalent) {
   if (lhs == rhs)
     return true;
 
@@ -798,15 +843,24 @@ OperationEquivalence::isRegionEquivalentTo(Region *lhs, Region *rhs,
     return false;
 
   // 2. Compare operands.
-  for (auto operandPair : llvm::zip(lhs->getOperands(), rhs->getOperands())) {
-    Value curArg = std::get<0>(operandPair);
-    Value otherArg = std::get<1>(operandPair);
-    if (curArg == otherArg)
-      continue;
-    if (curArg.getType() != otherArg.getType())
-      return false;
-    if (failed(checkEquivalent(curArg, otherArg)))
+  if (checkCommutativeEquivalent &&
+      lhs->hasTrait<mlir::OpTrait::IsCommutative>()) {
+    auto lhsRange = lhs->getOperands();
+    auto rhsRange = rhs->getOperands();
+    if (failed(checkCommutativeEquivalent(lhsRange, rhsRange)))
       return false;
+  } else {
+    // Check pair wise for equivalence.
+    for (auto operandPair : llvm::zip(lhs->getOperands(), rhs->getOperands())) {
+      Value curArg = std::get<0>(operandPair);
+      Value otherArg = std::get<1>(operandPair);
+      if (curArg == otherArg)
+        continue;
+      if (curArg.getType() != otherArg.getType())
+        return false;
+      if (failed(checkEquivalent(curArg, otherArg)))
+        return false;
+    }
   }
 
   // 3. Compare result types and mark results as equivalent.
@@ -841,7 +895,10 @@ OperationEquivalence::isRegionEquivalentTo(Region *lhs, Region *rhs,
       [&](Value lhsResult, Value rhsResult) {
         cache.markEquivalent(lhsResult, rhsResult);
       },
-      flags);
+      flags,
+      [&](ValueRange lhs, ValueRange rhs) -> LogicalResult {
+        return cache.checkCommutativeEquivalent(lhs, rhs);
+      });
 }
 
 //===----------------------------------------------------------------------===//

>From 6fc2f9d23168503abf5778084b09d1cc614cf3da Mon Sep 17 00:00:00 2001
From: Jacques Pienaar <jpienaar at google.com>
Date: Wed, 13 Dec 2023 13:28:44 +0000
Subject: [PATCH 2/3] Format & enable now passing test.

---
 flang/test/Fir/commute.fir       | 3 ---
 mlir/lib/IR/OperationSupport.cpp | 3 ++-
 2 files changed, 2 insertions(+), 4 deletions(-)

diff --git a/flang/test/Fir/commute.fir b/flang/test/Fir/commute.fir
index b39fe48145b3f3..be25fd61953deb 100644
--- a/flang/test/Fir/commute.fir
+++ b/flang/test/Fir/commute.fir
@@ -1,7 +1,4 @@
 // RUN: fir-opt %s | tco | FileCheck %s
-//
-// XFAIL:*
-// See: https://github.com/llvm/llvm-project/issues/63784
 
 // CHECK-LABEL: define i32 @f1(i32 %0, i32 %1)
 func.func @f1(%a : i32, %b : i32) -> i32 {
diff --git a/mlir/lib/IR/OperationSupport.cpp b/mlir/lib/IR/OperationSupport.cpp
index 630a3bc5016fff..14e21dc1d1a87c 100644
--- a/mlir/lib/IR/OperationSupport.cpp
+++ b/mlir/lib/IR/OperationSupport.cpp
@@ -683,7 +683,8 @@ llvm::hash_code OperationEquivalence::computeHash(
     hash = llvm::hash_combine(hash, op->getLoc());
 
   //   - Operands
-  if (op->hasTrait<mlir::OpTrait::IsCommutative>() && op->getNumOperands() > 0) {
+  if (op->hasTrait<mlir::OpTrait::IsCommutative>() &&
+      op->getNumOperands() > 0) {
     // If commutative, don't hash the operands as hash is not order independent
     // and even if it were would not be sufficient for CSE usage.
     // FIXME: This has the effect of resulting in more hash collisions

>From 1a9bbe230c3dcab7a5394b85edf70c547c2df6b4 Mon Sep 17 00:00:00 2001
From: Jacques Pienaar <jpienaar at google.com>
Date: Thu, 14 Dec 2023 05:10:07 +0000
Subject: [PATCH 3/3] Fix missing compare passed in.

---
 mlir/lib/IR/OperationSupport.cpp | 3 ++-
 1 file changed, 2 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/IR/OperationSupport.cpp b/mlir/lib/IR/OperationSupport.cpp
index 14e21dc1d1a87c..19f52d7ce42eb0 100644
--- a/mlir/lib/IR/OperationSupport.cpp
+++ b/mlir/lib/IR/OperationSupport.cpp
@@ -736,7 +736,8 @@ llvm::hash_code OperationEquivalence::computeHash(
     auto opsEquivalent = [&](Operation &lOp, Operation &rOp) {
       // Check for op equality (recursively).
       if (!OperationEquivalence::isEquivalentTo(&lOp, &rOp, checkEquivalent,
-                                                markEquivalent, flags))
+                                                markEquivalent, flags,
+                                                checkCommutativeEquivalent))
         return false;
       // Check successor mapping.
       for (auto successorsPair :



More information about the flang-commits mailing list