[Mlir-commits] [mlir] 0816de1 - [mlir][DialectConversion] Add support for properly tracking replaceUsesOfBlockArgument
River Riddle
llvmlistbot at llvm.org
Fri Apr 24 12:38:46 PDT 2020
Author: River Riddle
Date: 2020-04-24T12:37:32-07:00
New Revision: 0816de167a7418904287ffb8173e31516880364d
URL: https://github.com/llvm/llvm-project/commit/0816de167a7418904287ffb8173e31516880364d
DIFF: https://github.com/llvm/llvm-project/commit/0816de167a7418904287ffb8173e31516880364d.diff
LOG: [mlir][DialectConversion] Add support for properly tracking replaceUsesOfBlockArgument
The current implementation of this method performs the replacement directly, and thus doesn't support proper back tracking.
Differential Revision: https://reviews.llvm.org/D78790
Added:
Modified:
mlir/include/mlir/IR/Value.h
mlir/lib/IR/Value.cpp
mlir/lib/Transforms/DialectConversion.cpp
mlir/test/Transforms/test-legalizer.mlir
mlir/test/lib/Dialect/Test/TestPatterns.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/IR/Value.h b/mlir/include/mlir/IR/Value.h
index 42001ae7db87..95def7686792 100644
--- a/mlir/include/mlir/IR/Value.h
+++ b/mlir/include/mlir/IR/Value.h
@@ -145,6 +145,11 @@ class Value {
replaceAllUsesExcept(Value newValue,
const SmallPtrSetImpl<Operation *> &exceptions) const;
+ /// Replace all uses of 'this' value with 'newValue' if the given callback
+ /// returns true.
+ void replaceUsesWithIf(Value newValue,
+ function_ref<bool(OpOperand &)> shouldReplace);
+
//===--------------------------------------------------------------------===//
// Uses
diff --git a/mlir/lib/IR/Value.cpp b/mlir/lib/IR/Value.cpp
index f011c1341769..fdc5ad6be887 100644
--- a/mlir/lib/IR/Value.cpp
+++ b/mlir/lib/IR/Value.cpp
@@ -125,6 +125,15 @@ void Value::replaceAllUsesExcept(
}
}
+/// Replace all uses of 'this' value with 'newValue' if the given callback
+/// returns true.
+void Value::replaceUsesWithIf(Value newValue,
+ function_ref<bool(OpOperand &)> shouldReplace) {
+ for (OpOperand &use : llvm::make_early_inc_range(getUses()))
+ if (shouldReplace(use))
+ use.set(newValue);
+}
+
//===--------------------------------------------------------------------===//
// Uses
diff --git a/mlir/lib/Transforms/DialectConversion.cpp b/mlir/lib/Transforms/DialectConversion.cpp
index 04bf62753a9b..63db2c80c800 100644
--- a/mlir/lib/Transforms/DialectConversion.cpp
+++ b/mlir/lib/Transforms/DialectConversion.cpp
@@ -197,8 +197,6 @@ struct ArgConverter {
/// Fully replace uses of the old arguments with the new, materializing cast
/// operations as necessary.
- // FIXME(riverriddle) The 'mapping' parameter is only necessary because the
- // implementation of replaceUsesOfBlockArgument is buggy.
void applyRewrites(ConversionValueMapping &mapping);
//===--------------------------------------------------------------------===//
@@ -436,9 +434,10 @@ namespace {
/// This is useful when saving and undoing a set of rewrites.
struct RewriterState {
RewriterState(unsigned numCreatedOps, unsigned numReplacements,
- unsigned numBlockActions, unsigned numIgnoredOperations,
- unsigned numRootUpdates)
+ unsigned numArgReplacements, unsigned numBlockActions,
+ unsigned numIgnoredOperations, unsigned numRootUpdates)
: numCreatedOps(numCreatedOps), numReplacements(numReplacements),
+ numArgReplacements(numArgReplacements),
numBlockActions(numBlockActions),
numIgnoredOperations(numIgnoredOperations),
numRootUpdates(numRootUpdates) {}
@@ -449,6 +448,9 @@ struct RewriterState {
/// The current number of replacements queued.
unsigned numReplacements;
+ /// The current number of argument replacements queued.
+ unsigned numArgReplacements;
+
/// The current number of block actions performed.
unsigned numBlockActions;
@@ -624,6 +626,9 @@ struct ConversionPatternRewriterImpl {
/// Ordered vector of any requested operation replacements.
SmallVector<OpReplacement, 4> replacements;
+ /// Ordered vector of any requested block argument replacements.
+ SmallVector<BlockArgument, 4> argReplacements;
+
/// Ordered list of block operations (creations, splits, motions).
SmallVector<BlockAction, 4> blockActions;
@@ -654,8 +659,8 @@ struct ConversionPatternRewriterImpl {
RewriterState ConversionPatternRewriterImpl::getCurrentState() {
return RewriterState(createdOps.size(), replacements.size(),
- blockActions.size(), ignoredOps.size(),
- rootUpdates.size());
+ argReplacements.size(), blockActions.size(),
+ ignoredOps.size(), rootUpdates.size());
}
void ConversionPatternRewriterImpl::resetState(RewriterState state) {
@@ -664,6 +669,12 @@ void ConversionPatternRewriterImpl::resetState(RewriterState state) {
rootUpdates[i].resetOperation();
rootUpdates.resize(state.numRootUpdates);
+ // Reset any replaced arguments.
+ for (BlockArgument replacedArg :
+ llvm::drop_begin(argReplacements, state.numArgReplacements))
+ mapping.erase(replacedArg);
+ argReplacements.resize(state.numArgReplacements);
+
// Undo any block actions.
undoBlockActions(state.numBlockActions);
@@ -753,6 +764,25 @@ void ConversionPatternRewriterImpl::applyRewrites() {
argConverter.notifyOpRemoved(repl.op);
}
+ // Apply all of the requested argument replacements.
+ for (BlockArgument arg : argReplacements) {
+ Value repl = mapping.lookupOrDefault(arg);
+ if (repl.isa<BlockArgument>()) {
+ arg.replaceAllUsesWith(repl);
+ continue;
+ }
+
+ // If the replacement value is an operation, we check to make sure that we
+ // don't replace uses that are within the parent operation of the
+ // replacement value.
+ Operation *replOp = repl.cast<OpResult>().getOwner();
+ Block *replBlock = replOp->getBlock();
+ arg.replaceUsesWithIf(repl, [&](OpOperand &operand) {
+ Operation *user = operand.getOwner();
+ return user->getBlock() != replBlock || replOp->isBeforeInBlock(user);
+ });
+ }
+
// In a second pass, erase all of the replaced operations in reverse. This
// allows processing nested operations before their parent region is
// destroyed.
@@ -907,11 +937,13 @@ Block *ConversionPatternRewriter::applySignatureConversion(
void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument from,
Value to) {
- for (auto &u : from.getUses()) {
- if (u.getOwner() == to.getDefiningOp())
- continue;
- u.getOwner()->replaceUsesOfWith(from, to);
- }
+ LLVM_DEBUG({
+ Operation *parentOp = from.getOwner()->getParentOp();
+ impl->logger.startLine() << "** Replace Argument : '" << from
+ << "'(in region of '" << parentOp->getName()
+ << "'(" << from.getOwner()->getParentOp() << ")\n";
+ });
+ impl->argReplacements.push_back(from);
impl->mapping.map(impl->mapping.lookupOrDefault(from), to);
}
diff --git a/mlir/test/Transforms/test-legalizer.mlir b/mlir/test/Transforms/test-legalizer.mlir
index 557908d2b1a4..5c5434446abe 100644
--- a/mlir/test/Transforms/test-legalizer.mlir
+++ b/mlir/test/Transforms/test-legalizer.mlir
@@ -197,3 +197,17 @@ func @create_illegal_block() {
}) : () -> ()
return
}
+
+// -----
+
+// CHECK-LABEL: @undo_block_arg_replace
+func @undo_block_arg_replace() {
+ "test.undo_block_arg_replace"() ({
+ ^bb0(%arg0: i32):
+ // CHECK: ^bb0(%[[ARG:.*]]: i32):
+ // CHECK-NEXT: "test.return"(%[[ARG]]) : (i32)
+
+ "test.return"(%arg0) : (i32) -> ()
+ }) : () -> ()
+ return
+}
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index c14dea2b9534..d21d59ca1e8b 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -238,6 +238,24 @@ struct TestCreateIllegalBlock : public RewritePattern {
}
};
+/// A simple pattern that tests the undo mechanism when replacing the uses of a
+/// block argument.
+struct TestUndoBlockArgReplace : public ConversionPattern {
+ TestUndoBlockArgReplace(MLIRContext *ctx)
+ : ConversionPattern("test.undo_block_arg_replace", /*benefit=*/1, ctx) {}
+
+ LogicalResult
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const final {
+ auto illegalOp =
+ rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type());
+ rewriter.replaceUsesOfBlockArgument(op->getRegion(0).front().getArgument(0),
+ illegalOp);
+ rewriter.updateRootInPlace(op, [] {});
+ return success();
+ }
+};
+
//===----------------------------------------------------------------------===//
// Type-Conversion Rewrite Testing
@@ -449,12 +467,14 @@ struct TestLegalizePatternDriver
TestTypeConverter converter;
mlir::OwningRewritePatternList patterns;
populateWithGenerated(&getContext(), &patterns);
- patterns.insert<
- TestRegionRewriteBlockMovement, TestRegionRewriteUndo, TestCreateBlock,
- TestCreateIllegalBlock, TestPassthroughInvalidOp, TestSplitReturnType,
- TestChangeProducerTypeI32ToF32, TestChangeProducerTypeF32ToF64,
- TestChangeProducerTypeF32ToInvalid, TestUpdateConsumerType,
- TestNonRootReplacement, TestBoundedRecursiveRewrite>(&getContext());
+ patterns.insert<TestRegionRewriteBlockMovement, TestRegionRewriteUndo,
+ TestCreateBlock, TestCreateIllegalBlock,
+ TestUndoBlockArgReplace, TestPassthroughInvalidOp,
+ TestSplitReturnType, TestChangeProducerTypeI32ToF32,
+ TestChangeProducerTypeF32ToF64,
+ TestChangeProducerTypeF32ToInvalid, TestUpdateConsumerType,
+ TestNonRootReplacement, TestBoundedRecursiveRewrite>(
+ &getContext());
patterns.insert<TestDropOpSignatureConversion>(&getContext(), converter);
mlir::populateFuncOpTypeConversionPattern(patterns, &getContext(),
converter);
More information about the Mlir-commits
mailing list