[Mlir-commits] [mlir] 5ce68f4 - [mlir] Introduce `replaceUsesOfWith` to `RewriterBase`
Guray Ozen
llvmlistbot at llvm.org
Wed Nov 16 08:53:18 PST 2022
Author: Guray Ozen
Date: 2022-11-16T17:53:11+01:00
New Revision: 5ce68f4284c694392238f1c8c5308d08d9a56251
URL: https://github.com/llvm/llvm-project/commit/5ce68f4284c694392238f1c8c5308d08d9a56251
DIFF: https://github.com/llvm/llvm-project/commit/5ce68f4284c694392238f1c8c5308d08d9a56251.diff
LOG: [mlir] Introduce `replaceUsesOfWith` to `RewriterBase`
Finding uses of a value and replacing them with a new one is a common method. I have not seen an safe and easy shortcut that does that. This revision attempts to address that by intoroducing `replaceUsesOfWith` to `RewriterBase`.
Reviewed By: mehdi_amini
Differential Revision: https://reviews.llvm.org/D138110
Added:
Modified:
mlir/include/mlir/IR/PatternMatch.h
mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp
mlir/lib/IR/PatternMatch.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h
index e257b67ad9d8e..7b05a2d61e905 100644
--- a/mlir/include/mlir/IR/PatternMatch.h
+++ b/mlir/include/mlir/IR/PatternMatch.h
@@ -502,6 +502,11 @@ class RewriterBase : public OpBuilder, public OpBuilder::Listener {
finalizeRootUpdate(root);
}
+ /// Find uses of `from` and replace it with `to`. It also marks every modified
+ /// uses and notifies the rewriter that an in-place operation modification is
+ /// about to happen.
+ void replaceAllUsesWith(Value from, Value to);
+
/// Used to notify the rewriter that the IR failed to be rewritten because of
/// a match failure, and provide a callback to populate a diagnostic with the
/// reason why the failure occurred. This method allows for derived rewriters
diff --git a/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp b/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp
index ccac41235b230..ec493bfbd1a5a 100644
--- a/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp
+++ b/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp
@@ -246,15 +246,9 @@ DiagnosedSilenceableFailure mlir::transform::gpu::mapForeachToBlocksImpl(
sourceBlock.getOperations());
// Step 5. RAUW thread indices to thread ops.
- for (Value blockIdx : foreachThreadOp.getThreadIndices()) {
- Value val = bvm.lookup(blockIdx);
- SmallVector<OpOperand *> uses;
- for (OpOperand &use : blockIdx.getUses())
- uses.push_back(&use);
- for (OpOperand *operand : uses) {
- Operation *op = operand->getOwner();
- rewriter.updateRootInPlace(op, [&]() { operand->set(val); });
- }
+ for (Value loopIndex : foreachThreadOp.getThreadIndices()) {
+ Value blockIdx = bvm.lookup(loopIndex);
+ rewriter.replaceAllUsesWith(loopIndex, blockIdx);
}
// Step 6. Erase old op.
@@ -492,15 +486,9 @@ static DiagnosedSilenceableFailure rewriteOneForeachThreadToGpuThreads(
sourceBlock.getOperations());
// Step 6. RAUW thread indices to thread ops.
- for (Value threadIdx : foreachThreadOp.getThreadIndices()) {
- Value val = bvm.lookup(threadIdx);
- SmallVector<OpOperand *> uses;
- for (OpOperand &use : threadIdx.getUses())
- uses.push_back(&use);
- for (OpOperand *operand : uses) {
- Operation *op = operand->getOwner();
- rewriter.updateRootInPlace(op, [&]() { operand->set(val); });
- }
+ for (Value loopIndex : foreachThreadOp.getThreadIndices()) {
+ Value threadIdx = bvm.lookup(loopIndex);
+ rewriter.replaceAllUsesWith(loopIndex, threadIdx);
}
// Step 7. syncthreads.
diff --git a/mlir/lib/IR/PatternMatch.cpp b/mlir/lib/IR/PatternMatch.cpp
index d2de65e7694ba..d3072b506c027 100644
--- a/mlir/lib/IR/PatternMatch.cpp
+++ b/mlir/lib/IR/PatternMatch.cpp
@@ -309,6 +309,14 @@ void RewriterBase::mergeBlocks(Block *source, Block *dest,
source->erase();
}
+/// Find uses of `from` and replace it with `to`
+void RewriterBase::replaceAllUsesWith(Value from, Value to) {
+ for (OpOperand &operand : llvm::make_early_inc_range(from.getUses())) {
+ Operation *op = operand.getOwner();
+ updateRootInPlace(op, [&]() { operand.set(to); });
+ }
+}
+
// Merge the operations of block 'source' before the operation 'op'. Source
// block should not have existing predecessors or successors.
void RewriterBase::mergeBlockBefore(Block *source, Operation *op,
More information about the Mlir-commits
mailing list