[Mlir-commits] [mlir] 0816de1 - [mlir][DialectConversion] Add support for properly tracking replaceUsesOfBlockArgument

River Riddle llvmlistbot at llvm.org
Fri Apr 24 12:38:46 PDT 2020


Author: River Riddle
Date: 2020-04-24T12:37:32-07:00
New Revision: 0816de167a7418904287ffb8173e31516880364d

URL: https://github.com/llvm/llvm-project/commit/0816de167a7418904287ffb8173e31516880364d
DIFF: https://github.com/llvm/llvm-project/commit/0816de167a7418904287ffb8173e31516880364d.diff

LOG: [mlir][DialectConversion] Add support for properly tracking replaceUsesOfBlockArgument

The current implementation of this method performs the replacement directly, and thus doesn't support proper back tracking.

Differential Revision: https://reviews.llvm.org/D78790

Added: 
    

Modified: 
    mlir/include/mlir/IR/Value.h
    mlir/lib/IR/Value.cpp
    mlir/lib/Transforms/DialectConversion.cpp
    mlir/test/Transforms/test-legalizer.mlir
    mlir/test/lib/Dialect/Test/TestPatterns.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/IR/Value.h b/mlir/include/mlir/IR/Value.h
index 42001ae7db87..95def7686792 100644
--- a/mlir/include/mlir/IR/Value.h
+++ b/mlir/include/mlir/IR/Value.h
@@ -145,6 +145,11 @@ class Value {
   replaceAllUsesExcept(Value newValue,
                        const SmallPtrSetImpl<Operation *> &exceptions) const;
 
+  /// Replace all uses of 'this' value with 'newValue' if the given callback
+  /// returns true.
+  void replaceUsesWithIf(Value newValue,
+                         function_ref<bool(OpOperand &)> shouldReplace);
+
   //===--------------------------------------------------------------------===//
   // Uses
 

diff  --git a/mlir/lib/IR/Value.cpp b/mlir/lib/IR/Value.cpp
index f011c1341769..fdc5ad6be887 100644
--- a/mlir/lib/IR/Value.cpp
+++ b/mlir/lib/IR/Value.cpp
@@ -125,6 +125,15 @@ void Value::replaceAllUsesExcept(
   }
 }
 
+/// Replace all uses of 'this' value with 'newValue' if the given callback
+/// returns true.
+void Value::replaceUsesWithIf(Value newValue,
+                              function_ref<bool(OpOperand &)> shouldReplace) {
+  for (OpOperand &use : llvm::make_early_inc_range(getUses()))
+    if (shouldReplace(use))
+      use.set(newValue);
+}
+
 //===--------------------------------------------------------------------===//
 // Uses
 

