[Mlir-commits] [mlir] 08e6566 - [mlir][Func] Support 1:N result type conversions in `func.call` conversion (#117413)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sat Nov 23 03:13:46 PST 2024


Author: Matthias Springer
Date: 2024-11-23T20:13:43+09:00
New Revision: 08e6566d7a310ace0660cbf3fbeb3f1c0c283295

URL: https://github.com/llvm/llvm-project/commit/08e6566d7a310ace0660cbf3fbeb3f1c0c283295
DIFF: https://github.com/llvm/llvm-project/commit/08e6566d7a310ace0660cbf3fbeb3f1c0c283295.diff

LOG: [mlir][Func] Support 1:N result type conversions in `func.call` conversion (#117413)

This commit adds support for 1:N result type conversions for `func.call`
ops. In that case, argument materializations to the original result type
should be inserted (via `replaceOpWithMultiple`).

This commit is in preparation of merging the 1:1 and 1:N conversion
drivers.

Added: 
    

Modified: 
    mlir/lib/Dialect/Func/Transforms/FuncConversions.cpp
    mlir/test/Transforms/test-legalizer.mlir
    mlir/test/lib/Dialect/Test/TestPatterns.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Func/Transforms/FuncConversions.cpp b/mlir/lib/Dialect/Func/Transforms/FuncConversions.cpp
index eb444d665ff260..b1cde6ca5d2fca 100644
--- a/mlir/lib/Dialect/Func/Transforms/FuncConversions.cpp
+++ b/mlir/lib/Dialect/Func/Transforms/FuncConversions.cpp
@@ -23,21 +23,34 @@ struct CallOpSignatureConversion : public OpConversionPattern<CallOp> {
   LogicalResult
   matchAndRewrite(CallOp callOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    // Convert the original function results.
+    // Convert the original function results. Keep track of how many result
+    // types an original result type is converted into.
+    SmallVector<size_t> numResultsReplacments;
     SmallVector<Type, 1> convertedResults;
-    if (failed(typeConverter->convertTypes(callOp.getResultTypes(),
-                                           convertedResults)))
-      return failure();
-
-    // If this isn't a one-to-one type mapping, we don't know how to aggregate
-    // the results.
-    if (callOp->getNumResults() != convertedResults.size())
-      return failure();
+    size_t numFlattenedResults = 0;
+    for (auto [idx, type] : llvm::enumerate(callOp.getResultTypes())) {
+      if (failed(typeConverter->convertTypes(type, convertedResults)))
+        return failure();
+      numResultsReplacments.push_back(convertedResults.size() -
+                                      numFlattenedResults);
+      numFlattenedResults = convertedResults.size();
+    }
 
     // Substitute with the new result types from the corresponding FuncType
     // conversion.
-    rewriter.replaceOpWithNewOp<CallOp>(
-        callOp, callOp.getCallee(), convertedResults, adaptor.getOperands());
+    auto newCallOp =
+        rewriter.create<CallOp>(callOp.getLoc(), callOp.getCallee(),
+                                convertedResults, adaptor.getOperands());
+    SmallVector<ValueRange> replacements;
+    size_t offset = 0;
+    for (int i = 0, e = callOp->getNumResults(); i < e; ++i) {
+      replacements.push_back(
+          newCallOp->getResults().slice(offset, numResultsReplacments[i]));
+      offset += numResultsReplacments[i];
+    }
+    assert(offset == convertedResults.size() &&
+           "expected that all converted results are used");
+    rewriter.replaceOpWithMultiple(callOp, replacements);
     return success();
   }
 };

diff  --git a/mlir/test/Transforms/test-legalizer.mlir b/mlir/test/Transforms/test-legalizer.mlir
index e5503ee8920424..e05f444afa68f0 100644
--- a/mlir/test/Transforms/test-legalizer.mlir
+++ b/mlir/test/Transforms/test-legalizer.mlir
@@ -379,15 +379,24 @@ builtin.module {
 
 // -----
 
-// expected-remark @below {{applyPartialConversion failed}}
 module {
-  func.func private @callee(%0 : f32) -> f32
-
-  func.func @caller( %arg: f32) {
-    // expected-error @below {{failed to legalize}}
-    %1 = func.call @callee(%arg) : (f32) -> f32
-    return
-  }
+// CHECK-LABEL: func.func private @callee() -> (f16, f16)
+func.func private @callee() -> (f32, i24)
+
+// CHECK: func.func @caller()
+func.func @caller() {
+  // f32 is converted to (f16, f16).
+  // i24 is converted to ().
+  // CHECK: %[[call:.*]]:2 = call @callee() : () -> (f16, f16)
+  %0:2 = func.call @callee() : () -> (f32, i24)
+
+  // CHECK: %[[cast1:.*]] = "test.cast"() : () -> i24
+  // CHECK: %[[cast0:.*]] = "test.cast"(%[[call]]#0, %[[call]]#1) : (f16, f16) -> f32
+  // CHECK: "test.some_user"(%[[cast0]], %[[cast1]]) : (f32, i24) -> ()
+  // expected-remark @below{{'test.some_user' is not legalizable}}
+  "test.some_user"(%0#0, %0#1) : (f32, i24) -> ()
+  "test.return"() : () -> ()
+}
 }
 
 // -----

diff  --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index 3df6cff3c0a60b..bbd55938718fe7 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -1215,6 +1215,11 @@ struct TestTypeConverter : public TypeConverter {
       return success();
     }
 
+    // Drop I24 types.
+    if (t.isInteger(24)) {
+      return success();
+    }
+
     // Otherwise, convert the type directly.
     results.push_back(t);
     return success();


        


More information about the Mlir-commits mailing list