[Mlir-commits] [mlir] 8a583dd - [MLIR] Add replaceUsesWithIf on Operation
Uday Bondhugula
llvmlistbot at llvm.org
Mon Feb 20 20:45:31 PST 2023
Author: Uday Bondhugula
Date: 2023-02-21T10:10:22+05:30
New Revision: 8a583dd22012294e253f735bf274628791fecb33
URL: https://github.com/llvm/llvm-project/commit/8a583dd22012294e253f735bf274628791fecb33
DIFF: https://github.com/llvm/llvm-project/commit/8a583dd22012294e253f735bf274628791fecb33.diff
LOG: [MLIR] Add replaceUsesWithIf on Operation
Add replaceUsesWithIf on Operation along the lines of
Value::replaceUsesWithIf. This had been missing on Operation and is
convenient to replace multi-result operations' results conditionally.
Reviewed By: lattner
Differential Revision: https://reviews.llvm.org/D144348
Added:
Modified:
mlir/include/mlir/IR/Operation.h
mlir/include/mlir/IR/ValueRange.h
mlir/lib/IR/OperationSupport.cpp
mlir/lib/Transforms/CSE.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/IR/Operation.h b/mlir/include/mlir/IR/Operation.h
index e7ef01edef313..ac6bdfc265406 100644
--- a/mlir/include/mlir/IR/Operation.h
+++ b/mlir/include/mlir/IR/Operation.h
@@ -257,6 +257,15 @@ class alignas(8) Operation final
getResults().replaceAllUsesWith(std::forward<ValuesT>(values));
}
+ /// Replace uses of results of this operation with the provided `values` if
+ /// the given callback returns true.
+ template <typename ValuesT>
+ void replaceUsesWithIf(ValuesT &&values,
+ function_ref<bool(OpOperand &)> shouldReplace) {
+ getResults().replaceUsesWithIf(std::forward<ValuesT>(values),
+ shouldReplace);
+ }
+
/// Destroys this operation and its subclass data.
void destroy();
diff --git a/mlir/include/mlir/IR/ValueRange.h b/mlir/include/mlir/IR/ValueRange.h
index 8873260736d16..0f7354bbaef42 100644
--- a/mlir/include/mlir/IR/ValueRange.h
+++ b/mlir/include/mlir/IR/ValueRange.h
@@ -279,6 +279,26 @@ class ResultRange final
/// Replace all uses of results of this range with results of 'op'.
void replaceAllUsesWith(Operation *op);
+ /// Replace uses of results of this range with the provided 'values' if the
+ /// given callback returns true. The size of `values` must match the size of
+ /// this range.
+ template <typename ValuesT>
+ std::enable_if_t<!std::is_convertible<ValuesT, Operation *>::value>
+ replaceUsesWithIf(ValuesT &&values,
+ function_ref<bool(OpOperand &)> shouldReplace) {
+ assert(static_cast<size_t>(std::distance(values.begin(), values.end())) ==
+ size() &&
+ "expected 'values' to correspond 1-1 with the number of results");
+
+ for (auto it : llvm::zip(*this, values))
+ std::get<0>(it).replaceUsesWithIf(std::get<1>(it), shouldReplace);
+ }
+
+ /// Replace uses of results of this range with results of `op` if the given
+ /// callback returns true.
+ void replaceUsesWithIf(Operation *op,
+ function_ref<bool(OpOperand &)> shouldReplace);
+
//===--------------------------------------------------------------------===//
// Users
//===--------------------------------------------------------------------===//
diff --git a/mlir/lib/IR/OperationSupport.cpp b/mlir/lib/IR/OperationSupport.cpp
index 20ce9b36737fb..a38a12def567c 100644
--- a/mlir/lib/IR/OperationSupport.cpp
+++ b/mlir/lib/IR/OperationSupport.cpp
@@ -589,6 +589,11 @@ void ResultRange::replaceAllUsesWith(Operation *op) {
replaceAllUsesWith(op->getResults());
}
+void ResultRange::replaceUsesWithIf(
+ Operation *op, function_ref<bool(OpOperand &)> shouldReplace) {
+ replaceUsesWithIf(op->getResults(), shouldReplace);
+}
+
//===----------------------------------------------------------------------===//
// ValueRange
diff --git a/mlir/lib/Transforms/CSE.cpp b/mlir/lib/Transforms/CSE.cpp
index 93e5c95cb1999..e98ccccd68d2d 100644
--- a/mlir/lib/Transforms/CSE.cpp
+++ b/mlir/lib/Transforms/CSE.cpp
@@ -124,12 +124,9 @@ void CSE::replaceUsesAndDelete(ScopedMapTy &knownValues, Operation *op,
} else {
// When the region does not have SSA dominance, we need to check if we
// have visited a use before replacing any use.
- for (auto it : llvm::zip(op->getResults(), existing->getResults())) {
- std::get<0>(it).replaceUsesWithIf(
- std::get<1>(it), [&](OpOperand &operand) {
- return !knownValues.count(operand.getOwner());
- });
- }
+ op->replaceUsesWithIf(existing->getResults(), [&](OpOperand &operand) {
+ return !knownValues.count(operand.getOwner());
+ });
// There may be some remaining uses of the operation.
if (op->use_empty())
More information about the Mlir-commits
mailing list