[Mlir-commits] [mlir] c864288 - [mlir][transforms] Simplify OperationEquivalence and CSE

Matthias Springer llvmlistbot at llvm.org
Fri Jan 27 02:57:20 PST 2023


Author: Matthias Springer
Date: 2023-01-27T11:56:48+01:00
New Revision: c864288dcf169fd1aa663320b7414c0ee72d5c2b

URL: https://github.com/llvm/llvm-project/commit/c864288dcf169fd1aa663320b7414c0ee72d5c2b
DIFF: https://github.com/llvm/llvm-project/commit/c864288dcf169fd1aa663320b7414c0ee72d5c2b.diff

LOG: [mlir][transforms] Simplify OperationEquivalence and CSE

Replace `mapOperands` and `mapResults` with two new callbacks. It was not clear what "mapping" meant and why the equivalence relationship was a property of the Operand/OpResult as opposed to just SSA values.

This revision changes the contract of the two callbacks: `checkEquivalent` compares two values for equivalence. `markEquivalent` informs the caller that the analysis determined that two values are equivalent. This simplifies the API because callers do not have to reason about operands/results, but just SSA values.

`OperationEquivalence::isEquivalentTo` can be used directly in CSE and there is no need for a custom op equivalence analysis.

Differential Revision: https://reviews.llvm.org/D142558

Added: 
    

Modified: 
    mlir/include/mlir/IR/OperationSupport.h
    mlir/lib/IR/OperationSupport.cpp
    mlir/lib/Transforms/CSE.cpp
    mlir/lib/Transforms/Utils/RegionUtils.cpp
    mlir/test/lib/IR/TestOperationEquals.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h
