[Mlir-commits] [mlir] 8fc3294 - [mlir][Transforms] Dialect conversion: Add missing "else if" branch (#101148)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Jul 30 07:36:50 PDT 2024


Author: Matthias Springer
Date: 2024-07-30T16:36:47+02:00
New Revision: 8fc329421b2e3bc5cdda98ce303ea3b07af58ebc

URL: https://github.com/llvm/llvm-project/commit/8fc329421b2e3bc5cdda98ce303ea3b07af58ebc
DIFF: https://github.com/llvm/llvm-project/commit/8fc329421b2e3bc5cdda98ce303ea3b07af58ebc.diff

LOG: [mlir][Transforms] Dialect conversion: Add missing "else if" branch (#101148)

This code got lost in #97213 and there was no test for it. Add it back
with an MLIR test.

When a pattern is run without a type converter, we can assume that the
new block argument types of a signature conversion are legal. That's
because they were specified by the user. This won't work for 1->N
conversions due to limitations in the dialect conversion infrastructure,
so the original `FIXME` has to stay in place.

Added: 
    

Modified: 
    mlir/lib/Transforms/Utils/DialectConversion.cpp
    mlir/test/Transforms/test-legalize-type-conversion.mlir
    mlir/test/lib/Dialect/Test/TestOps.td
    mlir/test/lib/Dialect/Test/TestPatterns.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 059288e18049b..f26aa0a1516a6 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -1328,15 +1328,19 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
     mapping.map(origArg, argMat);
     appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
 
-    // FIXME: We simply pass through the replacement argument if there wasn't a
-    // converter, which isn't great as it allows implicit type conversions to
-    // appear. We should properly restructure this code to handle cases where a
-    // converter isn't provided and also to properly handle the case where an
-    // argument materialization is actually a temporary source materialization
-    // (e.g. in the case of 1->N).
     Type legalOutputType;
-    if (converter)
+    if (converter) {
       legalOutputType = converter->convertType(origArgType);
+    } else if (replArgs.size() == 1) {
+      // When there is no type converter, assume that the new block argument
+      // types are legal. This is reasonable to assume because they were
+      // specified by the user.
+      // FIXME: This won't work for 1->N conversions because multiple output
+      // types are not supported in parts of the dialect conversion. In such a
+      // case, we currently use the original block argument type (produced by
+      // the argument materialization).
+      legalOutputType = replArgs[0].getType();
+    }
     if (legalOutputType && legalOutputType != origArgType) {
       Value targetMat = buildUnresolvedTargetMaterialization(
           origArg.getLoc(), argMat, legalOutputType, converter);

diff  --git a/mlir/test/Transforms/test-legalize-type-conversion.mlir b/mlir/test/Transforms/test-legalize-type-conversion.mlir
index 8254be68912c8..d0563fed8e5d9 100644
--- a/mlir/test/Transforms/test-legalize-type-conversion.mlir
+++ b/mlir/test/Transforms/test-legalize-type-conversion.mlir
@@ -127,3 +127,18 @@ llvm.func @unsupported_func_op_interface() {
   // CHECK: llvm.return
   llvm.return
 }
+
+// -----
+
+// CHECK-LABEL: func @test_signature_conversion_no_converter()
+func.func @test_signature_conversion_no_converter() {
+  // CHECK: "test.signature_conversion_no_converter"() ({
+  // CHECK: ^{{.*}}(%[[arg0:.*]]: f64):
+  "test.signature_conversion_no_converter"() ({
+  ^bb0(%arg0: f32):
+    // CHECK: "test.legal_op_d"(%[[arg0]]) : (f64) -> ()
+    "test.replace_with_legal_op"(%arg0) : (f32) -> ()
+    "test.return"() : () -> ()
+  }) : () -> ()
+  return
+}

diff  --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 2d97a02b8076a..2b55bff3538d3 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -1884,6 +1884,7 @@ def LegalOpA : TEST_Op<"legal_op_a">,
 def LegalOpB : TEST_Op<"legal_op_b">, Results<(outs I32)>;
 def LegalOpC : TEST_Op<"legal_op_c">,
   Arguments<(ins I32)>, Results<(outs I32)>;
+def LegalOpD : TEST_Op<"legal_op_d">, Arguments<(ins AnyType)>;
 
 // Check that the conversion infrastructure can properly undo the creation of
 // operations where an operation was created before its parent, in this case,

diff  --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index 0546523a58c80..91dfb2faa80a1 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -1580,6 +1580,17 @@ struct TestTypeConversionAnotherProducer
   }
 };
 
+struct TestReplaceWithLegalOp : public ConversionPattern {
+  TestReplaceWithLegalOp(MLIRContext *ctx)
+      : ConversionPattern("test.replace_with_legal_op", /*benefit=*/1, ctx) {}
+  LogicalResult
+  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const final {
+    rewriter.replaceOpWithNewOp<LegalOpD>(op, operands[0]);
+    return success();
+  }
+};
+
 struct TestTypeConversionDriver
     : public PassWrapper<TestTypeConversionDriver, OperationPass<>> {
   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestTypeConversionDriver)
@@ -1671,6 +1682,7 @@ struct TestTypeConversionDriver
 
     // Initialize the conversion target.
     mlir::ConversionTarget target(getContext());
+    target.addLegalOp<LegalOpD>();
     target.addDynamicallyLegalOp<TestTypeProducerOp>([](TestTypeProducerOp op) {
       auto recursiveType = dyn_cast<test::TestRecursiveType>(op.getType());
       return op.getType().isF64() || op.getType().isInteger(64) ||
@@ -1696,7 +1708,8 @@ struct TestTypeConversionDriver
                  TestSignatureConversionUndo,
                  TestTestSignatureConversionNoConverter>(converter,
                                                          &getContext());
-    patterns.add<TestTypeConversionAnotherProducer>(&getContext());
+    patterns.add<TestTypeConversionAnotherProducer, TestReplaceWithLegalOp>(
+        &getContext());
     mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns,
                                                               converter);
 


        


More information about the Mlir-commits mailing list