[Mlir-commits] [mlir] 8a583dd - [MLIR] Add replaceUsesWithIf on Operation

Uday Bondhugula llvmlistbot at llvm.org
Mon Feb 20 20:45:31 PST 2023


Author: Uday Bondhugula
Date: 2023-02-21T10:10:22+05:30
New Revision: 8a583dd22012294e253f735bf274628791fecb33

URL: https://github.com/llvm/llvm-project/commit/8a583dd22012294e253f735bf274628791fecb33
DIFF: https://github.com/llvm/llvm-project/commit/8a583dd22012294e253f735bf274628791fecb33.diff

LOG: [MLIR] Add replaceUsesWithIf on Operation

Add replaceUsesWithIf on Operation along the lines of
Value::replaceUsesWithIf. This had been missing on Operation and is
convenient to replace multi-result operations' results conditionally.

Reviewed By: lattner

Differential Revision: https://reviews.llvm.org/D144348

Added: 
    

Modified: 
    mlir/include/mlir/IR/Operation.h
    mlir/include/mlir/IR/ValueRange.h
    mlir/lib/IR/OperationSupport.cpp
    mlir/lib/Transforms/CSE.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/IR/Operation.h b/mlir/include/mlir/IR/Operation.h
index e7ef01edef313..ac6bdfc265406 100644
--- a/mlir/include/mlir/IR/Operation.h
+++ b/mlir/include/mlir/IR/Operation.h
@@ -257,6 +257,15 @@ class alignas(8) Operation final
     getResults().replaceAllUsesWith(std::forward<ValuesT>(values));
   }
 
+  /// Replace uses of results of this operation with the provided `values` if
+  /// the given callback returns true.
+  template <typename ValuesT>
+  void replaceUsesWithIf(ValuesT &&values,
+                         function_ref<bool(OpOperand &)> shouldReplace) {
+    getResults().replaceUsesWithIf(std::forward<ValuesT>(values),
+                                   shouldReplace);
+  }
+
   /// Destroys this operation and its subclass data.
   void destroy();
 

diff  --git a/mlir/include/mlir/IR/ValueRange.h b/mlir/include/mlir/IR/ValueRange.h
index 8873260736d16..0f7354bbaef42 100644
--- a/mlir/include/mlir/IR/ValueRange.h
+++ b/mlir/include/mlir/IR/ValueRange.h
@@ -279,6 +279,26 @@ class ResultRange final
   /// Replace all uses of results of this range with results of 'op'.
   void replaceAllUsesWith(Operation *op);
 
+  /// Replace uses of results of this range with the provided 'values' if the
+  /// given callback returns true. The size of `values` must match the size of
+  /// this range.
+  template <typename ValuesT>
+  std::enable_if_t<!std::is_convertible<ValuesT, Operation *>::value>
+  replaceUsesWithIf(ValuesT &&values,
+                    function_ref<bool(OpOperand &)> shouldReplace) {
+    assert(static_cast<size_t>(std::distance(values.begin(), values.end())) ==
+               size() &&
+           "expected 'values' to correspond 1-1 with the number of results");
+
+    for (auto it : llvm::zip(*this, values))
+      std::get<0>(it).replaceUsesWithIf(std::get<1>(it), shouldReplace);
+  }
+
+  /// Replace uses of results of this range with results of `op` if the given
+  /// callback returns true.
+  void replaceUsesWithIf(Operation *op,
+                         function_ref<bool(OpOperand &)> shouldReplace);
+
   //===--------------------------------------------------------------------===//
   // Users
   //===--------------------------------------------------------------------===//

diff  --git a/mlir/lib/IR/OperationSupport.cpp b/mlir/lib/IR/OperationSupport.cpp
index 20ce9b36737fb..a38a12def567c 100644
--- a/mlir/lib/IR/OperationSupport.cpp
+++ b/mlir/lib/IR/OperationSupport.cpp
@@ -589,6 +589,11 @@ void ResultRange::replaceAllUsesWith(Operation *op) {
   replaceAllUsesWith(op->getResults());
 }
 
+void ResultRange::replaceUsesWithIf(
+    Operation *op, function_ref<bool(OpOperand &)> shouldReplace) {
+  replaceUsesWithIf(op->getResults(), shouldReplace);
+}
+
 //===----------------------------------------------------------------------===//
 // ValueRange
 

diff  --git a/mlir/lib/Transforms/CSE.cpp b/mlir/lib/Transforms/CSE.cpp
index 93e5c95cb1999..e98ccccd68d2d 100644
--- a/mlir/lib/Transforms/CSE.cpp
+++ b/mlir/lib/Transforms/CSE.cpp
@@ -124,12 +124,9 @@ void CSE::replaceUsesAndDelete(ScopedMapTy &knownValues, Operation *op,
   } else {
     // When the region does not have SSA dominance, we need to check if we
     // have visited a use before replacing any use.
-    for (auto it : llvm::zip(op->getResults(), existing->getResults())) {
-      std::get<0>(it).replaceUsesWithIf(
-          std::get<1>(it), [&](OpOperand &operand) {
-            return !knownValues.count(operand.getOwner());
-          });
-    }
+    op->replaceUsesWithIf(existing->getResults(), [&](OpOperand &operand) {
+      return !knownValues.count(operand.getOwner());
+    });
 
     // There may be some remaining uses of the operation.
     if (op->use_empty())


        


More information about the Mlir-commits mailing list