[flang-commits] [flang] ee2deb4 - [mlir] Handle simple commutative cases in CSE.

Jacques Pienaar via flang-commits flang-commits at lists.llvm.org
Thu Dec 14 16:09:13 PST 2023


Author: Jacques Pienaar
Date: 2023-12-14T16:09:05-08:00
New Revision: ee2deb4cf70e9468510750070435f29ca7482d02

URL: https://github.com/llvm/llvm-project/commit/ee2deb4cf70e9468510750070435f29ca7482d02
DIFF: https://github.com/llvm/llvm-project/commit/ee2deb4cf70e9468510750070435f29ca7482d02.diff

LOG: [mlir] Handle simple commutative cases in CSE.

Tried to keep this simple while handling obvious CSE instances. For more
complicated cases the expectation is still that the sorting pass would
run before. While simple, this case did turn up in a real deployed
instance where it had a large (>10% e2e) impact. This can of course be
refined.

Added: 
    

Modified: 
    flang/test/Fir/commute.fir
    mlir/include/mlir/IR/OperationSupport.h
    mlir/lib/IR/OperationSupport.cpp

Removed: 
    


################################################################################
diff  --git a/flang/test/Fir/commute.fir b/flang/test/Fir/commute.fir
index b39fe48145b3f3..be25fd61953deb 100644
--- a/flang/test/Fir/commute.fir
+++ b/flang/test/Fir/commute.fir
@@ -1,7 +1,4 @@
 // RUN: fir-opt %s | tco | FileCheck %s
