[llvm-branch-commits] [mlir] [mlir][Transforms] Support `replaceAllUsesWith` in dialect conversion (PR #84725)
via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Mon Mar 11 00:33:35 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
@llvm/pr-subscribers-mlir-core
Author: Matthias Springer (matthias-springer)
<details>
<summary>Changes</summary>
This commit adds support for `RewriterBase::replaceAllUsesWith` to the dialect conversion. Uses are not immediately replaced, but in a delayed fashion during the "commit" phase. No type conversions are performed; this is consistent with `ConversionPatternRewriter::replaceUsesOfBlockArgument`.
- `RewriterBase::replaceAllUsesWith` is now virtual, so that it can be overridden in the dialect conversion. Note: `RewriterBase::replaceOp` can now be turned into a non-virtual function in a follow-up commit.
- `ConversionPatternRewriter::replaceUsesOfBlockArgument` is generalized to `ConversionPatternRewriter::replaceAllUsesWith`, following the same implementation strategy.
- A new kind of "IR rewrite" is added: `ValueRewrite` with `ReplaceAllUsesRewrite` (replacing `ReplaceBlockArgRewrite`) as the only value rewrite for now.
- `replacedOps` is renamed to `erasedOps` to better capture its meaning.
---
Patch is 25.74 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/84725.diff
8 Files Affected:
- (modified) mlir/include/mlir/IR/PatternMatch.h (+1-1)
- (modified) mlir/include/mlir/Transforms/DialectConversion.h (+5-3)
- (modified) mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp (+1-1)
- (modified) mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp (+1-1)
- (modified) mlir/lib/Transforms/Utils/DialectConversion.cpp (+121-88)
- (modified) mlir/test/Transforms/test-legalizer.mlir (+18)
- (modified) mlir/test/lib/Dialect/Test/TestOps.td (+1)
- (modified) mlir/test/lib/Dialect/Test/TestPatterns.cpp (+24-3)
``````````diff
diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h
index 2be1e2e2b40276..3e11e00b9d4b40 100644
--- a/mlir/include/mlir/IR/PatternMatch.h
+++ b/mlir/include/mlir/IR/PatternMatch.h
@@ -614,7 +614,7 @@ class RewriterBase : public OpBuilder {
/// Find uses of `from` and replace them with `to`. Also notify the listener
/// about every in-place op modification (for every use that was replaced).
- void replaceAllUsesWith(Value from, Value to) {
+ virtual void replaceAllUsesWith(Value from, Value to) {
for (OpOperand &operand : llvm::make_early_inc_range(from.getUses())) {
Operation *op = operand.getOwner();
modifyOpInPlace(op, [&]() { operand.set(to); });
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 83198c9b0db545..1797ee0876e437 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -697,9 +697,6 @@ class ConversionPatternRewriter final : public PatternRewriter {
Region *region, const TypeConverter &converter,
ArrayRef<TypeConverter::SignatureConversion> blockConversions);
- /// Replace all the uses of the block argument `from` with value `to`.
- void replaceUsesOfBlockArgument(BlockArgument from, Value to);
-
/// Return the converted value of 'key' with a type defined by the type
/// converter of the currently executing pattern. Return nullptr in the case
/// of failure, the remapped value otherwise.
@@ -720,6 +717,11 @@ class ConversionPatternRewriter final : public PatternRewriter {
/// patterns even if a failure is encountered during the rewrite step.
bool canRecoverFromRewriteFailure() const override { return true; }
+ /// Find uses of `from` and replace them with `to`.
+ ///
+ /// Note: This function does not convert types.
+ void replaceAllUsesWith(Value from, Value to) override;
+
/// PatternRewriter hook for replacing an operation.
void replaceOp(Operation *op, ValueRange newValues) override;
diff --git a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
index 53b44aa3241bb1..d7ed9a196e8938 100644
--- a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
+++ b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
@@ -310,7 +310,7 @@ static void modifyFuncOpToUseBarePtrCallingConv(
Location loc = funcOp.getLoc();
auto placeholder = rewriter.create<LLVM::UndefOp>(
loc, typeConverter.convertType(memrefTy));
- rewriter.replaceUsesOfBlockArgument(arg, placeholder);
+ rewriter.replaceAllUsesWith(arg, placeholder);
Value desc = MemRefDescriptor::fromStaticShape(rewriter, loc, typeConverter,
memrefTy, arg);
diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
index 73d418cb841327..c6d2ddac9dbb19 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
@@ -201,7 +201,7 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
llvmFuncOp.getBody().getArgument(remapping->inputNo);
auto placeholder = rewriter.create<LLVM::UndefOp>(
loc, getTypeConverter()->convertType(memrefTy));
- rewriter.replaceUsesOfBlockArgument(newArg, placeholder);
+ rewriter.replaceAllUsesWith(newArg, placeholder);
Value desc = MemRefDescriptor::fromStaticShape(
rewriter, loc, *getTypeConverter(), memrefTy, newArg);
rewriter.replaceOp(placeholder, {desc});
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index c1a261eab8487d..e4a022b7a0288b 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -153,9 +153,9 @@ namespace {
/// This is useful when saving and undoing a set of rewrites.
struct RewriterState {
RewriterState(unsigned numRewrites, unsigned numIgnoredOperations,
- unsigned numReplacedOps)
+ unsigned numErasedOps)
: numRewrites(numRewrites), numIgnoredOperations(numIgnoredOperations),
- numReplacedOps(numReplacedOps) {}
+ numErasedOps(numErasedOps) {}
/// The current number of rewrites performed.
unsigned numRewrites;
@@ -163,8 +163,8 @@ struct RewriterState {
/// The current number of ignored operations.
unsigned numIgnoredOperations;
- /// The current number of replaced ops that are scheduled for erasure.
- unsigned numReplacedOps;
+ /// The current number of ops that are scheduled for erasure.
+ unsigned numErasedOps;
};
//===----------------------------------------------------------------------===//
@@ -190,13 +190,14 @@ class IRRewrite {
InlineBlock,
MoveBlock,
BlockTypeConversion,
- ReplaceBlockArg,
// Operation rewrites
MoveOperation,
ModifyOperation,
ReplaceOperation,
CreateOperation,
- UnresolvedMaterialization
+ UnresolvedMaterialization,
+ // Value rewrites
+ ReplaceAllUses
};
virtual ~IRRewrite() = default;
@@ -243,7 +244,7 @@ class BlockRewrite : public IRRewrite {
static bool classof(const IRRewrite *rewrite) {
return rewrite->getKind() >= Kind::CreateBlock &&
- rewrite->getKind() <= Kind::ReplaceBlockArg;
+ rewrite->getKind() <= Kind::BlockTypeConversion;
}
protected:
@@ -487,27 +488,6 @@ class BlockTypeConversionRewrite : public BlockRewrite {
const TypeConverter *converter;
};
-/// Replacing a block argument. This rewrite is not immediately reflected in the
-/// IR. An internal IR mapping is updated, but the actual replacement is delayed
-/// until the rewrite is committed.
-class ReplaceBlockArgRewrite : public BlockRewrite {
-public:
- ReplaceBlockArgRewrite(ConversionPatternRewriterImpl &rewriterImpl,
- Block *block, BlockArgument arg)
- : BlockRewrite(Kind::ReplaceBlockArg, rewriterImpl, block), arg(arg) {}
-
- static bool classof(const IRRewrite *rewrite) {
- return rewrite->getKind() == Kind::ReplaceBlockArg;
- }
-
- void commit(RewriterBase &rewriter) override;
-
- void rollback() override;
-
-private:
- BlockArgument arg;
-};
-
/// An operation rewrite.
class OperationRewrite : public IRRewrite {
public:
@@ -751,6 +731,44 @@ class UnresolvedMaterializationRewrite : public OperationRewrite {
/// The original output type. This is only used for argument conversions.
Type origOutputType;
};
+
+/// A value rewrite.
+class ValueRewrite : public IRRewrite {
+public:
+ /// Return the operation that this rewrite operates on.
+ Value getValue() const { return value; }
+
+ static bool classof(const IRRewrite *rewrite) {
+ return rewrite->getKind() >= Kind::ReplaceAllUses &&
+ rewrite->getKind() <= Kind::ReplaceAllUses;
+ }
+
+protected:
+ ValueRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl,
+ Value value)
+ : IRRewrite(kind, rewriterImpl), value(value) {}
+
+ // The value that this rewrite operates on.
+ Value value;
+};
+
+/// Replacing a value. This rewrite is not immediately reflected in the IR. An
+/// internal IR mapping is updated, but the actual replacement is delayed until
+/// the rewrite is committed.
+class ReplaceAllUsesRewrite : public ValueRewrite {
+public:
+ ReplaceAllUsesRewrite(ConversionPatternRewriterImpl &rewriterImpl,
+ Value value)
+ : ValueRewrite(Kind::ReplaceAllUses, rewriterImpl, value) {}
+
+ static bool classof(const IRRewrite *rewrite) {
+ return rewrite->getKind() == Kind::ReplaceAllUses;
+ }
+
+ void commit(RewriterBase &rewriter) override;
+
+ void rollback() override;
+};
} // namespace
/// Return "true" if there is an operation rewrite that matches the specified
@@ -832,8 +850,9 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
/// converted.
bool isOpIgnored(Operation *op) const;
- /// Return "true" if the given operation was replaced or erased.
- bool wasOpReplaced(Operation *op) const;
+ /// Return "true" if the given operation is scheduled for erasure. (It may
+ /// still be visible in the IR, but should not be accessed.)
+ bool wasOpErased(Operation *op) const;
//===--------------------------------------------------------------------===//
// Type Conversion
@@ -982,11 +1001,11 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
/// tracked separately.
SetVector<Operation *> ignoredOps;
- /// A set of operations that were replaced/erased. Such ops are not erased
- /// immediately but only when the dialect conversion succeeds. In the mean
- /// time, they should no longer be considered for legalization and any attempt
- /// to modify/access them is invalid rewriter API usage.
- SetVector<Operation *> replacedOps;
+ /// A set of operations that were erased. Such ops are not erased immediately
+ /// but only when the dialect conversion succeeds. In the mean time, they
+ /// should no longer be considered for legalization and any attempt to
+ /// modify/access them is invalid rewriter API usage.
+ SetVector<Operation *> erasedOps;
/// The current type converter, or nullptr if no type converter is currently
/// active.
@@ -1099,13 +1118,12 @@ LogicalResult BlockTypeConversionRewrite::materializeLiveConversions(
return success();
}
-void ReplaceBlockArgRewrite::commit(RewriterBase &rewriter) {
- Value repl = rewriterImpl.mapping.lookupOrNull(arg, arg.getType());
- if (!repl)
- return;
+void ReplaceAllUsesRewrite::commit(RewriterBase &rewriter) {
+ Value repl = rewriterImpl.mapping.lookupOrNull(value);
+ assert(repl && "expected that value is mapped");
if (isa<BlockArgument>(repl)) {
- rewriter.replaceAllUsesWith(arg, repl);
+ rewriter.replaceAllUsesWith(value, repl);
return;
}
@@ -1114,13 +1132,13 @@ void ReplaceBlockArgRewrite::commit(RewriterBase &rewriter) {
// replacement value.
Operation *replOp = cast<OpResult>(repl).getOwner();
Block *replBlock = replOp->getBlock();
- rewriter.replaceUsesWithIf(arg, repl, [&](OpOperand &operand) {
+ rewriter.replaceUsesWithIf(value, repl, [&](OpOperand &operand) {
Operation *user = operand.getOwner();
return user->getBlock() != replBlock || replOp->isBeforeInBlock(user);
});
}
-void ReplaceBlockArgRewrite::rollback() { rewriterImpl.mapping.erase(arg); }
+void ReplaceAllUsesRewrite::rollback() { rewriterImpl.mapping.erase(value); }
void ReplaceOperationRewrite::commit(RewriterBase &rewriter) {
auto *listener = dyn_cast_or_null<RewriterBase::ForwardingListener>(
@@ -1205,7 +1223,7 @@ void ConversionPatternRewriterImpl::applyRewrites() {
// State Management
RewriterState ConversionPatternRewriterImpl::getCurrentState() {
- return RewriterState(rewrites.size(), ignoredOps.size(), replacedOps.size());
+ return RewriterState(rewrites.size(), ignoredOps.size(), erasedOps.size());
}
void ConversionPatternRewriterImpl::resetState(RewriterState state) {
@@ -1216,8 +1234,8 @@ void ConversionPatternRewriterImpl::resetState(RewriterState state) {
while (ignoredOps.size() != state.numIgnoredOperations)
ignoredOps.pop_back();
- while (replacedOps.size() != state.numReplacedOps)
- replacedOps.pop_back();
+ while (erasedOps.size() != state.numErasedOps)
+ erasedOps.pop_back();
}
void ConversionPatternRewriterImpl::undoRewrites(unsigned numRewritesToKeep) {
@@ -1282,13 +1300,13 @@ LogicalResult ConversionPatternRewriterImpl::remapValues(
}
bool ConversionPatternRewriterImpl::isOpIgnored(Operation *op) const {
- // Check to see if this operation is ignored or was replaced.
- return replacedOps.count(op) || ignoredOps.count(op);
+ // Check to see if this operation is ignored or was erased.
+ return erasedOps.count(op) || ignoredOps.count(op);
}
-bool ConversionPatternRewriterImpl::wasOpReplaced(Operation *op) const {
- // Check to see if this operation was replaced.
- return replacedOps.count(op);
+bool ConversionPatternRewriterImpl::wasOpErased(Operation *op) const {
+ // Check to see if this operation was scheduled for erasure.
+ return erasedOps.count(op);
}
//===----------------------------------------------------------------------===//
@@ -1434,7 +1452,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
"invalid to provide a replacement value when the argument isn't "
"dropped");
mapping.map(origArg, inputMap->replacementValue);
- appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
+ appendRewrite<ReplaceAllUsesRewrite>(origArg);
continue;
}
@@ -1469,7 +1487,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
}
mapping.map(origArg, newArg);
- appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
+ appendRewrite<ReplaceAllUsesRewrite>(origArg);
argInfo[i] = ConvertedArgInfo(inputMap->inputNo, inputMap->size, newArg);
}
@@ -1535,8 +1553,8 @@ void ConversionPatternRewriterImpl::notifyOperationInserted(
logger.startLine() << "** Insert : '" << op->getName() << "'(" << op
<< ")\n";
});
- assert(!wasOpReplaced(op->getParentOp()) &&
- "attempting to insert into a block within a replaced/erased op");
+ assert(!wasOpErased(op->getParentOp()) &&
+ "attempting to insert into a block within an erased op");
if (!previous.isSet()) {
// This is a newly created op.
@@ -1571,8 +1589,8 @@ void ConversionPatternRewriterImpl::notifyOpReplaced(Operation *op,
appendRewrite<ReplaceOperationRewrite>(op, currentTypeConverter,
resultChanged);
- // Mark this operation and all nested ops as replaced.
- op->walk([&](Operation *op) { replacedOps.insert(op); });
+ // Mark this operation and all nested ops as erased.
+ op->walk([&](Operation *op) { erasedOps.insert(op); });
}
void ConversionPatternRewriterImpl::notifyBlockIsBeingErased(Block *block) {
@@ -1583,8 +1601,8 @@ void ConversionPatternRewriterImpl::notifyBlockIsBeingErased(Block *block) {
void ConversionPatternRewriterImpl::notifyBlockInserted(
Block *block, Region *previous, Region::iterator previousIt) {
- assert(!wasOpReplaced(block->getParentOp()) &&
- "attempting to insert into a region within a replaced/erased op");
+ assert(!wasOpErased(block->getParentOp()) &&
+ "attempting to insert into a region within an erased op");
LLVM_DEBUG(
{
Operation *parent = block->getParentOp();
@@ -1660,8 +1678,8 @@ void ConversionPatternRewriter::eraseOp(Operation *op) {
}
void ConversionPatternRewriter::eraseBlock(Block *block) {
- assert(!impl->wasOpReplaced(block->getParentOp()) &&
- "attempting to erase a block within a replaced/erased op");
+ assert(!impl->wasOpErased(block->getParentOp()) &&
+ "attempting to erase a block within an erased op");
// Mark all ops for erasure.
for (Operation &op : *block)
@@ -1678,41 +1696,59 @@ void ConversionPatternRewriter::eraseBlock(Block *block) {
Block *ConversionPatternRewriter::applySignatureConversion(
Region *region, TypeConverter::SignatureConversion &conversion,
const TypeConverter *converter) {
- assert(!impl->wasOpReplaced(region->getParentOp()) &&
- "attempting to apply a signature conversion to a block within a "
- "replaced/erased op");
+ assert(!impl->wasOpErased(region->getParentOp()) &&
+ "attempting to apply a signature conversion to a block within an "
+ "erased op");
return impl->applySignatureConversion(*this, region, conversion, converter);
}
FailureOr<Block *> ConversionPatternRewriter::convertRegionTypes(
Region *region, const TypeConverter &converter,
TypeConverter::SignatureConversion *entryConversion) {
- assert(!impl->wasOpReplaced(region->getParentOp()) &&
- "attempting to apply a signature conversion to a block within a "
- "replaced/erased op");
+ assert(!impl->wasOpErased(region->getParentOp()) &&
+ "attempting to apply a signature conversion to a block within an "
+ "erased op");
return impl->convertRegionTypes(*this, region, converter, entryConversion);
}
LogicalResult ConversionPatternRewriter::convertNonEntryRegionTypes(
Region *region, const TypeConverter &converter,
ArrayRef<TypeConverter::SignatureConversion> blockConversions) {
- assert(!impl->wasOpReplaced(region->getParentOp()) &&
- "attempting to apply a signature conversion to a block within a "
- "replaced/erased op");
+ assert(!impl->wasOpErased(region->getParentOp()) &&
+ "attempting to apply a signature conversion to a block within an "
+ "erased op");
return impl->convertNonEntryRegionTypes(*this, region, converter,
blockConversions);
}
-void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument from,
- Value to) {
+void ConversionPatternRewriter::replaceAllUsesWith(Value from, Value to) {
+#ifndef NDEBUG
LLVM_DEBUG({
- Operation *parentOp = from.getOwner()->getParentOp();
- impl->logger.startLine() << "** Replace Argument : '" << from
- << "'(in region of '" << parentOp->getName()
- << "'(" << from.getOwner()->getParentOp() << ")\n";
+ Block *parentBlock = from.getParentBlock();
+ Operation *parentOp = parentBlock ? parentBlock->getParentOp() : nullptr;
+ impl->logger.startLine() << "** Replace value : '" << from;
+ if (parentOp) {
+ impl->logger.getOStream() << "' (in region of '" << parentOp->getName()
+ << "'(" << parentOp << ")\n";
+ } else {
+ impl->logger.getOStream() << "' (detached)\n";
+ }
});
- impl->appendRewrite<ReplaceBlockArgRewrite>(from.getOwner(), from);
- impl->mapping.map(impl->mapping.lookupOrDefault(from), to);
+ if (OpResult opResult = dyn_cast<OpResult>(from)) {
+ assert(!impl->wasOpErased(opResult.getDefiningOp()) &&
+ "attempting to replace an OpResult defined by an erased op");
+ }
+ if (OpResult opResult = dyn_cast<OpResult>(to)) {
+ assert(!impl->wasOpErased(opResult.getDefiningOp()) &&
+ "attempting to replace with an OpResult defined by an erased op");
+ }
+ // A value cannot be replaced multiple times. That would likely require a more
+ // fine-grained tracking of replacements (i.e., each use must be tracked).
+ assert(!impl->mapping.lookupOrNull(from) && "value was already replaced");
+#endif // NDEBUG
+
+ impl->appendRewrite<ReplaceAllUsesRewrite>(from);
+ impl->mapping.map(from, to);
}
Value ConversionPatternRewriter::getRemappedValue(Value key) {
@@ -1738,10 +1774,10 @@ void ConversionPatternRewriter::inlineBlockBefore(Block *source, Block *dest,
#ifndef NDEBUG
assert(argValues.size() == source->getNumArguments() &&
"incorrect # of argument replacement values");
- assert(!impl->wasOpReplaced(source->getParentOp()) &&
- "attempting to inline a block from a replaced/erased op");
- assert(!impl->wasOpReplaced(dest->getParentOp()) &&
- "attempting to inline a block into a replaced/erased op");
+ assert(!impl->wasOpErased(source->getParentOp()) &&
+ "attempting to inline a block from an erased op");
+ assert(!impl->wasOpErased(dest->getParentOp()) &&
+ "attempting to inline a block into an erased op");
auto opIgnored = [&](Operation *op) { return impl->isOpIgnored(op); };
// The source block will be deleted, so it should not have any users (i.e.,
// there should be no predecessors).
@@ -1762,7 +1798,7 @@ void ConversionPatternRewriter::inlineBlockBefore(Block *source, Block *dest,
// Replace all uses of block arguments.
for (auto it : llvm::zip(source->getArguments(), argValues))
- replaceUsesOfBlockArgument(std::get<0>(it), std::get<1>(it));
+ replaceAllUsesWith(std::get<0>(it), std::get<1>(it));
if (fastPath) {
// Move all ops at once.
@@ -1778,8 +1814,7 @@ void ConversionPatternRewriter::inlineBlockBefore(Block *source, Block *dest,
}
void ConversionPatternRewriter::startOpModification(Operation *op) {
- assert(...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/84725
More information about the llvm-branch-commits
mailing list