[flang-commits] [flang] [flang] fix C_PTR function result lowering (PR #100082)
via flang-commits
flang-commits at lists.llvm.org
Tue Jul 23 01:50:46 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-flang-fir-hlfir
Author: None (jeanPerier)
<details>
<summary>Changes</summary>
Functions returning C_PTR were lowered to function returning intptr (i64 on 64bit arch). This caused conflicts when these functions were defined as returning !fir.ref<none>/llvm.ptr in other compiler generated contexts (e.g., malloc).
Lower them to return !fir.ref<none>.
This should deal with https://github.com/llvm/llvm-project/issues/97325 and https://github.com/llvm/llvm-project/issues/98644.
---
Full diff: https://github.com/llvm/llvm-project/pull/100082.diff
3 Files Affected:
- (modified) flang/lib/Optimizer/Builder/FIRBuilder.cpp (+35-20)
- (modified) flang/lib/Optimizer/Transforms/AbstractResult.cpp (+56-52)
- (modified) flang/test/Fir/abstract-results.fir (+20-16)
``````````diff
diff --git a/flang/lib/Optimizer/Builder/FIRBuilder.cpp b/flang/lib/Optimizer/Builder/FIRBuilder.cpp
index 2961df96b3cab..9599e1505303c 100644
--- a/flang/lib/Optimizer/Builder/FIRBuilder.cpp
+++ b/flang/lib/Optimizer/Builder/FIRBuilder.cpp
@@ -1541,21 +1541,45 @@ mlir::Value fir::factory::genMaxWithZero(fir::FirOpBuilder &builder,
zero);
}
+static std::pair<mlir::Value, mlir::Type>
+genCPtrOrCFunptrFieldIndex(fir::FirOpBuilder &builder, mlir::Location loc,
+ mlir::Type cptrTy) {
+ assert(mlir::isa<fir::RecordType>(cptrTy));
+ auto recTy = mlir::dyn_cast<fir::RecordType>(cptrTy);
+ assert(recTy.getTypeList().size() == 1);
+ auto addrFieldName = recTy.getTypeList()[0].first;
+ mlir::Type addrFieldTy = recTy.getTypeList()[0].second;
+ auto fieldIndexType = fir::FieldType::get(cptrTy.getContext());
+ mlir::Value addrFieldIndex = builder.create<fir::FieldIndexOp>(
+ loc, fieldIndexType, addrFieldName, recTy,
+ /*typeParams=*/mlir::ValueRange{});
+ return {addrFieldIndex, addrFieldTy};
+}
+
mlir::Value fir::factory::genCPtrOrCFunptrAddr(fir::FirOpBuilder &builder,
mlir::Location loc,
mlir::Value cPtr,
mlir::Type ty) {
- assert(mlir::isa<fir::RecordType>(ty));
- auto recTy = mlir::dyn_cast<fir::RecordType>(ty);
- assert(recTy.getTypeList().size() == 1);
- auto fieldName = recTy.getTypeList()[0].first;
- mlir::Type fieldTy = recTy.getTypeList()[0].second;
- auto fieldIndexType = fir::FieldType::get(ty.getContext());
- mlir::Value field =
- builder.create<fir::FieldIndexOp>(loc, fieldIndexType, fieldName, recTy,
- /*typeParams=*/mlir::ValueRange{});
- return builder.create<fir::CoordinateOp>(loc, builder.getRefType(fieldTy),
- cPtr, field);
+ auto [addrFieldIndex, addrFieldTy] =
+ genCPtrOrCFunptrFieldIndex(builder, loc, ty);
+ return builder.create<fir::CoordinateOp>(loc, builder.getRefType(addrFieldTy),
+ cPtr, addrFieldIndex);
+}
+
+mlir::Value fir::factory::genCPtrOrCFunptrValue(fir::FirOpBuilder &builder,
+ mlir::Location loc,
+ mlir::Value cPtr) {
+ mlir::Type cPtrTy = fir::unwrapRefType(cPtr.getType());
+ if (fir::isa_ref_type(cPtr.getType())) {
+ mlir::Value cPtrAddr =
+ fir::factory::genCPtrOrCFunptrAddr(builder, loc, cPtr, cPtrTy);
+ return builder.create<fir::LoadOp>(loc, cPtrAddr);
+ }
+ auto [addrFieldIndex, addrFieldTy] =
+ genCPtrOrCFunptrFieldIndex(builder, loc, cPtrTy);
+ auto arrayAttr =
+ builder.getArrayAttr({builder.getIntegerAttr(builder.getIndexType(), 0)});
+ return builder.create<fir::ExtractValueOp>(loc, addrFieldTy, cPtr, arrayAttr);
}
fir::BoxValue fir::factory::createBoxValue(fir::FirOpBuilder &builder,
@@ -1596,15 +1620,6 @@ fir::BoxValue fir::factory::createBoxValue(fir::FirOpBuilder &builder,
return fir::BoxValue(box, lbounds, explicitTypeParams);
}
-mlir::Value fir::factory::genCPtrOrCFunptrValue(fir::FirOpBuilder &builder,
- mlir::Location loc,
- mlir::Value cPtr) {
- mlir::Type cPtrTy = fir::unwrapRefType(cPtr.getType());
- mlir::Value cPtrAddr =
- fir::factory::genCPtrOrCFunptrAddr(builder, loc, cPtr, cPtrTy);
- return builder.create<fir::LoadOp>(loc, cPtrAddr);
-}
-
mlir::Value fir::factory::createNullBoxProc(fir::FirOpBuilder &builder,
mlir::Location loc,
mlir::Type boxType) {
diff --git a/flang/lib/Optimizer/Transforms/AbstractResult.cpp b/flang/lib/Optimizer/Transforms/AbstractResult.cpp
index 3906aa553cb34..ff37310224e85 100644
--- a/flang/lib/Optimizer/Transforms/AbstractResult.cpp
+++ b/flang/lib/Optimizer/Transforms/AbstractResult.cpp
@@ -59,14 +59,16 @@ static mlir::FunctionType getNewFunctionType(mlir::FunctionType funcTy,
/*resultTypes=*/{});
}
+static mlir::Type getVoidPtrType(mlir::MLIRContext *context) {
+ return fir::ReferenceType::get(mlir::NoneType::get(context));
+}
+
/// This is for function result types that are of type C_PTR from ISO_C_BINDING.
/// Follow the ABI for interoperability with C.
static mlir::FunctionType getCPtrFunctionType(mlir::FunctionType funcTy) {
- auto resultType = funcTy.getResult(0);
- assert(fir::isa_builtin_cptr_type(resultType));
- llvm::SmallVector<mlir::Type> outputTypes;
- auto recTy = mlir::dyn_cast<fir::RecordType>(resultType);
- outputTypes.emplace_back(recTy.getTypeList()[0].second);
+ assert(fir::isa_builtin_cptr_type(funcTy.getResult(0)));
+ llvm::SmallVector<mlir::Type> outputTypes{
+ getVoidPtrType(funcTy.getContext())};
return mlir::FunctionType::get(funcTy.getContext(), funcTy.getInputs(),
outputTypes);
}
@@ -109,15 +111,11 @@ class CallConversion : public mlir::OpRewritePattern<Op> {
saveResult.getTypeparams());
llvm::SmallVector<mlir::Type> newResultTypes;
- // TODO: This should be generalized for derived types, and it is
- // architecture and OS dependent.
bool isResultBuiltinCPtr = fir::isa_builtin_cptr_type(result.getType());
- Op newOp;
- if (isResultBuiltinCPtr) {
- auto recTy = mlir::dyn_cast<fir::RecordType>(result.getType());
- newResultTypes.emplace_back(recTy.getTypeList()[0].second);
- }
+ if (isResultBuiltinCPtr)
+ newResultTypes.emplace_back(getVoidPtrType(result.getContext()));
+ Op newOp;
// fir::CallOp specific handling.
if constexpr (std::is_same_v<Op, fir::CallOp>) {
if (op.getCallee()) {
@@ -175,7 +173,7 @@ class CallConversion : public mlir::OpRewritePattern<Op> {
FirOpBuilder builder(rewriter, module);
mlir::Value saveAddr = fir::factory::genCPtrOrCFunptrAddr(
builder, loc, save, result.getType());
- rewriter.create<fir::StoreOp>(loc, newOp->getResult(0), saveAddr);
+ builder.createStoreWithConvert(loc, newOp->getResult(0), saveAddr);
}
op->dropAllReferences();
rewriter.eraseOp(op);
@@ -210,42 +208,52 @@ class ReturnOpConversion : public mlir::OpRewritePattern<mlir::func::ReturnOp> {
mlir::PatternRewriter &rewriter) const override {
auto loc = ret.getLoc();
rewriter.setInsertionPoint(ret);
- auto returnedValue = ret.getOperand(0);
- bool replacedStorage = false;
- if (auto *op = returnedValue.getDefiningOp())
- if (auto load = mlir::dyn_cast<fir::LoadOp>(op)) {
- auto resultStorage = load.getMemref();
- // The result alloca may be behind a fir.declare, if any.
- if (auto declare = mlir::dyn_cast_or_null<fir::DeclareOp>(
- resultStorage.getDefiningOp()))
- resultStorage = declare.getMemref();
- // TODO: This should be generalized for derived types, and it is
- // architecture and OS dependent.
- if (fir::isa_builtin_cptr_type(returnedValue.getType())) {
- rewriter.eraseOp(load);
- auto module = ret->getParentOfType<mlir::ModuleOp>();
- FirOpBuilder builder(rewriter, module);
- mlir::Value retAddr = fir::factory::genCPtrOrCFunptrAddr(
- builder, loc, resultStorage, returnedValue.getType());
- mlir::Value retValue = rewriter.create<fir::LoadOp>(
- loc, fir::unwrapRefType(retAddr.getType()), retAddr);
- rewriter.replaceOpWithNewOp<mlir::func::ReturnOp>(
- ret, mlir::ValueRange{retValue});
- return mlir::success();
- }
- resultStorage.replaceAllUsesWith(newArg);
- replacedStorage = true;
- if (auto *alloc = resultStorage.getDefiningOp())
- if (alloc->use_empty())
- rewriter.eraseOp(alloc);
+ mlir::Value resultValue = ret.getOperand(0);
+ fir::LoadOp resultLoad;
+ mlir::Value resultStorage;
+ // Identify result local storage.
+ if (auto load = resultValue.getDefiningOp<fir::LoadOp>()) {
+ resultLoad = load;
+ resultStorage = load.getMemref();
+ // The result alloca may be behind a fir.declare, if any.
+ if (auto declare = resultStorage.getDefiningOp<fir::DeclareOp>())
+ resultStorage = declare.getMemref();
+ }
+ // Replace old local storage with new storage argument, unless
+ // the derived type is C_PTR/C_FUN_PTR, in which case the return
+ // type is updated to return void* (no new argument is passed).
+ if (fir::isa_builtin_cptr_type(resultValue.getType())) {
+ auto module = ret->getParentOfType<mlir::ModuleOp>();
+ FirOpBuilder builder(rewriter, module);
+ mlir::Value cptr = resultValue;
+ if (resultLoad) {
+ // Replace whole derived type load by component load.
+ cptr = resultLoad.getMemref();
+ rewriter.setInsertionPoint(resultLoad);
}
- // The result storage may have been optimized out by a memory to
- // register pass, this is possible for fir.box results, or fir.record
- // with no length parameters. Simply store the result in the result storage.
- // at the return point.
- if (!replacedStorage)
- rewriter.create<fir::StoreOp>(loc, returnedValue, newArg);
- rewriter.replaceOpWithNewOp<mlir::func::ReturnOp>(ret);
+ mlir::Value newResultValue =
+ fir::factory::genCPtrOrCFunptrValue(builder, loc, cptr);
+ newResultValue = builder.createConvert(
+ loc, getVoidPtrType(ret.getContext()), newResultValue);
+ rewriter.setInsertionPoint(ret);
+ rewriter.replaceOpWithNewOp<mlir::func::ReturnOp>(
+ ret, mlir::ValueRange{newResultValue});
+ } else if (resultStorage) {
+ resultStorage.replaceAllUsesWith(newArg);
+ rewriter.replaceOpWithNewOp<mlir::func::ReturnOp>(ret);
+ } else {
+ // The result storage may have been optimized out by a memory to
+ // register pass, this is possible for fir.box results, or fir.record
+ // with no length parameters. Simply store the result in the result
+ // storage. at the return point.
+ rewriter.create<fir::StoreOp>(loc, resultValue, newArg);
+ rewriter.replaceOpWithNewOp<mlir::func::ReturnOp>(ret);
+ }
+ // Delete result old local storage if unused.
+ if (resultStorage)
+ if (auto alloc = resultStorage.getDefiningOp<fir::AllocaOp>())
+ if (alloc->use_empty())
+ rewriter.eraseOp(alloc);
return mlir::success();
}
@@ -263,8 +271,6 @@ class AddrOfOpConversion : public mlir::OpRewritePattern<fir::AddrOfOp> {
mlir::PatternRewriter &rewriter) const override {
auto oldFuncTy = mlir::cast<mlir::FunctionType>(addrOf.getType());
mlir::FunctionType newFuncTy;
- // TODO: This should be generalized for derived types, and it is
- // architecture and OS dependent.
if (oldFuncTy.getNumResults() != 0 &&
fir::isa_builtin_cptr_type(oldFuncTy.getResult(0)))
newFuncTy = getCPtrFunctionType(oldFuncTy);
@@ -298,8 +304,6 @@ class AbstractResultOpt
// Convert function type itself if it has an abstract result.
auto funcTy = mlir::cast<mlir::FunctionType>(func.getFunctionType());
if (hasAbstractResult(funcTy)) {
- // TODO: This should be generalized for derived types, and it is
- // architecture and OS dependent.
if (fir::isa_builtin_cptr_type(funcTy.getResult(0))) {
func.setType(getCPtrFunctionType(funcTy));
patterns.insert<ReturnOpConversion>(context, mlir::Value{});
diff --git a/flang/test/Fir/abstract-results.fir b/flang/test/Fir/abstract-results.fir
index 82f1cd33073fd..93e63dc657f0c 100644
--- a/flang/test/Fir/abstract-results.fir
+++ b/flang/test/Fir/abstract-results.fir
@@ -87,8 +87,8 @@ func.func @boxfunc_callee() -> !fir.box<!fir.heap<f64>> {
// FUNC-BOX: return
}
-// FUNC-REF-LABEL: func @retcptr() -> i64
-// FUNC-BOX-LABEL: func @retcptr() -> i64
+// FUNC-REF-LABEL: func @retcptr() -> !fir.ref<none>
+// FUNC-BOX-LABEL: func @retcptr() -> !fir.ref<none>
func.func @retcptr() -> !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}> {
%0 = fir.alloca !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}> {bindc_name = "rec", uniq_name = "_QFrecErec"}
%1 = fir.load %0 : !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>
@@ -98,12 +98,14 @@ func.func @retcptr() -> !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__addres
// FUNC-REF: %[[FIELD:.*]] = fir.field_index __address, !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>
// FUNC-REF: %[[ADDR:.*]] = fir.coordinate_of %[[ALLOC]], %[[FIELD]] : (!fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>, !fir.field) -> !fir.ref<i64>
// FUNC-REF: %[[VAL:.*]] = fir.load %[[ADDR]] : !fir.ref<i64>
- // FUNC-REF: return %[[VAL]] : i64
+ // FUNC-REF: %[[CAST:.*]] = fir.convert %[[VAL]] : (i64) -> !fir.ref<none>
+ // FUNC-REF: return %[[CAST]] : !fir.ref<none>
// FUNC-BOX: %[[ALLOC:.*]] = fir.alloca !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}> {bindc_name = "rec", uniq_name = "_QFrecErec"}
// FUNC-BOX: %[[FIELD:.*]] = fir.field_index __address, !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>
// FUNC-BOX: %[[ADDR:.*]] = fir.coordinate_of %[[ALLOC]], %[[FIELD]] : (!fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>, !fir.field) -> !fir.ref<i64>
// FUNC-BOX: %[[VAL:.*]] = fir.load %[[ADDR]] : !fir.ref<i64>
- // FUNC-BOX: return %[[VAL]] : i64
+ // FUNC-BOX: %[[CAST:.*]] = fir.convert %[[VAL]] : (i64) -> !fir.ref<none>
+ // FUNC-BOX: return %[[CAST]] : !fir.ref<none>
}
// FUNC-REF-LABEL: func private @arrayfunc_callee_declare(
@@ -311,8 +313,8 @@ func.func @test_address_of() {
}
-// FUNC-REF-LABEL: func.func private @returns_null() -> i64
-// FUNC-BOX-LABEL: func.func private @returns_null() -> i64
+// FUNC-REF-LABEL: func.func private @returns_null() -> !fir.ref<none>
+// FUNC-BOX-LABEL: func.func private @returns_null() -> !fir.ref<none>
func.func private @returns_null() -> !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>
// FUNC-REF-LABEL: func @test_address_of_cptr
@@ -323,12 +325,12 @@ func.func @test_address_of_cptr() {
fir.call @_QMtest_c_func_modPsubr(%1) : (() -> ()) -> ()
return
- // FUNC-REF: %[[VAL_0:.*]] = fir.address_of(@returns_null) : () -> i64
- // FUNC-REF: %[[VAL_1:.*]] = fir.convert %[[VAL_0]] : (() -> i64) -> (() -> !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>)
+ // FUNC-REF: %[[VAL_0:.*]] = fir.address_of(@returns_null) : () -> !fir.ref<none>
+ // FUNC-REF: %[[VAL_1:.*]] = fir.convert %[[VAL_0]] : (() -> !fir.ref<none>) -> (() -> !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>)
// FUNC-REF: %[[VAL_2:.*]] = fir.convert %[[VAL_1]] : (() -> !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>) -> (() -> ())
// FUNC-REF: fir.call @_QMtest_c_func_modPsubr(%[[VAL_2]]) : (() -> ()) -> ()
- // FUNC-BOX: %[[VAL_0:.*]] = fir.address_of(@returns_null) : () -> i64
- // FUNC-BOX: %[[VAL_1:.*]] = fir.convert %[[VAL_0]] : (() -> i64) -> (() -> !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>)
+ // FUNC-BOX: %[[VAL_0:.*]] = fir.address_of(@returns_null) : () -> !fir.ref<none>
+ // FUNC-BOX: %[[VAL_1:.*]] = fir.convert %[[VAL_0]] : (() -> !fir.ref<none>) -> (() -> !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>)
// FUNC-BOX: %[[VAL_2:.*]] = fir.convert %[[VAL_1]] : (() -> !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>) -> (() -> ())
// FUNC-BOX: fir.call @_QMtest_c_func_modPsubr(%[[VAL_2]]) : (() -> ()) -> ()
}
@@ -380,18 +382,20 @@ func.func @test_indirect_calls_return_cptr(%arg0: () -> ()) {
// FUNC-REF: %[[VAL_0:.*]] = fir.alloca !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}> {bindc_name = ".result"}
// FUNC-REF: %[[VAL_1:.*]] = fir.convert %[[ARG0]] : (() -> ()) -> (() -> !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>)
- // FUNC-REF: %[[VAL_2:.*]] = fir.convert %[[VAL_1]] : (() -> !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>) -> (() -> i64)
- // FUNC-REF: %[[VAL_3:.*]] = fir.call %[[VAL_2]]() : () -> i64
+ // FUNC-REF: %[[VAL_2:.*]] = fir.convert %[[VAL_1]] : (() -> !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>) -> (() -> !fir.ref<none>)
+ // FUNC-REF: %[[VAL_3:.*]] = fir.call %[[VAL_2]]() : () -> !fir.ref<none>
// FUNC-REF: %[[VAL_4:.*]] = fir.field_index __address, !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>
// FUNC-REF: %[[VAL_5:.*]] = fir.coordinate_of %[[VAL_0]], %[[VAL_4]] : (!fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>, !fir.field) -> !fir.ref<i64>
- // FUNC-REF: fir.store %[[VAL_3]] to %[[VAL_5]] : !fir.ref<i64>
+ // FUNC-REF: %[[CAST:.*]] = fir.convert %[[VAL_3]] : (!fir.ref<none>) -> i64
+ // FUNC-REF: fir.store %[[CAST]] to %[[VAL_5]] : !fir.ref<i64>
// FUNC-BOX: %[[VAL_0:.*]] = fir.alloca !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}> {bindc_name = ".result"}
// FUNC-BOX: %[[VAL_1:.*]] = fir.convert %[[ARG0]] : (() -> ()) -> (() -> !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>)
- // FUNC-BOX: %[[VAL_2:.*]] = fir.convert %[[VAL_1]] : (() -> !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>) -> (() -> i64)
- // FUNC-BOX: %[[VAL_3:.*]] = fir.call %[[VAL_2]]() : () -> i64
+ // FUNC-BOX: %[[VAL_2:.*]] = fir.convert %[[VAL_1]] : (() -> !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>) -> (() -> !fir.ref<none>)
+ // FUNC-BOX: %[[VAL_3:.*]] = fir.call %[[VAL_2]]() : () -> !fir.ref<none>
// FUNC-BOX: %[[VAL_4:.*]] = fir.field_index __address, !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>
// FUNC-BOX: %[[VAL_5:.*]] = fir.coordinate_of %[[VAL_0]], %[[VAL_4]] : (!fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>, !fir.field) -> !fir.ref<i64>
- // FUNC-BOX: fir.store %[[VAL_3]] to %[[VAL_5]] : !fir.ref<i64>
+ // FUNC-BOX: %[[CAST:.*]] = fir.convert %[[VAL_3]] : (!fir.ref<none>) -> i64
+ // FUNC-BOX: fir.store %[[CAST]] to %[[VAL_5]] : !fir.ref<i64>
}
// ----------------------- Test GlobalOp rewrite ------------------------
``````````
</details>
https://github.com/llvm/llvm-project/pull/100082
More information about the flang-commits
mailing list