[Mlir-commits] [mlir] 0409eb2 - [mlir] Keep track of region signature conversions as argument replacements

Alex Zinenko llvmlistbot at llvm.org
Tue Feb 2 01:38:40 PST 2021


Author: Alex Zinenko
Date: 2021-02-02T10:38:31+01:00
New Revision: 0409eb287414b71cfdcda13796c93794b71ea6d4

URL: https://github.com/llvm/llvm-project/commit/0409eb287414b71cfdcda13796c93794b71ea6d4
DIFF: https://github.com/llvm/llvm-project/commit/0409eb287414b71cfdcda13796c93794b71ea6d4.diff

LOG: [mlir] Keep track of region signature conversions as argument replacements

In dialect conversion, signature conversions essentially perform block argument
replacement and are added to the general value remapping. However, the replaced
values were not tracked, so if a signature conversion was rolled back, the
construction of operand lists for the following patterns could have obtained
block arguments from the mapping and give them to the pattern leading to
use-after-free. Keep track of signature conversions similarly to normal block
argument replacement, and erase such replacements from the general mapping when
the conversion is rolled back.

Reviewed By: rriddle

Differential Revision: https://reviews.llvm.org/D95688

Added: 
    

Modified: 
    mlir/lib/Transforms/Utils/DialectConversion.cpp
    mlir/test/Transforms/test-legalize-type-conversion.mlir
    mlir/test/lib/Dialect/Test/TestOps.td
    mlir/test/lib/Dialect/Test/TestPatterns.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index ae62a63b1228..ecbd57d947e8 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -256,8 +256,10 @@ struct ArgConverter {
   /// Attempt to convert the signature of the given block, if successful a new
   /// block is returned containing the new arguments. Returns `block` if it did
   /// not require conversion.
-  FailureOr<Block *> convertSignature(Block *block, TypeConverter &converter,
-                                      ConversionValueMapping &mapping);
+  FailureOr<Block *>
+  convertSignature(Block *block, TypeConverter &converter,
+                   ConversionValueMapping &mapping,
+                   SmallVectorImpl<BlockArgument> &argReplacements);
 
   /// Apply the given signature conversion on the given block. The new block
   /// containing the updated signature is returned. If no conversions were
@@ -268,7 +270,8 @@ struct ArgConverter {
   Block *applySignatureConversion(
       Block *block, TypeConverter &converter,
       TypeConverter::SignatureConversion &signatureConversion,
-      ConversionValueMapping &mapping);
+      ConversionValueMapping &mapping,
+      SmallVectorImpl<BlockArgument> &argReplacements);
 
   /// Insert a new conversion into the cache.
   void insertConversion(Block *newBlock, ConvertedBlockInfo &&info);
@@ -425,9 +428,9 @@ LogicalResult ArgConverter::materializeLiveConversions(
 //===----------------------------------------------------------------------===//
 // Conversion
 
-FailureOr<Block *>
-ArgConverter::convertSignature(Block *block, TypeConverter &converter,
-                               ConversionValueMapping &mapping) {
+FailureOr<Block *> ArgConverter::convertSignature(
+    Block *block, TypeConverter &converter, ConversionValueMapping &mapping,
+    SmallVectorImpl<BlockArgument> &argReplacements) {
   // Check if the block was already converted. If the block is detached,
   // conservatively assume it is going to be deleted.
   if (hasBeenConverted(block) || !block->getParent())
@@ -435,14 +438,16 @@ ArgConverter::convertSignature(Block *block, TypeConverter &converter,
 
   // Try to convert the signature for the block with the provided converter.
   if (auto conversion = converter.convertBlockSignature(block))
-    return applySignatureConversion(block, converter, *conversion, mapping);
+    return applySignatureConversion(block, converter, *conversion, mapping,
+                                    argReplacements);
   return failure();
 }
 
 Block *ArgConverter::applySignatureConversion(
     Block *block, TypeConverter &converter,
     TypeConverter::SignatureConversion &signatureConversion,
-    ConversionValueMapping &mapping) {
+    ConversionValueMapping &mapping,
+    SmallVectorImpl<BlockArgument> &argReplacements) {
   // If no arguments are being changed or added, there is nothing to do.
   unsigned origArgCount = block->getNumArguments();
   auto convertedTypes = signatureConversion.getConvertedTypes();
@@ -477,6 +482,7 @@ Block *ArgConverter::applySignatureConversion(
              "invalid to provide a replacement value when the argument isn't "
              "dropped");
       mapping.map(origArg, inputMap->replacementValue);
+      argReplacements.push_back(origArg);
       continue;
     }
 
@@ -492,6 +498,7 @@ Block *ArgConverter::applySignatureConversion(
       newArg = replArgs.front();
     }
     mapping.map(origArg, newArg);
+    argReplacements.push_back(origArg);
     info.argInfo[i] =
         ConvertedArgInfo(inputMap->inputNo, inputMap->size, newArg);
   }
@@ -1113,9 +1120,10 @@ FailureOr<Block *> ConversionPatternRewriterImpl::convertBlockSignature(
     Block *block, TypeConverter &converter,
     TypeConverter::SignatureConversion *conversion) {
   FailureOr<Block *> result =
-      conversion ? argConverter.applySignatureConversion(block, converter,
-                                                         *conversion, mapping)
-                 : argConverter.convertSignature(block, converter, mapping);
+      conversion ? argConverter.applySignatureConversion(
+                       block, converter, *conversion, mapping, argReplacements)
+                 : argConverter.convertSignature(block, converter, mapping,
+                                                 argReplacements);
   if (Block *newBlock = result.getValue()) {
     if (newBlock != block)
       blockActions.push_back(BlockAction::getTypeConversion(newBlock));

diff  --git a/mlir/test/Transforms/test-legalize-type-conversion.mlir b/mlir/test/Transforms/test-legalize-type-conversion.mlir
index c56b3c8ca1e2..1ea8ddc2660c 100644
--- a/mlir/test/Transforms/test-legalize-type-conversion.mlir
+++ b/mlir/test/Transforms/test-legalize-type-conversion.mlir
@@ -62,3 +62,18 @@ func @test_valid_result_legalization() {
   %result = "test.type_producer"() : () -> f32
   "foo.return"(%result) : (f32) -> ()
 }
+
+// -----
+
+// Should not segfault here but gracefully fail.
+// CHECK-LABEL: func @test_signature_conversion_undo
+func @test_signature_conversion_undo() {
+  // CHECK: test.signature_conversion_undo
+  "test.signature_conversion_undo"() ({
+  // CHECK: ^{{.*}}(%{{.*}}: f32):
+  ^bb0(%arg0: f32):
+    "test.type_consumer"(%arg0) : (f32) -> ()
+    "test.return"(%arg0) : (f32) -> ()
+  }) : () -> ()
+  return
+}

diff  --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 89d2ee87356b..ec836f2385b6 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -1289,6 +1289,10 @@ def TestMergeBlocksOp : TEST_Op<"merge_blocks"> {
   let results = (outs Variadic<AnyType>:$result);
 }
 
+def TestSignatureConversionUndoOp : TEST_Op<"signature_conversion_undo"> {
+  let regions = (region AnyRegion);
+}
+
 //===----------------------------------------------------------------------===//
 // Test parser.
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index a4c16a6d533f..7da02ed0c5c3 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -774,6 +774,34 @@ struct TestTypeConversionProducer
   }
 };
 
+/// Call signature conversion and then fail the rewrite to trigger the undo
+/// mechanism.
+struct TestSignatureConversionUndo
+    : public OpConversionPattern<TestSignatureConversionUndoOp> {
+  using OpConversionPattern<TestSignatureConversionUndoOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(TestSignatureConversionUndoOp op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const final {
+    (void)rewriter.convertRegionTypes(&op->getRegion(0), *getTypeConverter());
+    return failure();
+  }
+};
+
+/// Just forward the operands to the root op. This is essentially a no-op
+/// pattern that is used to trigger target materialization.
+struct TestTypeConsumerForward
+    : public OpConversionPattern<TestTypeConsumerOp> {
+  using OpConversionPattern<TestTypeConsumerOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(TestTypeConsumerOp op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const final {
+    rewriter.updateRootInPlace(op, [&] { op->setOperands(operands); });
+    return success();
+  }
+};
+
 struct TestTypeConversionDriver
     : public PassWrapper<TestTypeConversionDriver, OperationPass<ModuleOp>> {
   void getDependentDialects(DialectRegistry &registry) const override {
@@ -836,7 +864,8 @@ struct TestTypeConversionDriver
 
     // Initialize the set of rewrite patterns.
     OwningRewritePatternList patterns;
-    patterns.insert<TestTypeConversionProducer>(converter, &getContext());
+    patterns.insert<TestTypeConsumerForward, TestTypeConversionProducer,
+                    TestSignatureConversionUndo>(converter, &getContext());
     mlir::populateFuncOpTypeConversionPattern(patterns, &getContext(),
                                               converter);
 


        


More information about the Mlir-commits mailing list