index 99e63f207e96b..003029d6db94b 100644
--- a/mlir/include/mlir/IR/OperationSupport.h
+++ b/mlir/include/mlir/IR/OperationSupport.h
@@ -926,24 +926,36 @@ struct OperationEquivalence {
   /// operands/result mapping.
   static llvm::hash_code directHashValue(Value v) { return hash_value(v); }
 
-  /// Compare two operations and return if they are equivalent.
-  /// `mapOperands` and `mapResults` are optional callbacks that allows the
-  /// caller to check the mapping of SSA value between the lhs and rhs
-  /// operations. It is expected to return success if the mapping is valid and
-  /// failure if it conflicts with a previous mapping.
+  /// Compare two operations (including their regions) and return if they are
+  /// equivalent.
+  ///
+  /// * `checkEquivalent` is a callback to check if two values are equivalent.
+  ///   For two operations to be equivalent, their operands must be the same SSA
+  ///   value or this callback must return `success`.
+  /// * `markEquivalent` is a callback to inform the caller that the analysis
+  ///   determined that two values are equivalent.
+  ///
+  /// Note: Additional information regarding value equivalence can be injected
+  /// into the analysis via `checkEquivalent`. Typically, callers may want
+  /// values that were determined to be equivalent as per `markEquivalent` to be
+  /// reflected in `checkEquivalent`, unless `exactValueMatch` or a 
diff erent
+  /// equivalence relationship is desired.
   static bool
   isEquivalentTo(Operation *lhs, Operation *rhs,
-                 function_ref<LogicalResult(Value, Value)> mapOperands,
-                 function_ref<LogicalResult(Value, Value)> mapResults,
+                 function_ref<LogicalResult(Value, Value)> checkEquivalent,
+                 function_ref<void(Value, Value)> markEquivalent = nullptr,
                  Flags flags = Flags::None);
 
-  /// Helper that can be used with `isEquivalentTo` above to ignore operation
-  /// operands/result mapping.
+  /// Compare two operations and return if they are equivalent.
+  static bool isEquivalentTo(Operation *lhs, Operation *rhs, Flags flags);
+
+  /// Helper that can be used with `isEquivalentTo` above to consider ops
+  /// equivalent even if their operands are not equivalent.
   static LogicalResult ignoreValueEquivalence(Value lhs, Value rhs) {
     return success();
   }
-  /// Helper that can be used with `isEquivalentTo` above to ignore operation
-  /// operands/result mapping.
+  /// Helper that can be used with `isEquivalentTo` above to consider ops
+  /// equivalent only if their operands are the exact same SSA values.
   static LogicalResult exactValueMatch(Value lhs, Value rhs) {
     return success(lhs == rhs);
   }

diff  --git a/mlir/lib/IR/OperationSupport.cpp b/mlir/lib/IR/OperationSupport.cpp
index 8f3d37976d2a8..20ce9b36737fb 100644
--- a/mlir/lib/IR/OperationSupport.cpp
+++ b/mlir/lib/IR/OperationSupport.cpp
@@ -652,8 +652,8 @@ llvm::hash_code OperationEquivalence::computeHash(
 
 static bool
 isRegionEquivalentTo(Region *lhs, Region *rhs,
-                     function_ref<LogicalResult(Value, Value)> mapOperands,
-                     function_ref<LogicalResult(Value, Value)> mapResults,
+                     function_ref<LogicalResult(Value, Value)> checkEquivalent,
+                     function_ref<void(Value, Value)> markEquivalent,
                      OperationEquivalence::Flags flags) {
   DenseMap<Block *, Block *> blocksMap;
   auto blocksEquivalent = [&](Block &lBlock, Block &rBlock) {
@@ -675,15 +675,15 @@ isRegionEquivalentTo(Region *lhs, Region *rhs,
       if (!(flags & OperationEquivalence::IgnoreLocations) &&
           curArg.getLoc() != otherArg.getLoc())
         return false;
-      // Check if this value was already mapped to another value.
-      if (failed(mapOperands(curArg, otherArg)))
-        return false;
+      // Corresponding bbArgs are equivalent.
+      if (markEquivalent)
+        markEquivalent(curArg, otherArg);
     }
 
     auto opsEquivalent = [&](Operation &lOp, Operation &rOp) {
       // Check for op equality (recursively).
-      if (!OperationEquivalence::isEquivalentTo(&lOp, &rOp, mapOperands,
-                                                mapResults, flags))
+      if (!OperationEquivalence::isEquivalentTo(&lOp, &rOp, checkEquivalent,
+                                                markEquivalent, flags))
         return false;
       // Check successor mapping.
       for (auto successorsPair :
@@ -703,12 +703,12 @@ isRegionEquivalentTo(Region *lhs, Region *rhs,
 
 bool OperationEquivalence::isEquivalentTo(
     Operation *lhs, Operation *rhs,
-    function_ref<LogicalResult(Value, Value)> mapOperands,
-    function_ref<LogicalResult(Value, Value)> mapResults, Flags flags) {
+    function_ref<LogicalResult(Value, Value)> checkEquivalent,
+    function_ref<void(Value, Value)> markEquivalent, Flags flags) {
   if (lhs == rhs)
     return true;
 
-  // Compare the operation properties.
+  // 1. Compare the operation properties.
   if (lhs->getName() != rhs->getName() ||
       lhs->getAttrDictionary() != rhs->getAttrDictionary() ||
       lhs->getNumRegions() != rhs->getNumRegions() ||
@@ -719,6 +719,7 @@ bool OperationEquivalence::isEquivalentTo(
   if (!(flags & IgnoreLocations) && lhs->getLoc() != rhs->getLoc())
     return false;
 
+  // 2. Compare operands.
   ValueRange lhsOperands = lhs->getOperands(), rhsOperands = rhs->getOperands();
   SmallVector<Value> lhsOperandStorage, rhsOperandStorage;
   if (lhs->hasTrait<mlir::OpTrait::IsCommutative>()) {
@@ -752,32 +753,58 @@ bool OperationEquivalence::isEquivalentTo(
     rhsOperandStorage = sortValues(rhsOperands);
     rhsOperands = rhsOperandStorage;
   }
-  auto checkValueRangeMapping =
-      [](ValueRange lhs, ValueRange rhs,
-         function_ref<LogicalResult(Value, Value)> mapValues) {
-        for (auto operandPair : llvm::zip(lhs, rhs)) {
-          Value curArg = std::get<0>(operandPair);
-          Value otherArg = std::get<1>(operandPair);
-          if (curArg.getType() != otherArg.getType())
-            return false;
-          if (failed(mapValues(curArg, otherArg)))
-            return false;
-        }
-        return true;
-      };
-  // Check mapping of operands and results.
-  if (!checkValueRangeMapping(lhsOperands, rhsOperands, mapOperands))
-    return false;
-  if (!checkValueRangeMapping(lhs->getResults(), rhs->getResults(), mapResults))
-    return false;
+
+  for (auto operandPair : llvm::zip(lhsOperands, rhsOperands)) {
+    Value curArg = std::get<0>(operandPair);
+    Value otherArg = std::get<1>(operandPair);
+    if (curArg == otherArg)
+      continue;
+    if (curArg.getType() != otherArg.getType())
+      return false;
+    if (failed(checkEquivalent(curArg, otherArg)))
+      return false;
+  }
+
+  // 3. Compare result types and mark results as equivalent.
+  for (auto resultPair : llvm::zip(lhs->getResults(), rhs->getResults())) {
+    Value curArg = std::get<0>(resultPair);
+    Value otherArg = std::get<1>(resultPair);
+    if (curArg.getType() != otherArg.getType())
+      return false;
+    if (markEquivalent)
+      markEquivalent(curArg, otherArg);
+  }
+
+  // 4. Compare regions.
   for (auto regionPair : llvm::zip(lhs->getRegions(), rhs->getRegions()))
     if (!isRegionEquivalentTo(&std::get<0>(regionPair),
-                              &std::get<1>(regionPair), mapOperands, mapResults,
-                              flags))
+                              &std::get<1>(regionPair), checkEquivalent,
+                              markEquivalent, flags))
       return false;
+
   return true;
 }
 
+bool OperationEquivalence::isEquivalentTo(Operation *lhs, Operation *rhs,
+                                          Flags flags) {
+  // Equivalent values in lhs and rhs.
+  DenseMap<Value, Value> equivalentValues;
+  auto checkEquivalent = [&](Value lhsValue, Value rhsValue) -> LogicalResult {
+    return success(lhsValue == rhsValue ||
+                   equivalentValues.lookup(lhsValue) == rhsValue);
+  };
+  auto markEquivalent = [&](Value lhsResult, Value rhsResult) {
+    auto insertion = equivalentValues.insert({lhsResult, rhsResult});
+    // Make sure that the value was not already marked equivalent to some other
+    // value.
+    (void)insertion;
+    assert(insertion.first->second == rhsResult &&
+           "inconsistent OperationEquivalence state");
+  };
+  return OperationEquivalence::isEquivalentTo(lhs, rhs, checkEquivalent,
+                                              markEquivalent, flags);
+}
+
 //===----------------------------------------------------------------------===//
 // OperationFingerPrint
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Transforms/CSE.cpp b/mlir/lib/Transforms/CSE.cpp
index cd0f90b411f0a..86debe7271fc1 100644
--- a/mlir/lib/Transforms/CSE.cpp
+++ b/mlir/lib/Transforms/CSE.cpp
@@ -47,70 +47,9 @@ struct SimpleOperationInfo : public llvm::DenseMapInfo<Operation *> {
     if (lhs == getTombstoneKey() || lhs == getEmptyKey() ||
         rhs == getTombstoneKey() || rhs == getEmptyKey())
       return false;
-
-    // If op has no regions, operation equivalence w.r.t operands alone is
-    // enough.
-    if (lhs->getNumRegions() == 0 && rhs->getNumRegions() == 0) {
-      return OperationEquivalence::isEquivalentTo(
-          const_cast<Operation *>(lhsC), const_cast<Operation *>(rhsC),
-          OperationEquivalence::exactValueMatch,
-          OperationEquivalence::ignoreValueEquivalence,
-          OperationEquivalence::IgnoreLocations);
-    }
-
-    // If lhs or rhs does not have a single region with a single block, they
-    // aren't CSEed for now.
-    if (lhs->getNumRegions() != 1 || rhs->getNumRegions() != 1 ||
-        !llvm::hasSingleElement(lhs->getRegion(0)) ||
-        !llvm::hasSingleElement(rhs->getRegion(0)))
-      return false;
-
-    // Compare the two blocks.
-    Block &lhsBlock = lhs->getRegion(0).front();
-    Block &rhsBlock = rhs->getRegion(0).front();
-
-    // Don't CSE if number of arguments 
diff er.
-    if (lhsBlock.getNumArguments() != rhsBlock.getNumArguments())
-      return false;
-
-    // Map to store `Value`s from `lhsBlock` that are equivalent to `Value`s in
-    // `rhsBlock`. `Value`s from `lhsBlock` are the key.
-    DenseMap<Value, Value> areEquivalentValues;
-    for (auto bbArgs : llvm::zip(lhs->getRegion(0).getArguments(),
-                                 rhs->getRegion(0).getArguments())) {
-      areEquivalentValues[std::get<0>(bbArgs)] = std::get<1>(bbArgs);
-    }
-
-    // Helper function to get the parent operation.
-    auto getParent = [](Value v) -> Operation * {
-      if (auto blockArg = v.dyn_cast<BlockArgument>())
-        return blockArg.getParentBlock()->getParentOp();
-      return v.getDefiningOp()->getParentOp();
-    };
-
-    // Callback to compare if operands of ops in the region of `lhs` and `rhs`
-    // are equivalent.
-    auto mapOperands = [&](Value lhsValue, Value rhsValue) -> LogicalResult {
-      if (lhsValue == rhsValue)
-        return success();
-      if (areEquivalentValues.lookup(lhsValue) == rhsValue)
-        return success();
-      return failure();
-    };
-
-    // Callback to compare if results of ops in the region of `lhs` and `rhs`
-    // are equivalent.
-    auto mapResults = [&](Value lhsResult, Value rhsResult) -> LogicalResult {
-      if (getParent(lhsResult) == lhs && getParent(rhsResult) == rhs) {
-        auto insertion = areEquivalentValues.insert({lhsResult, rhsResult});
-        return success(insertion.first->second == rhsResult);
-      }
-      return success();
-    };
-
     return OperationEquivalence::isEquivalentTo(
         const_cast<Operation *>(lhsC), const_cast<Operation *>(rhsC),
-        mapOperands, mapResults, OperationEquivalence::IgnoreLocations);
+        OperationEquivalence::IgnoreLocations);
   }
 };
 } // namespace

diff  --git a/mlir/lib/Transforms/Utils/RegionUtils.cpp b/mlir/lib/Transforms/Utils/RegionUtils.cpp
index 996588243f565..e5824bd6663b0 100644
--- a/mlir/lib/Transforms/Utils/RegionUtils.cpp
+++ b/mlir/lib/Transforms/Utils/RegionUtils.cpp
@@ -493,7 +493,7 @@ LogicalResult BlockMergeCluster::addToCluster(BlockEquivalenceData &blockData) {
     // Check that the operations are equivalent.
     if (!OperationEquivalence::isEquivalentTo(
             &*lhsIt, &*rhsIt, OperationEquivalence::ignoreValueEquivalence,
-            OperationEquivalence::ignoreValueEquivalence,
+            /*markEquivalent=*/nullptr,
             OperationEquivalence::Flags::IgnoreLocations))
       return failure();
 

diff  --git a/mlir/test/lib/IR/TestOperationEquals.cpp b/mlir/test/lib/IR/TestOperationEquals.cpp
index 6bfb59ef55533..ef3589636e529 100644
--- a/mlir/test/lib/IR/TestOperationEquals.cpp
+++ b/mlir/test/lib/IR/TestOperationEquals.cpp
@@ -28,11 +28,6 @@ struct TestOperationEqualPass
                          << opCount;
       return signalPassFailure();
     }
-    DenseMap<Value, Value> valuesMap;
-    auto mapValue = [&](Value lhs, Value rhs) {
-      auto insertion = valuesMap.insert({lhs, rhs});
-      return success(insertion.first->second == rhs);
-    };
 
     Operation *first = &module.getBody()->front();
     llvm::outs() << first->getName().getStringRef() << " with attr "
@@ -41,7 +36,7 @@ struct TestOperationEqualPass
     if (!first->hasAttr("strict_loc_check"))
       flags |= OperationEquivalence::IgnoreLocations;
     if (OperationEquivalence::isEquivalentTo(first, &module.getBody()->back(),
-                                             mapValue, mapValue, flags))
+                                             flags))
       llvm::outs() << " compares equals.\n";
     else
       llvm::outs() << " compares NOT equals!\n";


        


More information about the Mlir-commits mailing list