[llvm-branch-commits] [mlir] [mlir][Transforms] Dialect conversion: Fix `replaceUsesOfBlockArgument` (PR #117666)

Matthias Springer via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Mon Nov 25 20:48:55 PST 2024


https://github.com/matthias-springer created https://github.com/llvm/llvm-project/pull/117666

This commit fixes the implementation of `ConversionPatternRewriter::replaceUsesOfBlockArgument`. The old implementation was different from what the documentation says.

```
/// Replace all the uses of the block argument `from` with value `to`.
void ConversionPatternRewriter::replaceUsesOfBlockArgument(
    BlockArgument from, Value to) {
  // ...
  impl->mapping.map(impl->mapping.lookupOrDefault(from), to);
}
```

The extra `mapping.lookupOrDefault` was incorrect: we may not replace `from`, but the value that `from` is mapped to (if it is mapped).

This function is typically used after a block signature conversion to "fix up" some block arguments. During a 1:N conversion, an argument materialization is inserted. The old implementation could be used to replace the argument materialization by passing the old block argument as the `from` parameter. This was unintuitive, because it's not the block argument that is being replaced. Furthermore, replacing a block arguments of an erased block (scheduled for erasure to be precise) is incorrect from an API perspective because a block argument of an erased block should not have any uses anymore.

The new implementation of `replaceUsesOfBlockArgument` now does what the documentation says: it replaces the `from` argument. No extra lookup magic anymore. When an argument materialization should be replaced, users can call `replaceOp` on the argument materialization.


>From 18302a346179fda0b04416883a380decae3e4bfd Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Tue, 26 Nov 2024 05:38:50 +0100
Subject: [PATCH] [mlir][Transforms] Dialect conversion: Fix
 `replaceUsesOfBlockArgument`

---
 .../Transforms/Utils/DialectConversion.cpp    |  2 +-
 mlir/test/Transforms/test-legalizer.mlir      | 20 +++++++++
 mlir/test/lib/Dialect/Test/TestPatterns.cpp   | 45 ++++++++++++++++++-
 3 files changed, 65 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 60b3656d98a38e..8b7ffd791d2591 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -1641,7 +1641,7 @@ void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument from,
   });
   impl->appendRewrite<ReplaceBlockArgRewrite>(from.getOwner(), from,
                                               impl->currentTypeConverter);
-  impl->mapping.map(impl->mapping.lookupOrDefault(from), to);
+  impl->mapping.map(from, to);
 }
 
 Value ConversionPatternRewriter::getRemappedValue(Value key) {
diff --git a/mlir/test/Transforms/test-legalizer.mlir b/mlir/test/Transforms/test-legalizer.mlir
index 624add08846a28..dfa619796700eb 100644
--- a/mlir/test/Transforms/test-legalizer.mlir
+++ b/mlir/test/Transforms/test-legalizer.mlir
@@ -472,3 +472,23 @@ func.func @circular_mapping() {
   %0 = "test.erase_op"() : () -> (i64)
   "test.drop_operands_and_replace_with_valid"(%0) : (i64) -> ()
 }
+
+// -----
+
+// CHECK-LABEL: func @test_replace_uses_of_block_arg() {
+//       CHECK:   "test.convert_block_and_replace_arg"() ({
+//       CHECK:   ^bb0(%[[arg0:.*]]: f64, %[[arg1:.*]]: f64):
+//       CHECK:     %[[producer:.*]] = "test.type_producer"() : () -> f64
+//       CHECK:     %[[cast:.*]] = "test.cast"(%[[producer]], %[[arg1]]) : (f64, f64) -> f32
+//       CHECK:     "test.some_user"(%[[cast]]) : (f32) -> ()
+//       CHECK:   }) {legal} : () -> ()
+//       CHECK:   "test.return"() : () -> ()
+//       CHECK: }
+func.func @test_replace_uses_of_block_arg() {
+  "test.convert_block_and_replace_arg"() ({
+  ^bb0(%arg0: f32):
+    // expected-remark @below{{'test.some_user' is not legalizable}}
+    "test.some_user"(%arg0) : (f32) -> ()
+  }) : () -> ()
+  "test.return"() : () -> ()
+}
\ No newline at end of file
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index e931b394c86210..54699f402e2f1e 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -902,6 +902,44 @@ struct TestUndoBlockArgReplace : public ConversionPattern {
   }
 };
 
+struct TestConvertBlockAndReplaceArg : public ConversionPattern {
+  TestConvertBlockAndReplaceArg(MLIRContext *ctx,
+                                const TypeConverter &converter)
+      : ConversionPattern(converter, "test.convert_block_and_replace_arg",
+                          /*benefit=*/1, ctx) {}
+
+  LogicalResult
+  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const final {
+    // Expect single region with single block with single block argument.
+    if (op->getNumRegions() != 1)
+      return failure();
+    if (op->getRegion(0).getBlocks().size() != 1)
+      return failure();
+    Block *block = &op->getRegion(0).front();
+    if (block->getArguments().size() != 1)
+      return failure();
+
+    // Convert the block argument into to F64 block arguments.
+    TypeConverter::SignatureConversion result(1);
+    result.addInputs(0, {rewriter.getF64Type(), rewriter.getF64Type()});
+    Block *newBlock =
+        rewriter.applySignatureConversion(block, result, getTypeConverter());
+
+    // Replace the first block argument with a new op.
+    BlockArgument arg = newBlock->getArgument(0);
+    rewriter.setInsertionPointToStart(newBlock);
+    Value zero = rewriter.create<TestTypeProducerOp>(op->getLoc(),
+                                                     rewriter.getF64Type());
+    rewriter.replaceUsesOfBlockArgument(arg, zero);
+
+    // Mark the op as legal.
+    rewriter.modifyOpInPlace(
+        op, [&]() { op->setAttr("legal", rewriter.getUnitAttr()); });
+    return success();
+  }
+};
+
 /// This pattern hoists ops out of a "test.hoist_me" and then fails conversion.
 /// This is to test the rollback logic.
 struct TestUndoMoveOpBefore : public ConversionPattern {
@@ -1265,7 +1303,8 @@ struct TestLegalizePatternDriver
              TestCreateUnregisteredOp, TestUndoMoveOpBefore,
              TestUndoPropertiesModification, TestEraseOp>(&getContext());
     patterns.add<TestDropOpSignatureConversion, TestDropAndReplaceInvalidOp,
-                 TestPassthroughInvalidOp>(&getContext(), converter);
+                 TestPassthroughInvalidOp, TestConvertBlockAndReplaceArg>(
+        &getContext(), converter);
     mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns,
                                                               converter);
     mlir::populateCallOpTypeConversionPattern(patterns, converter);
@@ -1317,6 +1356,10 @@ struct TestLegalizePatternDriver
     target.addDynamicallyLegalOp<TestOpInPlaceSelfFold>(
         [](TestOpInPlaceSelfFold op) { return op.getFolded(); });
 
+    target.addDynamicallyLegalOp(
+        OperationName("test.convert_block_and_replace_arg", &getContext()),
+        [](Operation *op) { return op->hasAttr("legal"); });
+
     // Handle a partial conversion.
     if (mode == ConversionMode::Partial) {
       DenseSet<Operation *> unlegalizedOps;



More information about the llvm-branch-commits mailing list