[flang-commits] [flang] c336e72 - [flang] Fix function result rewrite for CPTR type
via flang-commits
flang-commits at lists.llvm.org
Tue Nov 8 17:25:54 PST 2022
Author: Peixin-Qiao
Date: 2022-11-09T09:24:38+08:00
New Revision: c336e72c82b11732f1728ab2f608ed99c7843258
URL: https://github.com/llvm/llvm-project/commit/c336e72c82b11732f1728ab2f608ed99c7843258
DIFF: https://github.com/llvm/llvm-project/commit/c336e72c82b11732f1728ab2f608ed99c7843258.diff
LOG: [flang] Fix function result rewrite for CPTR type
Not all derived type can be taken as abstract result. The CPTR type
should be treated as return by value so to interoperable with C
functions. Fix the function result rewrite for CPTR type, but it
should be generalized for all derived types. The ABI of
interoperability with C for derived type is architecture dependent,
which should be supported later.
Reviewed By: PeteSteinfeld, jeanPerier
Differential Revision: https://reviews.llvm.org/D137548
Added:
Modified:
flang/lib/Optimizer/Dialect/FIRType.cpp
flang/lib/Optimizer/Transforms/AbstractResult.cpp
flang/test/Fir/abstract-results.fir
Removed:
################################################################################
diff --git a/flang/lib/Optimizer/Dialect/FIRType.cpp b/flang/lib/Optimizer/Dialect/FIRType.cpp
index c509ce0fcdcfb..67b4d1af7cf17 100644
--- a/flang/lib/Optimizer/Dialect/FIRType.cpp
+++ b/flang/lib/Optimizer/Dialect/FIRType.cpp
@@ -948,10 +948,6 @@ bool fir::hasAbstractResult(mlir::FunctionType ty) {
if (ty.getNumResults() == 0)
return false;
auto resultType = ty.getResult(0);
- // FIXME: The interoperable derived type needs more investigations and tests.
- // The derived type without BIND attribute may also not be abstract result.
- if (fir::isa_builtin_cptr_type(resultType))
- return false;
return resultType.isa<fir::SequenceType, fir::BoxType, fir::RecordType>();
}
diff --git a/flang/lib/Optimizer/Transforms/AbstractResult.cpp b/flang/lib/Optimizer/Transforms/AbstractResult.cpp
index 87f8a4b399375..6cdc445ac18ff 100644
--- a/flang/lib/Optimizer/Transforms/AbstractResult.cpp
+++ b/flang/lib/Optimizer/Transforms/AbstractResult.cpp
@@ -6,10 +6,12 @@
//
//===----------------------------------------------------------------------===//
+#include "flang/Optimizer/Builder/FIRBuilder.h"
#include "flang/Optimizer/Builder/Todo.h"
#include "flang/Optimizer/Dialect/FIRDialect.h"
#include "flang/Optimizer/Dialect/FIROps.h"
#include "flang/Optimizer/Dialect/FIRType.h"
+#include "flang/Optimizer/Support/FIRContext.h"
#include "flang/Optimizer/Transforms/Passes.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/Diagnostics.h"
@@ -56,6 +58,18 @@ static mlir::FunctionType getNewFunctionType(mlir::FunctionType funcTy,
/*resultTypes=*/{});
}
+/// This is for function result types that are of type C_PTR from ISO_C_BINDING.
+/// Follow the ABI for interoperability with C.
+static mlir::FunctionType getCPtrFunctionType(mlir::FunctionType funcTy) {
+ auto resultType = funcTy.getResult(0);
+ assert(fir::isa_builtin_cptr_type(resultType));
+ llvm::SmallVector<mlir::Type> outputTypes;
+ auto recTy = resultType.dyn_cast<fir::RecordType>();
+ outputTypes.emplace_back(recTy.getTypeList()[0].second);
+ return mlir::FunctionType::get(funcTy.getContext(), funcTy.getInputs(),
+ outputTypes);
+}
+
static bool mustEmboxResult(mlir::Type resultType, bool shouldBoxResult) {
return resultType.isa<fir::SequenceType, fir::RecordType>() &&
shouldBoxResult;
@@ -92,28 +106,50 @@ class CallOpConversion : public mlir::OpRewritePattern<fir::CallOp> {
saveResult.getTypeparams());
llvm::SmallVector<mlir::Type> newResultTypes;
+ // TODO: This should be generalized for derived types, and it is
+ // architecture and OS dependent.
+ bool isResultBuiltinCPtr = fir::isa_builtin_cptr_type(result.getType());
+ fir::CallOp newCallOp;
+ if (isResultBuiltinCPtr) {
+ auto recTy = result.getType().dyn_cast<fir::RecordType>();
+ newResultTypes.emplace_back(recTy.getTypeList()[0].second);
+ }
if (callOp.getCallee()) {
- llvm::SmallVector<mlir::Value> newOperands = {arg};
+ llvm::SmallVector<mlir::Value> newOperands;
+ if (!isResultBuiltinCPtr)
+ newOperands.emplace_back(arg);
newOperands.append(callOp.getOperands().begin(),
callOp.getOperands().end());
- rewriter.create<fir::CallOp>(loc, *callOp.getCallee(), newResultTypes,
- newOperands);
+ newCallOp = rewriter.create<fir::CallOp>(loc, *callOp.getCallee(),
+ newResultTypes, newOperands);
} else {
// Indirect calls.
- llvm::SmallVector<mlir::Type> newInputTypes = {argType};
+ llvm::SmallVector<mlir::Type> newInputTypes;
+ if (!isResultBuiltinCPtr)
+ newInputTypes.emplace_back(argType);
for (auto operand : callOp.getOperands().drop_front())
newInputTypes.push_back(operand.getType());
- auto funTy = mlir::FunctionType::get(callOp.getContext(), newInputTypes,
- newResultTypes);
+ auto newFuncTy = mlir::FunctionType::get(callOp.getContext(),
+ newInputTypes, newResultTypes);
llvm::SmallVector<mlir::Value> newOperands;
- newOperands.push_back(
- rewriter.create<fir::ConvertOp>(loc, funTy, callOp.getOperand(0)));
- newOperands.push_back(arg);
+ newOperands.push_back(rewriter.create<fir::ConvertOp>(
+ loc, newFuncTy, callOp.getOperand(0)));
+ if (!isResultBuiltinCPtr)
+ newOperands.push_back(arg);
newOperands.append(callOp.getOperands().begin() + 1,
callOp.getOperands().end());
- rewriter.create<fir::CallOp>(loc, mlir::SymbolRefAttr{}, newResultTypes,
- newOperands);
+ newCallOp = rewriter.create<fir::CallOp>(loc, mlir::SymbolRefAttr{},
+ newResultTypes, newOperands);
+ }
+ if (isResultBuiltinCPtr) {
+ mlir::Value save = saveResult.getMemref();
+ auto module = callOp->getParentOfType<mlir::ModuleOp>();
+ fir::KindMapping kindMap = fir::getKindMapping(module);
+ FirOpBuilder builder(rewriter, kindMap);
+ mlir::Value saveAddr = fir::factory::genCPtrOrCFunptrAddr(
+ builder, loc, save, result.getType());
+ rewriter.create<fir::StoreOp>(loc, newCallOp->getResult(0), saveAddr);
}
callOp->dropAllReferences();
rewriter.eraseOp(callOp);
@@ -146,12 +182,28 @@ class ReturnOpConversion : public mlir::OpRewritePattern<mlir::func::ReturnOp> {
mlir::LogicalResult
matchAndRewrite(mlir::func::ReturnOp ret,
mlir::PatternRewriter &rewriter) const override {
+ auto loc = ret.getLoc();
rewriter.setInsertionPoint(ret);
auto returnedValue = ret.getOperand(0);
bool replacedStorage = false;
if (auto *op = returnedValue.getDefiningOp())
if (auto load = mlir::dyn_cast<fir::LoadOp>(op)) {
auto resultStorage = load.getMemref();
+ // TODO: This should be generalized for derived types, and it is
+ // architecture and OS dependent.
+ if (fir::isa_builtin_cptr_type(returnedValue.getType())) {
+ rewriter.eraseOp(load);
+ auto module = ret->getParentOfType<mlir::ModuleOp>();
+ fir::KindMapping kindMap = fir::getKindMapping(module);
+ FirOpBuilder builder(rewriter, kindMap);
+ mlir::Value retAddr = fir::factory::genCPtrOrCFunptrAddr(
+ builder, loc, resultStorage, returnedValue.getType());
+ mlir::Value retValue = rewriter.create<fir::LoadOp>(
+ loc, fir::unwrapRefType(retAddr.getType()), retAddr);
+ rewriter.replaceOpWithNewOp<mlir::func::ReturnOp>(
+ ret, mlir::ValueRange{retValue});
+ return mlir::success();
+ }
load.getMemref().replaceAllUsesWith(newArg);
replacedStorage = true;
if (auto *alloc = resultStorage.getDefiningOp())
@@ -163,7 +215,7 @@ class ReturnOpConversion : public mlir::OpRewritePattern<mlir::func::ReturnOp> {
// with no length parameters. Simply store the result in the result storage.
// at the return point.
if (!replacedStorage)
- rewriter.create<fir::StoreOp>(ret.getLoc(), returnedValue, newArg);
+ rewriter.create<fir::StoreOp>(loc, returnedValue, newArg);
rewriter.replaceOpWithNewOp<mlir::func::ReturnOp>(ret);
return mlir::success();
}
@@ -181,7 +233,14 @@ class AddrOfOpConversion : public mlir::OpRewritePattern<fir::AddrOfOp> {
matchAndRewrite(fir::AddrOfOp addrOf,
mlir::PatternRewriter &rewriter) const override {
auto oldFuncTy = addrOf.getType().cast<mlir::FunctionType>();
- auto newFuncTy = getNewFunctionType(oldFuncTy, shouldBoxResult);
+ mlir::FunctionType newFuncTy;
+ // TODO: This should be generalized for derived types, and it is
+ // architecture and OS dependent.
+ if (oldFuncTy.getNumResults() != 0 &&
+ fir::isa_builtin_cptr_type(oldFuncTy.getResult(0)))
+ newFuncTy = getCPtrFunctionType(oldFuncTy);
+ else
+ newFuncTy = getNewFunctionType(oldFuncTy, shouldBoxResult);
auto newAddrOf = rewriter.create<fir::AddrOfOp>(addrOf.getLoc(), newFuncTy,
addrOf.getSymbol());
// Rather than converting all op a function pointer might transit through
@@ -263,6 +322,18 @@ class AbstractResultOnFuncOpt
// Convert function type itself if it has an abstract result.
auto funcTy = func.getFunctionType().cast<mlir::FunctionType>();
if (hasAbstractResult(funcTy)) {
+ // TODO: This should be generalized for derived types, and it is
+ // architecture and OS dependent.
+ if (fir::isa_builtin_cptr_type(funcTy.getResult(0))) {
+ func.setType(getCPtrFunctionType(funcTy));
+ patterns.insert<ReturnOpConversion>(context, mlir::Value{});
+ target.addDynamicallyLegalOp<mlir::func::ReturnOp>(
+ [](mlir::func::ReturnOp ret) {
+ mlir::Type retTy = ret.getOperand(0).getType();
+ return !fir::isa_builtin_cptr_type(retTy);
+ });
+ return;
+ }
func.setType(getNewFunctionType(funcTy, shouldBoxResult));
if (!func.empty()) {
// Insert new argument.
diff --git a/flang/test/Fir/abstract-results.fir b/flang/test/Fir/abstract-results.fir
index 92d803e4994ba..14c59a6569744 100644
--- a/flang/test/Fir/abstract-results.fir
+++ b/flang/test/Fir/abstract-results.fir
@@ -87,6 +87,26 @@ func.func @boxfunc_callee() -> !fir.box<!fir.heap<f64>> {
// FUNC-BOX: return
}
+// FUNC-REF-LABEL: func @retcptr() -> i64
+// FUNC-BOX-LABEL: func @retcptr() -> i64
+func.func @retcptr() -> !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}> {
+ %0 = fir.alloca !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}> {bindc_name = "rec", uniq_name = "_QFrecErec"}
+ %1 = fir.load %0 : !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>
+ return %1 : !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>
+
+ // FUNC-REF: %[[ALLOC:.*]] = fir.alloca !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}> {bindc_name = "rec", uniq_name = "_QFrecErec"}
+ // FUNC-REF: %[[FIELD:.*]] = fir.field_index __address, !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>
+ // FUNC-REF: %[[ADDR:.*]] = fir.coordinate_of %[[ALLOC]], %[[FIELD]] : (!fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>, !fir.field) -> !fir.ref<i64>
+ // FUNC-REF: %[[VAL:.*]] = fir.load %[[ADDR]] : !fir.ref<i64>
+ // FUNC-REF: return %[[VAL]] : i64
+ // FUNC-BOX: %[[ALLOC:.*]] = fir.alloca !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}> {bindc_name = "rec", uniq_name = "_QFrecErec"}
+ // FUNC-BOX: %[[FIELD:.*]] = fir.field_index __address, !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>
+ // FUNC-BOX: %[[ADDR:.*]] = fir.coordinate_of %[[ALLOC]], %[[FIELD]] : (!fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>, !fir.field) -> !fir.ref<i64>
+ // FUNC-BOX: %[[VAL:.*]] = fir.load %[[ADDR]] : !fir.ref<i64>
+ // FUNC-BOX: return %[[VAL]] : i64
+}
+
+
// ------------------------ Test caller rewrite --------------------------------
// FUNC-REF-LABEL: func @call_arrayfunc() {
@@ -202,15 +222,26 @@ func.func @call_chararrayfunc() {
// FUNC-BOX-NOT: fir.save_result
}
-func.func private @rettcptr() -> !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}> attributes {fir.bindc_name = "rettcptr"}
-
-// FUNC-REF-LABEL: func @_QPtest_return_cptr() {
-// FUNC-BOX-LABEL: func @_QPtest_return_cptr() {
+// FUNC-REF-LABEL: func @_QPtest_return_cptr
+// FUNC-BOX-LABEL: func @_QPtest_return_cptr
func.func @_QPtest_return_cptr() {
- // FUNC-REF: [[VAL:.*]] = fir.call @rettcptr() : () -> !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>
- // FUNC-BOX: [[VAL:.*]] = fir.call @rettcptr() : () -> !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>
- %1 = fir.call @rettcptr() : () -> !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>
+ %0 = fir.alloca !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}> {bindc_name = ".result"}
+ %1 = fir.call @retcptr() : () -> i64
+ %2 = fir.field_index __address, !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>
+ %3 = fir.coordinate_of %0, %2 : (!fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>, !fir.field) -> !fir.ref<i64>
+ fir.store %1 to %3 : !fir.ref<i64>
return
+
+ // FUNC-REF: %[[ALLOC:.*]] = fir.alloca !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}> {bindc_name = ".result"}
+ // FUNC-REF: %[[VAL:.*]] = fir.call @retcptr() : () -> i64
+ // FUNC-REF: %[[FIELD:.*]] = fir.field_index __address, !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>
+ // FUNC-REF: %[[ADDR:.*]] = fir.coordinate_of %[[ALLOC]], %[[FIELD]] : (!fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>, !fir.field) -> !fir.ref<i64>
+ // FUNC-REF: fir.store %[[VAL]] to %[[ADDR]] : !fir.ref<i64>
+ // FUNC-BOX: %[[ALLOC:.*]] = fir.alloca !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}> {bindc_name = ".result"}
+ // FUNC-BOX: %[[VAL:.*]] = fir.call @retcptr() : () -> i64
+ // FUNC-BOX: %[[FIELD:.*]] = fir.field_index __address, !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>
+ // FUNC-BOX: %[[ADDR:.*]] = fir.coordinate_of %[[ALLOC]], %[[FIELD]] : (!fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>, !fir.field) -> !fir.ref<i64>
+ // FUNC-BOX: fir.store %[[VAL]] to %[[ADDR]] : !fir.ref<i64>
}
// ------------------------ Test fir.address_of rewrite ------------------------
@@ -234,6 +265,29 @@ func.func @test_address_of() {
}
+// FUNC-REF-LABEL: func.func private @returns_null() -> i64
+// FUNC-BOX-LABEL: func.func private @returns_null() -> i64
+func.func private @returns_null() -> !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>
+
+// FUNC-REF-LABEL: func @test_address_of_cptr
+// FUNC-BOX-LABEL: func @test_address_of_cptr
+func.func @test_address_of_cptr() {
+ %0 = fir.address_of(@returns_null) : () -> !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>
+ %1 = fir.convert %0 : (() -> !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>) -> (() -> ())
+ fir.call @_QMtest_c_func_modPsubr(%1) : (() -> ()) -> ()
+ return
+
+ // FUNC-REF: %[[VAL_0:.*]] = fir.address_of(@returns_null) : () -> i64
+ // FUNC-REF: %[[VAL_1:.*]] = fir.convert %[[VAL_0]] : (() -> i64) -> (() -> !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>)
+ // FUNC-REF: %[[VAL_2:.*]] = fir.convert %[[VAL_1]] : (() -> !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>) -> (() -> ())
+ // FUNC-REF: fir.call @_QMtest_c_func_modPsubr(%[[VAL_2]]) : (() -> ()) -> ()
+ // FUNC-BOX: %[[VAL_0:.*]] = fir.address_of(@returns_null) : () -> i64
+ // FUNC-BOX: %[[VAL_1:.*]] = fir.convert %[[VAL_0]] : (() -> i64) -> (() -> !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>)
+ // FUNC-BOX: %[[VAL_2:.*]] = fir.convert %[[VAL_1]] : (() -> !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>) -> (() -> ())
+ // FUNC-BOX: fir.call @_QMtest_c_func_modPsubr(%[[VAL_2]]) : (() -> ()) -> ()
+}
+
+
// ----------------------- Test indirect calls rewrite ------------------------
// FUNC-REF-LABEL: func @test_indirect_calls(
@@ -267,6 +321,33 @@ func.func @test_indirect_calls(%arg0: () -> ()) {
// FUNC-BOX-NOT: fir.save_result
}
+// FUNC-REF-LABEL: func @test_indirect_calls_return_cptr(
+// FUNC-REF-SAME: %[[ARG0:.*]]: () -> ())
+// FUNC-BOX-LABEL: func @test_indirect_calls_return_cptr(
+// FUNC-BOX-SAME: %[[ARG0:.*]]: () -> ())
+func.func @test_indirect_calls_return_cptr(%arg0: () -> ()) {
+ %0 = fir.alloca !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}> {bindc_name = ".result"}
+ %1 = fir.convert %arg0 : (() -> ()) -> (() -> !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>)
+ %2 = fir.call %1() : () -> !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>
+ fir.save_result %2 to %0 : !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>, !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>
+ return
+
+ // FUNC-REF: %[[VAL_0:.*]] = fir.alloca !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}> {bindc_name = ".result"}
+ // FUNC-REF: %[[VAL_1:.*]] = fir.convert %[[ARG0]] : (() -> ()) -> (() -> !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>)
+ // FUNC-REF: %[[VAL_2:.*]] = fir.convert %[[VAL_1]] : (() -> !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>) -> (() -> i64)
+ // FUNC-REF: %[[VAL_3:.*]] = fir.call %[[VAL_2]]() : () -> i64
+ // FUNC-REF: %[[VAL_4:.*]] = fir.field_index __address, !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>
+ // FUNC-REF: %[[VAL_5:.*]] = fir.coordinate_of %[[VAL_0]], %[[VAL_4]] : (!fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>, !fir.field) -> !fir.ref<i64>
+ // FUNC-REF: fir.store %[[VAL_3]] to %[[VAL_5]] : !fir.ref<i64>
+ // FUNC-BOX: %[[VAL_0:.*]] = fir.alloca !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}> {bindc_name = ".result"}
+ // FUNC-BOX: %[[VAL_1:.*]] = fir.convert %[[ARG0]] : (() -> ()) -> (() -> !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>)
+ // FUNC-BOX: %[[VAL_2:.*]] = fir.convert %[[VAL_1]] : (() -> !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>) -> (() -> i64)
+ // FUNC-BOX: %[[VAL_3:.*]] = fir.call %[[VAL_2]]() : () -> i64
+ // FUNC-BOX: %[[VAL_4:.*]] = fir.field_index __address, !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>
+ // FUNC-BOX: %[[VAL_5:.*]] = fir.coordinate_of %[[VAL_0]], %[[VAL_4]] : (!fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>, !fir.field) -> !fir.ref<i64>
+ // FUNC-BOX: fir.store %[[VAL_3]] to %[[VAL_5]] : !fir.ref<i64>
+}
+
// ----------------------- Test GlobalOp rewrite ------------------------
// This is needed to separate GlobalOp tests from FuncOp tests for FileCheck
More information about the flang-commits
mailing list