[Mlir-commits] [mlir] [mlir][Transforms] Dialect conversion: fix crash when converting detached region (PR #100633)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Jul 25 12:01:09 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Matthias Springer (matthias-springer)

<details>
<summary>Changes</summary>

This commit fixes a crash in the dialect conversion when applying a signature conversion to a block inside of a detached region.

This fixes an issue reported in https://github.com/llvm/llvm-project/pull/97213/files/4114d5be87596e11d86706a338248ebf05cf7150#r1691809730.

---
Full diff: https://github.com/llvm/llvm-project/pull/100633.diff


3 Files Affected:

- (modified) mlir/lib/Transforms/Utils/DialectConversion.cpp (+2-1) 
- (modified) mlir/test/Transforms/test-legalizer.mlir (+15) 
- (modified) mlir/test/lib/Dialect/Test/TestPatterns.cpp (+41-10) 


``````````diff
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index a045868b66031..059288e18049b 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -1370,7 +1370,8 @@ Value ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
 
   // Create an unresolved materialization. We use a new OpBuilder to avoid
   // tracking the materialization like we do for other operations.
-  OpBuilder builder(insertBlock, insertPt);
+  OpBuilder builder(outputType.getContext());
+  builder.setInsertionPoint(insertBlock, insertPt);
   auto convertOp =
       builder.create<UnrealizedConversionCastOp>(loc, outputType, inputs);
   appendRewrite<UnresolvedMaterializationRewrite>(convertOp, converter, kind);
diff --git a/mlir/test/Transforms/test-legalizer.mlir b/mlir/test/Transforms/test-legalizer.mlir
index 65c947198e06e..b153f8959b071 100644
--- a/mlir/test/Transforms/test-legalizer.mlir
+++ b/mlir/test/Transforms/test-legalizer.mlir
@@ -437,3 +437,18 @@ func.func @fold_legalization() -> i32 {
   %1 = "test.op_in_place_self_fold"() : () -> (i32)
   "test.return"(%1) : (i32) -> ()
 }
+
+// -----
+
+// CHECK-LABEL: func @convert_detached_signature()
+//       CHECK:   "test.legal_op_with_region"() ({
+//       CHECK:   ^bb0(%arg0: f64):
+//       CHECK:     "test.return"() : () -> ()
+//       CHECK:   }) : () -> ()
+func.func @convert_detached_signature() {
+  "test.detached_signature_conversion"() ({
+  ^bb0(%arg0: i64):
+    "test.return"() : () -> ()
+  }) : () -> ()
+  "test.return"() : () -> ()
+}
\ No newline at end of file
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index a14a5da341098..83672843b16ce 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -685,6 +685,35 @@ namespace {
 //===----------------------------------------------------------------------===//
 // Region-Block Rewrite Testing
 
+/// This pattern applies a signature conversion to a block inside a detached
+/// region.
+struct TestDetachedSignatureConversion : public ConversionPattern {
+  TestDetachedSignatureConversion(MLIRContext *ctx)
+      : ConversionPattern("test.detached_signature_conversion", /*benefit=*/1,
+                          ctx) {}
+
+  LogicalResult
+  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const final {
+    if (op->getNumRegions() != 1)
+      return failure();
+    OperationState state(op->getLoc(), "test.legal_op_with_region", operands,
+                         op->getResultTypes(), {}, BlockRange());
+    Region *newRegion = state.addRegion();
+    rewriter.inlineRegionBefore(op->getRegion(0), *newRegion,
+                                newRegion->begin());
+    TypeConverter::SignatureConversion result(newRegion->getNumArguments());
+    for (unsigned i = 0; i < newRegion->getNumArguments(); ++i) {
+      result.addInputs(i, rewriter.getF64Type());
+    }
+    rewriter.applySignatureConversion(&newRegion->front(), result);
+    Operation *newOp = rewriter.create(state);
+    newOp->dump();
+    rewriter.replaceOp(op, newOp->getResults());
+    return success();
+  }
+};
+
 /// This pattern is a simple pattern that inlines the first region of a given
 /// operation into the parent region.
 struct TestRegionRewriteBlockMovement : public ConversionPattern {
@@ -1112,16 +1141,16 @@ struct TestLegalizePatternDriver
     TestTypeConverter converter;
     mlir::RewritePatternSet patterns(&getContext());
     populateWithGenerated(patterns);
-    patterns
-        .add<TestRegionRewriteBlockMovement, TestRegionRewriteUndo,
-             TestCreateBlock, TestCreateIllegalBlock, TestUndoBlockArgReplace,
-             TestUndoBlockErase, TestPassthroughInvalidOp, TestSplitReturnType,
-             TestChangeProducerTypeI32ToF32, TestChangeProducerTypeF32ToF64,
-             TestChangeProducerTypeF32ToInvalid, TestUpdateConsumerType,
-             TestNonRootReplacement, TestBoundedRecursiveRewrite,
-             TestNestedOpCreationUndoRewrite, TestReplaceEraseOp,
-             TestCreateUnregisteredOp, TestUndoMoveOpBefore,
-             TestUndoPropertiesModification>(&getContext());
+    patterns.add<
+        TestRegionRewriteBlockMovement, TestDetachedSignatureConversion,
+        TestRegionRewriteUndo, TestCreateBlock, TestCreateIllegalBlock,
+        TestUndoBlockArgReplace, TestUndoBlockErase, TestPassthroughInvalidOp,
+        TestSplitReturnType, TestChangeProducerTypeI32ToF32,
+        TestChangeProducerTypeF32ToF64, TestChangeProducerTypeF32ToInvalid,
+        TestUpdateConsumerType, TestNonRootReplacement,
+        TestBoundedRecursiveRewrite, TestNestedOpCreationUndoRewrite,
+        TestReplaceEraseOp, TestCreateUnregisteredOp, TestUndoMoveOpBefore,
+        TestUndoPropertiesModification>(&getContext());
     patterns.add<TestDropOpSignatureConversion>(&getContext(), converter);
     mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns,
                                                               converter);
@@ -1132,6 +1161,8 @@ struct TestLegalizePatternDriver
     target.addLegalOp<ModuleOp>();
     target.addLegalOp<LegalOpA, LegalOpB, LegalOpC, TestCastOp, TestValidOp,
                       TerminatorOp, OneRegionOp>();
+    target.addLegalOp(
+        OperationName("test.legal_op_with_region", &getContext()));
     target
         .addIllegalOp<ILLegalOpF, TestRegionBuilderOp, TestOpWithRegionFold>();
     target.addDynamicallyLegalOp<TestReturnOp>([](TestReturnOp op) {

``````````

</details>


https://github.com/llvm/llvm-project/pull/100633


More information about the Mlir-commits mailing list