-//
-// XFAIL:*
-// See: https://github.com/llvm/llvm-project/issues/63784
 
 // CHECK-LABEL: define i32 @f1(i32 %0, i32 %1)
 func.func @f1(%a : i32, %b : i32) -> i32 {

diff  --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h
index 6a5ec129ad564b..478504b8f76e7a 100644
--- a/mlir/include/mlir/IR/OperationSupport.h
+++ b/mlir/include/mlir/IR/OperationSupport.h
@@ -1261,6 +1261,10 @@ struct OperationEquivalence {
   ///   value or this callback must return `success`.
   /// * `markEquivalent` is a callback to inform the caller that the analysis
   ///   determined that two values are equivalent.
+  /// * `checkCommutativeEquivalent` is an optional callback to check for
+  ///   equivalence across two ranges for a commutative operation. If not passed
+  ///   in, then equivalence is checked pairwise. This callback is needed to be
+  ///   able to query the optional equivalence classes.
   ///
   /// Note: Additional information regarding value equivalence can be injected
   /// into the analysis via `checkEquivalent`. Typically, callers may want
@@ -1271,7 +1275,9 @@ struct OperationEquivalence {
   isEquivalentTo(Operation *lhs, Operation *rhs,
                  function_ref<LogicalResult(Value, Value)> checkEquivalent,
                  function_ref<void(Value, Value)> markEquivalent = nullptr,
-                 Flags flags = Flags::None);
+                 Flags flags = Flags::None,
+                 function_ref<LogicalResult(ValueRange, ValueRange)>
+                     checkCommutativeEquivalent = nullptr);
 
   /// Compare two operations and return if they are equivalent.
   static bool isEquivalentTo(Operation *lhs, Operation *rhs, Flags flags);
@@ -1282,7 +1288,9 @@ struct OperationEquivalence {
       Region *lhs, Region *rhs,
       function_ref<LogicalResult(Value, Value)> checkEquivalent,
       function_ref<void(Value, Value)> markEquivalent,
-      OperationEquivalence::Flags flags);
+      OperationEquivalence::Flags flags,
+      function_ref<LogicalResult(ValueRange, ValueRange)>
+          checkCommutativeEquivalent = nullptr);
 
   /// Compare two regions and return if they are equivalent.
   static bool isRegionEquivalentTo(Region *lhs, Region *rhs,

diff  --git a/mlir/lib/IR/OperationSupport.cpp b/mlir/lib/IR/OperationSupport.cpp
index fc5ccd23b5108d..eaad6f2891608f 100644
--- a/mlir/lib/IR/OperationSupport.cpp
+++ b/mlir/lib/IR/OperationSupport.cpp
@@ -683,8 +683,16 @@ llvm::hash_code OperationEquivalence::computeHash(
     hash = llvm::hash_combine(hash, op->getLoc());
 
   //   - Operands
-  for (Value operand : op->getOperands())
-    hash = llvm::hash_combine(hash, hashOperands(operand));
+  if (op->hasTrait<mlir::OpTrait::IsCommutative>() &&
+      op->getNumOperands() > 0) {
+    size_t operandHash = hashOperands(op->getOperand(0));
+    for (auto operand : op->getOperands().drop_front())
+      operandHash += hashOperands(operand);
+    hash = llvm::hash_combine(hash, operandHash);
+  } else {
+    for (Value operand : op->getOperands())
+      hash = llvm::hash_combine(hash, hashOperands(operand));
+  }
 
   //   - Results
   for (Value result : op->getResults())
@@ -696,7 +704,9 @@ llvm::hash_code OperationEquivalence::computeHash(
     Region *lhs, Region *rhs,
     function_ref<LogicalResult(Value, Value)> checkEquivalent,
     function_ref<void(Value, Value)> markEquivalent,
-    OperationEquivalence::Flags flags) {
+    OperationEquivalence::Flags flags,
+    function_ref<LogicalResult(ValueRange, ValueRange)>
+        checkCommutativeEquivalent) {
   DenseMap<Block *, Block *> blocksMap;
   auto blocksEquivalent = [&](Block &lBlock, Block &rBlock) {
     // Check block arguments.
@@ -725,7 +735,8 @@ llvm::hash_code OperationEquivalence::computeHash(
     auto opsEquivalent = [&](Operation &lOp, Operation &rOp) {
       // Check for op equality (recursively).
       if (!OperationEquivalence::isEquivalentTo(&lOp, &rOp, checkEquivalent,
-                                                markEquivalent, flags))
+                                                markEquivalent, flags,
+                                                checkCommutativeEquivalent))
         return false;
       // Check successor mapping.
       for (auto successorsPair :
@@ -751,6 +762,36 @@ struct ValueEquivalenceCache {
     return success(lhsValue == rhsValue ||
                    equivalentValues.lookup(lhsValue) == rhsValue);
   }
+  LogicalResult checkCommutativeEquivalent(ValueRange lhsRange,
+                                           ValueRange rhsRange) {
+    // Handle simple case where sizes mismatch.
+    if (lhsRange.size() != rhsRange.size())
+      return failure();
+
+    // Handle where operands in order are equivalent.
+    auto lhsIt = lhsRange.begin();
+    auto rhsIt = rhsRange.begin();
+    for (; lhsIt != lhsRange.end(); ++lhsIt, ++rhsIt) {
+      if (failed(checkEquivalent(*lhsIt, *rhsIt)))
+        break;
+    }
+    if (lhsIt == lhsRange.end())
+      return success();
+
+    // Handle another simple case where operands are just a permutation.
+    // Note: This is not sufficient, this handles simple cases relatively
+    // cheaply.
+    auto sortValues = [](ValueRange values) {
+      SmallVector<Value> sortedValues = llvm::to_vector(values);
+      llvm::sort(sortedValues, [](Value a, Value b) {
+        return a.getAsOpaquePointer() < b.getAsOpaquePointer();
+      });
+      return sortedValues;
+    };
+    auto lhsSorted = sortValues({lhsIt, lhsRange.end()});
+    auto rhsSorted = sortValues({rhsIt, rhsRange.end()});
+    return success(lhsSorted == rhsSorted);
+  }
   void markEquivalent(Value lhsResult, Value rhsResult) {
     auto insertion = equivalentValues.insert({lhsResult, rhsResult});
     // Make sure that the value was not already marked equivalent to some other
@@ -773,13 +814,18 @@ OperationEquivalence::isRegionEquivalentTo(Region *lhs, Region *rhs,
       [&](Value lhsResult, Value rhsResult) {
         cache.markEquivalent(lhsResult, rhsResult);
       },
-      flags);
+      flags,
+      [&](ValueRange lhs, ValueRange rhs) -> LogicalResult {
+        return cache.checkCommutativeEquivalent(lhs, rhs);
+      });
 }
 
 /*static*/ bool OperationEquivalence::isEquivalentTo(
     Operation *lhs, Operation *rhs,
     function_ref<LogicalResult(Value, Value)> checkEquivalent,
-    function_ref<void(Value, Value)> markEquivalent, Flags flags) {
+    function_ref<void(Value, Value)> markEquivalent, Flags flags,
+    function_ref<LogicalResult(ValueRange, ValueRange)>
+        checkCommutativeEquivalent) {
   if (lhs == rhs)
     return true;
 
@@ -798,15 +844,24 @@ OperationEquivalence::isRegionEquivalentTo(Region *lhs, Region *rhs,
     return false;
 
   // 2. Compare operands.
-  for (auto operandPair : llvm::zip(lhs->getOperands(), rhs->getOperands())) {
-    Value curArg = std::get<0>(operandPair);
-    Value otherArg = std::get<1>(operandPair);
-    if (curArg == otherArg)
-      continue;
-    if (curArg.getType() != otherArg.getType())
-      return false;
-    if (failed(checkEquivalent(curArg, otherArg)))
+  if (checkCommutativeEquivalent &&
+      lhs->hasTrait<mlir::OpTrait::IsCommutative>()) {
+    auto lhsRange = lhs->getOperands();
+    auto rhsRange = rhs->getOperands();
+    if (failed(checkCommutativeEquivalent(lhsRange, rhsRange)))
       return false;
+  } else {
+    // Check pair wise for equivalence.
+    for (auto operandPair : llvm::zip(lhs->getOperands(), rhs->getOperands())) {
+      Value curArg = std::get<0>(operandPair);
+      Value otherArg = std::get<1>(operandPair);
+      if (curArg == otherArg)
+        continue;
+      if (curArg.getType() != otherArg.getType())
+        return false;
+      if (failed(checkEquivalent(curArg, otherArg)))
+        return false;
+    }
   }
 
   // 3. Compare result types and mark results as equivalent.
@@ -841,7 +896,10 @@ OperationEquivalence::isRegionEquivalentTo(Region *lhs, Region *rhs,
       [&](Value lhsResult, Value rhsResult) {
         cache.markEquivalent(lhsResult, rhsResult);
       },
-      flags);
+      flags,
+      [&](ValueRange lhs, ValueRange rhs) -> LogicalResult {
+        return cache.checkCommutativeEquivalent(lhs, rhs);
+      });
 }
 
 //===----------------------------------------------------------------------===//


        


More information about the flang-commits mailing list