diff  --git a/mlir/lib/Transforms/DialectConversion.cpp b/mlir/lib/Transforms/DialectConversion.cpp
index 04bf62753a9b..63db2c80c800 100644
--- a/mlir/lib/Transforms/DialectConversion.cpp
+++ b/mlir/lib/Transforms/DialectConversion.cpp
@@ -197,8 +197,6 @@ struct ArgConverter {
 
   /// Fully replace uses of the old arguments with the new, materializing cast
   /// operations as necessary.
-  // FIXME(riverriddle) The 'mapping' parameter is only necessary because the
-  // implementation of replaceUsesOfBlockArgument is buggy.
   void applyRewrites(ConversionValueMapping &mapping);
 
   //===--------------------------------------------------------------------===//
@@ -436,9 +434,10 @@ namespace {
 /// This is useful when saving and undoing a set of rewrites.
 struct RewriterState {
   RewriterState(unsigned numCreatedOps, unsigned numReplacements,
-                unsigned numBlockActions, unsigned numIgnoredOperations,
-                unsigned numRootUpdates)
+                unsigned numArgReplacements, unsigned numBlockActions,
+                unsigned numIgnoredOperations, unsigned numRootUpdates)
       : numCreatedOps(numCreatedOps), numReplacements(numReplacements),
+        numArgReplacements(numArgReplacements),
         numBlockActions(numBlockActions),
         numIgnoredOperations(numIgnoredOperations),
         numRootUpdates(numRootUpdates) {}
@@ -449,6 +448,9 @@ struct RewriterState {
   /// The current number of replacements queued.
   unsigned numReplacements;
 
+  /// The current number of argument replacements queued.
+  unsigned numArgReplacements;
+
   /// The current number of block actions performed.
   unsigned numBlockActions;
 
@@ -624,6 +626,9 @@ struct ConversionPatternRewriterImpl {
   /// Ordered vector of any requested operation replacements.
   SmallVector<OpReplacement, 4> replacements;
 
+  /// Ordered vector of any requested block argument replacements.
+  SmallVector<BlockArgument, 4> argReplacements;
+
   /// Ordered list of block operations (creations, splits, motions).
   SmallVector<BlockAction, 4> blockActions;
 
@@ -654,8 +659,8 @@ struct ConversionPatternRewriterImpl {
 
 RewriterState ConversionPatternRewriterImpl::getCurrentState() {
   return RewriterState(createdOps.size(), replacements.size(),
-                       blockActions.size(), ignoredOps.size(),
-                       rootUpdates.size());
+                       argReplacements.size(), blockActions.size(),
+                       ignoredOps.size(), rootUpdates.size());
 }
 
 void ConversionPatternRewriterImpl::resetState(RewriterState state) {
@@ -664,6 +669,12 @@ void ConversionPatternRewriterImpl::resetState(RewriterState state) {
     rootUpdates[i].resetOperation();
   rootUpdates.resize(state.numRootUpdates);
 
+  // Reset any replaced arguments.
+  for (BlockArgument replacedArg :
+       llvm::drop_begin(argReplacements, state.numArgReplacements))
+    mapping.erase(replacedArg);
+  argReplacements.resize(state.numArgReplacements);
+
   // Undo any block actions.
   undoBlockActions(state.numBlockActions);
 
@@ -753,6 +764,25 @@ void ConversionPatternRewriterImpl::applyRewrites() {
       argConverter.notifyOpRemoved(repl.op);
   }
 
+  // Apply all of the requested argument replacements.
+  for (BlockArgument arg : argReplacements) {
+    Value repl = mapping.lookupOrDefault(arg);
+    if (repl.isa<BlockArgument>()) {
+      arg.replaceAllUsesWith(repl);
+      continue;
+    }
+
+    // If the replacement value is an operation, we check to make sure that we
+    // don't replace uses that are within the parent operation of the
+    // replacement value.
+    Operation *replOp = repl.cast<OpResult>().getOwner();
+    Block *replBlock = replOp->getBlock();
+    arg.replaceUsesWithIf(repl, [&](OpOperand &operand) {
+      Operation *user = operand.getOwner();
+      return user->getBlock() != replBlock || replOp->isBeforeInBlock(user);
+    });
+  }
+
   // In a second pass, erase all of the replaced operations in reverse. This
   // allows processing nested operations before their parent region is
   // destroyed.
@@ -907,11 +937,13 @@ Block *ConversionPatternRewriter::applySignatureConversion(
 
 void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument from,
                                                            Value to) {
-  for (auto &u : from.getUses()) {
-    if (u.getOwner() == to.getDefiningOp())
-      continue;
-    u.getOwner()->replaceUsesOfWith(from, to);
-  }
+  LLVM_DEBUG({
+    Operation *parentOp = from.getOwner()->getParentOp();
+    impl->logger.startLine() << "** Replace Argument : '" << from
+                             << "'(in region of '" << parentOp->getName()
+                             << "'(" << from.getOwner()->getParentOp() << ")\n";
+  });
+  impl->argReplacements.push_back(from);
   impl->mapping.map(impl->mapping.lookupOrDefault(from), to);
 }
 

diff  --git a/mlir/test/Transforms/test-legalizer.mlir b/mlir/test/Transforms/test-legalizer.mlir
index 557908d2b1a4..5c5434446abe 100644
--- a/mlir/test/Transforms/test-legalizer.mlir
+++ b/mlir/test/Transforms/test-legalizer.mlir
@@ -197,3 +197,17 @@ func @create_illegal_block() {
   }) : () -> ()
   return
 }
+
+// -----
+
+// CHECK-LABEL: @undo_block_arg_replace
+func @undo_block_arg_replace() {
+  "test.undo_block_arg_replace"() ({
+  ^bb0(%arg0: i32):
+    // CHECK: ^bb0(%[[ARG:.*]]: i32):
+    // CHECK-NEXT: "test.return"(%[[ARG]]) : (i32)
+
+    "test.return"(%arg0) : (i32) -> ()
+  }) : () -> ()
+  return
+}

diff  --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index c14dea2b9534..d21d59ca1e8b 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -238,6 +238,24 @@ struct TestCreateIllegalBlock : public RewritePattern {
   }
 };
 
+/// A simple pattern that tests the undo mechanism when replacing the uses of a
+/// block argument.
+struct TestUndoBlockArgReplace : public ConversionPattern {
+  TestUndoBlockArgReplace(MLIRContext *ctx)
+      : ConversionPattern("test.undo_block_arg_replace", /*benefit=*/1, ctx) {}
+
+  LogicalResult
+  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const final {
+    auto illegalOp =
+        rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type());
+    rewriter.replaceUsesOfBlockArgument(op->getRegion(0).front().getArgument(0),
+                                        illegalOp);
+    rewriter.updateRootInPlace(op, [] {});
+    return success();
+  }
+};
+
 //===----------------------------------------------------------------------===//
 // Type-Conversion Rewrite Testing
 
@@ -449,12 +467,14 @@ struct TestLegalizePatternDriver
     TestTypeConverter converter;
     mlir::OwningRewritePatternList patterns;
     populateWithGenerated(&getContext(), &patterns);
-    patterns.insert<
-        TestRegionRewriteBlockMovement, TestRegionRewriteUndo, TestCreateBlock,
-        TestCreateIllegalBlock, TestPassthroughInvalidOp, TestSplitReturnType,
-        TestChangeProducerTypeI32ToF32, TestChangeProducerTypeF32ToF64,
-        TestChangeProducerTypeF32ToInvalid, TestUpdateConsumerType,
-        TestNonRootReplacement, TestBoundedRecursiveRewrite>(&getContext());
+    patterns.insert<TestRegionRewriteBlockMovement, TestRegionRewriteUndo,
+                    TestCreateBlock, TestCreateIllegalBlock,
+                    TestUndoBlockArgReplace, TestPassthroughInvalidOp,
+                    TestSplitReturnType, TestChangeProducerTypeI32ToF32,
+                    TestChangeProducerTypeF32ToF64,
+                    TestChangeProducerTypeF32ToInvalid, TestUpdateConsumerType,
+                    TestNonRootReplacement, TestBoundedRecursiveRewrite>(
+        &getContext());
     patterns.insert<TestDropOpSignatureConversion>(&getContext(), converter);
     mlir::populateFuncOpTypeConversionPattern(patterns, &getContext(),
                                               converter);


        


More information about the Mlir-commits mailing list