[flang-commits] [flang] ef59dbe - [flang][cuda] Lower c_devptr value arguments in bind(c) like c_ptr (#199316)
via flang-commits
flang-commits at lists.llvm.org
Tue May 26 13:02:43 PDT 2026
Author: Valentin Clement (バレンタイン クレメン)
Date: 2026-05-26T13:02:38-07:00
New Revision: ef59dbea76d82f008e5314c47e574193dbc4d403
URL: https://github.com/llvm/llvm-project/commit/ef59dbea76d82f008e5314c47e574193dbc4d403
DIFF: https://github.com/llvm/llvm-project/commit/ef59dbea76d82f008e5314c47e574193dbc4d403.diff
LOG: [flang][cuda] Lower c_devptr value arguments in bind(c) like c_ptr (#199316)
Treat `type(c_devptr), value` arguments in BIND(C) interfaces like
`type(c_ptr), value` by passing the nested raw address value instead of
the outer derived type ABI. This keeps call signatures consistent for
CUDA Fortran generic specifics that share a C binding label and avoids
argument misclassification at the x86_64 register/stack boundary.
Added:
flang/test/HLFIR/c_devptr_byvalue.cuf
Modified:
flang/include/flang/Lower/CallInterface.h
flang/include/flang/Optimizer/Builder/FIRBuilder.h
flang/include/flang/Optimizer/Dialect/FIRType.h
flang/lib/Evaluate/tools.cpp
flang/lib/Lower/CallInterface.cpp
flang/lib/Lower/ConvertCall.cpp
flang/lib/Lower/ConvertConstant.cpp
flang/lib/Lower/ConvertVariable.cpp
flang/lib/Optimizer/Builder/FIRBuilder.cpp
flang/lib/Optimizer/Builder/IntrinsicCall.cpp
Removed:
################################################################################
diff --git a/flang/include/flang/Lower/CallInterface.h b/flang/include/flang/Lower/CallInterface.h
index 9ccfb684510a1..348987cb76c5e 100644
--- a/flang/include/flang/Lower/CallInterface.h
+++ b/flang/include/flang/Lower/CallInterface.h
@@ -494,7 +494,7 @@ getDummyProcedurePointerType(const Fortran::semantics::Symbol &dummyProcPtr,
mlir::Type getUntypedBoxProcType(mlir::MLIRContext *context);
/// Return true if \p ty is "!fir.ref<i64>", which is the interface for
-/// type(C_PTR/C_FUNPTR) passed by value.
+/// type(C_PTR/C_FUNPTR/C_DEVPTR) passed by value.
bool isCPtrArgByValueType(mlir::Type ty);
/// Is it required to pass \p proc as a tuple<function address, result length> ?
diff --git a/flang/include/flang/Optimizer/Builder/FIRBuilder.h b/flang/include/flang/Optimizer/Builder/FIRBuilder.h
index dc99174d6b993..0f4a0b63755a0 100644
--- a/flang/include/flang/Optimizer/Builder/FIRBuilder.h
+++ b/flang/include/flang/Optimizer/Builder/FIRBuilder.h
@@ -905,17 +905,10 @@ mlir::Value genMaxWithZero(fir::FirOpBuilder &builder, mlir::Location loc,
mlir::Value genMaxWithZero(fir::FirOpBuilder &builder, mlir::Location loc,
mlir::Value value, mlir::Value zero);
-/// The type(C_PTR/C_FUNPTR) is defined as the derived type with only one
-/// component of integer 64, and the component is the C address. Get the C
-/// address.
+/// Get the C address from a type(C_PTR/C_FUNPTR/C_DEVPTR) entity.
mlir::Value genCPtrOrCFunptrAddr(fir::FirOpBuilder &builder, mlir::Location loc,
mlir::Value cPtr, mlir::Type ty);
-/// The type(C_DEVPTR) is defined as the derived type with only one
-/// component of C_PTR type. Get the C address from the C_PTR component.
-mlir::Value genCDevPtrAddr(fir::FirOpBuilder &builder, mlir::Location loc,
- mlir::Value cDevPtr, mlir::Type ty);
-
/// Get the C address value.
mlir::Value genCPtrOrCFunptrValue(fir::FirOpBuilder &builder,
mlir::Location loc, mlir::Value cPtr);
diff --git a/flang/include/flang/Optimizer/Dialect/FIRType.h b/flang/include/flang/Optimizer/Dialect/FIRType.h
index f67cc3ed5db34..daca6d209f553 100644
--- a/flang/include/flang/Optimizer/Dialect/FIRType.h
+++ b/flang/include/flang/Optimizer/Dialect/FIRType.h
@@ -157,11 +157,12 @@ inline bool conformsWithPassByRef(mlir::Type t) {
/// Is `t` a derived (record) type?
inline bool isa_derived(mlir::Type t) { return mlir::isa<fir::RecordType>(t); }
-/// Is `t` type(c_ptr) or type(c_funptr)?
+/// Is `t` type(c_ptr), type(c_funptr), or type(c_devptr)?
inline bool isa_builtin_cptr_type(mlir::Type t) {
if (auto recTy = mlir::dyn_cast_or_null<fir::RecordType>(t))
return recTy.getName().ends_with("T__builtin_c_ptr") ||
- recTy.getName().ends_with("T__builtin_c_funptr");
+ recTy.getName().ends_with("T__builtin_c_funptr") ||
+ recTy.getName().ends_with("T__builtin_c_devptr");
return false;
}
diff --git a/flang/lib/Evaluate/tools.cpp b/flang/lib/Evaluate/tools.cpp
index 3edaa5befd27b..82dcd1e795f49 100644
--- a/flang/lib/Evaluate/tools.cpp
+++ b/flang/lib/Evaluate/tools.cpp
@@ -2531,7 +2531,7 @@ bool IsBuiltinDerivedType(const DerivedTypeSpec *derived, const char *name) {
bool IsBuiltinCPtr(const Symbol &symbol) {
if (const DeclTypeSpec *declType = symbol.GetType()) {
if (const DerivedTypeSpec *derived = declType->AsDerived()) {
- return IsIsoCType(derived);
+ return IsIsoCType(derived) || IsBuiltinDerivedType(derived, "c_devptr");
}
}
return false;
diff --git a/flang/lib/Lower/CallInterface.cpp b/flang/lib/Lower/CallInterface.cpp
index e9059581c690a..3fcf314faefb6 100644
--- a/flang/lib/Lower/CallInterface.cpp
+++ b/flang/lib/Lower/CallInterface.cpp
@@ -1212,6 +1212,9 @@ class Fortran::lower::CallInterfaceImpl {
if (isBuiltinCptrType) {
auto recTy = mlir::dyn_cast<fir::RecordType>(type);
mlir::Type fieldTy = recTy.getTypeList()[0].second;
+ if (fir::isa_builtin_cdevptr_type(type))
+ fieldTy =
+ mlir::cast<fir::RecordType>(fieldTy).getTypeList()[0].second;
passType = fir::ReferenceType::get(fieldTy);
} else {
passType = type;
diff --git a/flang/lib/Lower/ConvertCall.cpp b/flang/lib/Lower/ConvertCall.cpp
index e6c89122bde23..d01f54357d9fb 100644
--- a/flang/lib/Lower/ConvertCall.cpp
+++ b/flang/lib/Lower/ConvertCall.cpp
@@ -79,15 +79,15 @@ static fir::ExtendedValue toExtendedValue(mlir::Location loc, mlir::Value base,
return base;
}
-/// Lower a type(C_PTR/C_FUNPTR) argument with VALUE attribute into a
+/// Lower a type(C_PTR/C_FUNPTR/C_DEVPTR) argument with VALUE attribute into a
/// reference. A C pointer can correspond to a Fortran dummy argument of type
/// C_PTR with the VALUE attribute. (see 18.3.6 note 3).
static mlir::Value genRecordCPtrValueArg(fir::FirOpBuilder &builder,
mlir::Location loc, mlir::Value rec,
- mlir::Type ty) {
- mlir::Value cAddr = fir::factory::genCPtrOrCFunptrAddr(builder, loc, rec, ty);
- mlir::Value cVal = fir::LoadOp::create(builder, loc, cAddr);
- return builder.createConvert(loc, cAddr.getType(), cVal);
+ mlir::Type) {
+ mlir::Value cVal = fir::factory::genCPtrOrCFunptrValue(builder, loc, rec);
+ return builder.createConvert(loc, fir::ReferenceType::get(cVal.getType()),
+ cVal);
}
// Find the argument that corresponds to the host associations.
@@ -1752,7 +1752,7 @@ void prepareUserCallArguments(
mlir::Type eleTy = value.getFortranElementType();
if (fir::isa_builtin_cptr_type(eleTy)) {
- // Pass-by-value argument of type(C_PTR/C_FUNPTR).
+ // Pass-by-value argument of type(C_PTR/C_FUNPTR/C_DEVPTR).
// Load the __address component and pass it by value.
if (value.isValue()) {
auto associate = hlfir::genAssociateExpr(loc, builder, value, eleTy,
diff --git a/flang/lib/Lower/ConvertConstant.cpp b/flang/lib/Lower/ConvertConstant.cpp
index b063099497386..ca24965c51668 100644
--- a/flang/lib/Lower/ConvertConstant.cpp
+++ b/flang/lib/Lower/ConvertConstant.cpp
@@ -446,16 +446,15 @@ static mlir::Value genStructureComponentInit(
if (Fortran::lower::isDerivedTypeWithLenParameters(sym))
TODO(loc, "component with length parameters in structure constructor");
- // Special handling for scalar c_ptr/c_funptr constants. The array constant
- // must fall through to genConstantValue() below.
+ // Special handling for scalar c_ptr/c_funptr/c_devptr constants. The array
+ // constant must fall through to genConstantValue() below.
if (Fortran::semantics::IsBuiltinCPtr(sym) && sym.Rank() == 0 &&
(Fortran::evaluate::GetLastSymbol(expr) ||
Fortran::evaluate::IsNullPointer(&expr))) {
- // Builtin c_ptr and c_funptr have special handling because designators
- // and NULL() are handled as initial values for them as an extension
- // (otherwise only c_ptr_null/c_funptr_null are allowed and these are
- // replaced by structure constructors by semantics, so GetLastSymbol
- // returns nothing).
+ // Builtin C pointer types have special handling because designators and
+ // NULL() are handled as initial values for them as an extension (otherwise
+ // only the named null constants are allowed and these are replaced by
+ // structure constructors by semantics, so GetLastSymbol returns nothing).
// The Ev::Expr is an initializer that is a pointer target (e.g., 'x' or
// NULL()) that must be inserted into an intermediate cptr record value's
@@ -468,20 +467,36 @@ static mlir::Value genStructureComponentInit(
mlir::isa<mlir::FunctionType>(addr.getType())) &&
"expect reference type for address field");
assert(fir::isa_derived(componentTy) &&
- "expect C_PTR, C_FUNPTR to be a record");
- auto cPtrRecTy = mlir::cast<fir::RecordType>(componentTy);
+ "expect C_PTR, C_FUNPTR, C_DEVPTR to be a record");
+ auto componentRecTy = mlir::cast<fir::RecordType>(componentTy);
+ mlir::Type cPtrTy = componentTy;
+ if (fir::isa_builtin_cdevptr_type(componentTy)) {
+ assert(componentRecTy.getTypeList().size() == 1);
+ cPtrTy = componentRecTy.getTypeList()[0].second;
+ }
+ auto cPtrRecTy = mlir::cast<fir::RecordType>(cPtrTy);
llvm::StringRef addrFieldName = Fortran::lower::builtin::cptrFieldName;
mlir::Type addrFieldTy = cPtrRecTy.getType(addrFieldName);
- auto addrField = fir::FieldIndexOp::create(
- builder, loc, fieldTy, addrFieldName, componentTy,
- /*typeParams=*/mlir::ValueRange{});
+ auto addrField =
+ fir::FieldIndexOp::create(builder, loc, fieldTy, addrFieldName, cPtrTy,
+ /*typeParams=*/mlir::ValueRange{});
mlir::Value castAddr = builder.createConvert(loc, addrFieldTy, addr);
- auto undef = fir::UndefOp::create(builder, loc, componentTy);
- addr = fir::InsertValueOp::create(
- builder, loc, componentTy, undef, castAddr,
+ auto undef = fir::UndefOp::create(builder, loc, cPtrTy);
+ mlir::Value componentValue = fir::InsertValueOp::create(
+ builder, loc, cPtrTy, undef, castAddr,
builder.getArrayAttr(addrField.getAttributes()));
+ if (fir::isa_builtin_cdevptr_type(componentTy)) {
+ auto cptrFieldName = componentRecTy.getTypeList()[0].first;
+ auto cptrField = fir::FieldIndexOp::create(
+ builder, loc, fieldTy, cptrFieldName, componentTy,
+ /*typeParams=*/mlir::ValueRange{});
+ auto cdevptrUndef = fir::UndefOp::create(builder, loc, componentTy);
+ componentValue = fir::InsertValueOp::create(
+ builder, loc, componentTy, cdevptrUndef, componentValue,
+ builder.getArrayAttr(cptrField.getAttributes()));
+ }
res =
- fir::InsertValueOp::create(builder, loc, recTy, res, addr,
+ fir::InsertValueOp::create(builder, loc, recTy, res, componentValue,
builder.getArrayAttr(field.getAttributes()));
return res;
}
diff --git a/flang/lib/Lower/ConvertVariable.cpp b/flang/lib/Lower/ConvertVariable.cpp
index 4e7b091d26186..f4efdfef7b2d2 100644
--- a/flang/lib/Lower/ConvertVariable.cpp
+++ b/flang/lib/Lower/ConvertVariable.cpp
@@ -588,15 +588,14 @@ fir::GlobalOp Fortran::lower::defineGlobal(
if (details && details->init()) {
auto sym{*details->init()};
if (sym) // Has a procedure target.
- createGlobalInitialization(
- builder, global, [&](fir::FirOpBuilder &b) {
- Fortran::lower::StatementContext stmtCtx(
- /*cleanupProhibited=*/true);
- auto box{Fortran::lower::convertProcedureDesignatorInitialTarget(
- converter, loc, *sym)};
- auto castTo{builder.createConvert(loc, symTy, box)};
- fir::HasValueOp::create(b, loc, castTo);
- });
+ createGlobalInitialization(builder, global, [&](fir::FirOpBuilder &b) {
+ Fortran::lower::StatementContext stmtCtx(
+ /*cleanupProhibited=*/true);
+ auto box{Fortran::lower::convertProcedureDesignatorInitialTarget(
+ converter, loc, *sym)};
+ auto castTo{builder.createConvert(loc, symTy, box)};
+ fir::HasValueOp::create(b, loc, castTo);
+ });
else { // Has NULL() target.
createGlobalInitialization(builder, global, [&](fir::FirOpBuilder &b) {
auto box{fir::factory::createNullBoxProc(b, loc, symTy)};
@@ -2506,15 +2505,15 @@ void Fortran::lower::mapSymbolAttributes(
if (!addr) {
if (arg) {
mlir::Type argType = arg.getType();
+ mlir::Type symType = converter.genType(sym);
const bool isCptrByVal = Fortran::semantics::IsBuiltinCPtr(sym) &&
Fortran::lower::isCPtrArgByValueType(argType);
if (isCptrByVal || !fir::conformsWithPassByRef(argType)) {
// Dummy argument passed in register. Place the value in memory at that
// point since lowering expect symbols to be mapped to memory addresses.
- mlir::Type symType = converter.genType(sym);
addr = fir::AllocaOp::create(builder, loc, symType);
if (isCptrByVal) {
- // Place the void* address into the CPTR address component.
+ // Place the void* address into the pointer address component.
mlir::Value addrComponent =
fir::factory::genCPtrOrCFunptrAddr(builder, loc, addr, symType);
builder.createStoreWithConvert(loc, arg, addrComponent);
diff --git a/flang/lib/Optimizer/Builder/FIRBuilder.cpp b/flang/lib/Optimizer/Builder/FIRBuilder.cpp
index 4ce5c6955b5c6..050f88b2f8ae0 100644
--- a/flang/lib/Optimizer/Builder/FIRBuilder.cpp
+++ b/flang/lib/Optimizer/Builder/FIRBuilder.cpp
@@ -1802,31 +1802,20 @@ mlir::Value fir::factory::genCPtrOrCFunptrAddr(fir::FirOpBuilder &builder,
mlir::Location loc,
mlir::Value cPtr,
mlir::Type ty) {
+ if (fir::isa_builtin_cdevptr_type(ty)) {
+ auto [cptrFieldIndex, cptrFieldTy] =
+ genCPtrOrCFunptrFieldIndex(builder, loc, ty);
+ auto cptrCoord = fir::CoordinateOp::create(
+ builder, loc, builder.getRefType(cptrFieldTy), cPtr, cptrFieldIndex);
+ return fir::factory::genCPtrOrCFunptrAddr(builder, loc, cptrCoord,
+ cptrFieldTy);
+ }
auto [addrFieldIndex, addrFieldTy] =
genCPtrOrCFunptrFieldIndex(builder, loc, ty);
return fir::CoordinateOp::create(
builder, loc, builder.getRefType(addrFieldTy), cPtr, addrFieldIndex);
}
-mlir::Value fir::factory::genCDevPtrAddr(fir::FirOpBuilder &builder,
- mlir::Location loc,
- mlir::Value cDevPtr, mlir::Type ty) {
- auto recTy = mlir::cast<fir::RecordType>(ty);
- assert(recTy.getTypeList().size() == 1);
- auto cptrFieldName = recTy.getTypeList()[0].first;
- mlir::Type cptrFieldTy = recTy.getTypeList()[0].second;
- auto fieldIndexType = fir::FieldType::get(ty.getContext());
- mlir::Value cptrFieldIndex = fir::FieldIndexOp::create(
- builder, loc, fieldIndexType, cptrFieldName, recTy,
- /*typeParams=*/mlir::ValueRange{});
- auto cptrCoord = fir::CoordinateOp::create(
- builder, loc, builder.getRefType(cptrFieldTy), cDevPtr, cptrFieldIndex);
- auto [addrFieldIndex, addrFieldTy] =
- genCPtrOrCFunptrFieldIndex(builder, loc, cptrFieldTy);
- return fir::CoordinateOp::create(
- builder, loc, builder.getRefType(addrFieldTy), cptrCoord, addrFieldIndex);
-}
-
mlir::Value fir::factory::genCPtrOrCFunptrValue(fir::FirOpBuilder &builder,
mlir::Location loc,
mlir::Value cPtr) {
diff --git a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
index e3be05bad1051..82f27c0fee37f 100644
--- a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
+++ b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
@@ -3108,15 +3108,12 @@ static void clocDeviceArgRewrite(fir::ExtendedValue arg) {
static fir::ExtendedValue
genCLocOrCFunLoc(fir::FirOpBuilder &builder, mlir::Location loc,
mlir::Type resultType, llvm::ArrayRef<fir::ExtendedValue> args,
- bool isFunc = false, bool isDevLoc = false) {
+ bool isFunc = false) {
assert(args.size() == 1);
clocDeviceArgRewrite(args[0]);
mlir::Value res = fir::AllocaOp::create(builder, loc, resultType);
- mlir::Value resAddr;
- if (isDevLoc)
- resAddr = fir::factory::genCDevPtrAddr(builder, loc, res, resultType);
- else
- resAddr = fir::factory::genCPtrOrCFunptrAddr(builder, loc, res, resultType);
+ mlir::Value resAddr =
+ fir::factory::genCPtrOrCFunptrAddr(builder, loc, res, resultType);
assert(fir::isa_box_type(fir::getBase(args[0]).getType()) &&
"argument must have been lowered to box type");
mlir::Value argAddr = getAddrFromBox(builder, loc, args[0], isFunc);
@@ -3178,8 +3175,7 @@ IntrinsicLibrary::genCAssociatedCPtr(mlir::Type resultType,
fir::ExtendedValue
IntrinsicLibrary::genCDevLoc(mlir::Type resultType,
llvm::ArrayRef<fir::ExtendedValue> args) {
- return genCLocOrCFunLoc(builder, loc, resultType, args, /*isFunc=*/false,
- /*isDevLoc=*/true);
+ return genCLocOrCFunLoc(builder, loc, resultType, args);
}
// C_F_POINTER
diff --git a/flang/test/HLFIR/c_devptr_byvalue.cuf b/flang/test/HLFIR/c_devptr_byvalue.cuf
new file mode 100644
index 0000000000000..46229df4610a9
--- /dev/null
+++ b/flang/test/HLFIR/c_devptr_byvalue.cuf
@@ -0,0 +1,22 @@
+! RUN: bbc -emit-hlfir -fcuda %s -o - | FileCheck %s
+
+! CHECK-LABEL: func.func @_QPtest_c_devptr(
+! CHECK-SAME: %[[ARG0:.*]]: !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_devptr{cptr:!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>}>> {fir.bindc_name = "cdevptr"}) {
+! CHECK: %[[DSCOPE:.*]] = fir.dummy_scope : !fir.dscope
+! CHECK: %[[DECL:.*]]:2 = hlfir.declare %[[ARG0]] dummy_scope %[[DSCOPE]] arg {{[0-9]+}} {uniq_name = "_QFtest_c_devptrEcdevptr"} : (!fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_devptr{cptr:!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>}>>, !fir.dscope) -> (!fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_devptr{cptr:!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>}>>, !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_devptr{cptr:!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>}>>)
+! CHECK: %[[CPTR:.*]] = fir.coordinate_of %[[DECL]]#0, cptr : (!fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_devptr{cptr:!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>}>>) -> !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>
+! CHECK: %[[ADDR:.*]] = fir.coordinate_of %[[CPTR]], __address : (!fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>) -> !fir.ref<i64>
+! CHECK: %[[VAL:.*]] = fir.load %[[ADDR]] : !fir.ref<i64>
+! CHECK: %[[ARG:.*]] = fir.convert %[[VAL]] : (i64) -> !fir.ref<i64>
+! CHECK: fir.call @get_expected_f(%[[ARG]]) proc_attrs<bind_c> fastmath<contract> : (!fir.ref<i64>) -> ()
+subroutine test_c_devptr(cdevptr)
+ use __fortran_builtins, only : c_devptr => __builtin_c_devptr
+ interface
+ subroutine get_expected_f(src) bind(c)
+ use __fortran_builtins, only : c_devptr => __builtin_c_devptr
+ type(c_devptr), value :: src
+ end subroutine get_expected_f
+ end interface
+ type(c_devptr) :: cdevptr
+ call get_expected_f(cdevptr)
+end
More information about the flang-commits
mailing list