[flang-commits] [flang] ef59dbe - [flang][cuda] Lower c_devptr value arguments in bind(c) like c_ptr (#199316)

via flang-commits flang-commits at lists.llvm.org
Tue May 26 13:02:43 PDT 2026


Author: Valentin Clement (バレンタイン クレメン)
Date: 2026-05-26T13:02:38-07:00
New Revision: ef59dbea76d82f008e5314c47e574193dbc4d403

URL: https://github.com/llvm/llvm-project/commit/ef59dbea76d82f008e5314c47e574193dbc4d403
DIFF: https://github.com/llvm/llvm-project/commit/ef59dbea76d82f008e5314c47e574193dbc4d403.diff

LOG: [flang][cuda] Lower c_devptr value arguments in bind(c) like c_ptr (#199316)

Treat `type(c_devptr), value` arguments in BIND(C) interfaces like
`type(c_ptr), value` by passing the nested raw address value instead of
the outer derived type ABI. This keeps call signatures consistent for
CUDA Fortran generic specifics that share a C binding label and avoids
argument misclassification at the x86_64 register/stack boundary.

Added: 
    flang/test/HLFIR/c_devptr_byvalue.cuf

Modified: 
    flang/include/flang/Lower/CallInterface.h
    flang/include/flang/Optimizer/Builder/FIRBuilder.h
    flang/include/flang/Optimizer/Dialect/FIRType.h
    flang/lib/Evaluate/tools.cpp
    flang/lib/Lower/CallInterface.cpp
    flang/lib/Lower/ConvertCall.cpp
    flang/lib/Lower/ConvertConstant.cpp
    flang/lib/Lower/ConvertVariable.cpp
    flang/lib/Optimizer/Builder/FIRBuilder.cpp
    flang/lib/Optimizer/Builder/IntrinsicCall.cpp

Removed: 
    


################################################################################
diff  --git a/flang/include/flang/Lower/CallInterface.h b/flang/include/flang/Lower/CallInterface.h
index 9ccfb684510a1..348987cb76c5e 100644
--- a/flang/include/flang/Lower/CallInterface.h
+++ b/flang/include/flang/Lower/CallInterface.h
@@ -494,7 +494,7 @@ getDummyProcedurePointerType(const Fortran::semantics::Symbol &dummyProcPtr,
 mlir::Type getUntypedBoxProcType(mlir::MLIRContext *context);
 
 /// Return true if \p ty is "!fir.ref<i64>", which is the interface for
-/// type(C_PTR/C_FUNPTR) passed by value.
+/// type(C_PTR/C_FUNPTR/C_DEVPTR) passed by value.
 bool isCPtrArgByValueType(mlir::Type ty);
 
 /// Is it required to pass \p proc as a tuple<function address, result length> ?

diff  --git a/flang/include/flang/Optimizer/Builder/FIRBuilder.h b/flang/include/flang/Optimizer/Builder/FIRBuilder.h
index dc99174d6b993..0f4a0b63755a0 100644
--- a/flang/include/flang/Optimizer/Builder/FIRBuilder.h
+++ b/flang/include/flang/Optimizer/Builder/FIRBuilder.h
@@ -905,17 +905,10 @@ mlir::Value genMaxWithZero(fir::FirOpBuilder &builder, mlir::Location loc,
 mlir::Value genMaxWithZero(fir::FirOpBuilder &builder, mlir::Location loc,
                            mlir::Value value, mlir::Value zero);
 
-/// The type(C_PTR/C_FUNPTR) is defined as the derived type with only one
-/// component of integer 64, and the component is the C address. Get the C
-/// address.
+/// Get the C address from a type(C_PTR/C_FUNPTR/C_DEVPTR) entity.
 mlir::Value genCPtrOrCFunptrAddr(fir::FirOpBuilder &builder, mlir::Location loc,
                                  mlir::Value cPtr, mlir::Type ty);
 
-/// The type(C_DEVPTR) is defined as the derived type with only one
-/// component of C_PTR type. Get the C address from the C_PTR component.
-mlir::Value genCDevPtrAddr(fir::FirOpBuilder &builder, mlir::Location loc,
-                           mlir::Value cDevPtr, mlir::Type ty);
-
 /// Get the C address value.
 mlir::Value genCPtrOrCFunptrValue(fir::FirOpBuilder &builder,
                                   mlir::Location loc, mlir::Value cPtr);

diff  --git a/flang/include/flang/Optimizer/Dialect/FIRType.h b/flang/include/flang/Optimizer/Dialect/FIRType.h
index f67cc3ed5db34..daca6d209f553 100644
--- a/flang/include/flang/Optimizer/Dialect/FIRType.h
+++ b/flang/include/flang/Optimizer/Dialect/FIRType.h
@@ -157,11 +157,12 @@ inline bool conformsWithPassByRef(mlir::Type t) {
 /// Is `t` a derived (record) type?
 inline bool isa_derived(mlir::Type t) { return mlir::isa<fir::RecordType>(t); }
 
-/// Is `t` type(c_ptr) or type(c_funptr)?
+/// Is `t` type(c_ptr), type(c_funptr), or type(c_devptr)?
 inline bool isa_builtin_cptr_type(mlir::Type t) {
   if (auto recTy = mlir::dyn_cast_or_null<fir::RecordType>(t))
     return recTy.getName().ends_with("T__builtin_c_ptr") ||
-           recTy.getName().ends_with("T__builtin_c_funptr");
+           recTy.getName().ends_with("T__builtin_c_funptr") ||
+           recTy.getName().ends_with("T__builtin_c_devptr");
   return false;
 }
 

diff  --git a/flang/lib/Evaluate/tools.cpp b/flang/lib/Evaluate/tools.cpp
index 3edaa5befd27b..82dcd1e795f49 100644
--- a/flang/lib/Evaluate/tools.cpp
+++ b/flang/lib/Evaluate/tools.cpp
@@ -2531,7 +2531,7 @@ bool IsBuiltinDerivedType(const DerivedTypeSpec *derived, const char *name) {
 bool IsBuiltinCPtr(const Symbol &symbol) {
   if (const DeclTypeSpec *declType = symbol.GetType()) {
     if (const DerivedTypeSpec *derived = declType->AsDerived()) {
-      return IsIsoCType(derived);
+      return IsIsoCType(derived) || IsBuiltinDerivedType(derived, "c_devptr");
     }
   }
   return false;

diff  --git a/flang/lib/Lower/CallInterface.cpp b/flang/lib/Lower/CallInterface.cpp
index e9059581c690a..3fcf314faefb6 100644
--- a/flang/lib/Lower/CallInterface.cpp
+++ b/flang/lib/Lower/CallInterface.cpp
@@ -1212,6 +1212,9 @@ class Fortran::lower::CallInterfaceImpl {
           if (isBuiltinCptrType) {
             auto recTy = mlir::dyn_cast<fir::RecordType>(type);
             mlir::Type fieldTy = recTy.getTypeList()[0].second;
+            if (fir::isa_builtin_cdevptr_type(type))
+              fieldTy =
+                  mlir::cast<fir::RecordType>(fieldTy).getTypeList()[0].second;
             passType = fir::ReferenceType::get(fieldTy);
           } else {
             passType = type;

diff  --git a/flang/lib/Lower/ConvertCall.cpp b/flang/lib/Lower/ConvertCall.cpp
index e6c89122bde23..d01f54357d9fb 100644
--- a/flang/lib/Lower/ConvertCall.cpp
+++ b/flang/lib/Lower/ConvertCall.cpp
@@ -79,15 +79,15 @@ static fir::ExtendedValue toExtendedValue(mlir::Location loc, mlir::Value base,
   return base;
 }
 
-/// Lower a type(C_PTR/C_FUNPTR) argument with VALUE attribute into a
+/// Lower a type(C_PTR/C_FUNPTR/C_DEVPTR) argument with VALUE attribute into a
 /// reference. A C pointer can correspond to a Fortran dummy argument of type
 /// C_PTR with the VALUE attribute. (see 18.3.6 note 3).
 static mlir::Value genRecordCPtrValueArg(fir::FirOpBuilder &builder,
                                          mlir::Location loc, mlir::Value rec,
-                                         mlir::Type ty) {
-  mlir::Value cAddr = fir::factory::genCPtrOrCFunptrAddr(builder, loc, rec, ty);
-  mlir::Value cVal = fir::LoadOp::create(builder, loc, cAddr);
-  return builder.createConvert(loc, cAddr.getType(), cVal);
+                                         mlir::Type) {
+  mlir::Value cVal = fir::factory::genCPtrOrCFunptrValue(builder, loc, rec);
+  return builder.createConvert(loc, fir::ReferenceType::get(cVal.getType()),
+                               cVal);
 }
 
 // Find the argument that corresponds to the host associations.
@@ -1752,7 +1752,7 @@ void prepareUserCallArguments(
 
       mlir::Type eleTy = value.getFortranElementType();
       if (fir::isa_builtin_cptr_type(eleTy)) {
-        // Pass-by-value argument of type(C_PTR/C_FUNPTR).
+        // Pass-by-value argument of type(C_PTR/C_FUNPTR/C_DEVPTR).
         // Load the __address component and pass it by value.
         if (value.isValue()) {
           auto associate = hlfir::genAssociateExpr(loc, builder, value, eleTy,

diff  --git a/flang/lib/Lower/ConvertConstant.cpp b/flang/lib/Lower/ConvertConstant.cpp
index b063099497386..ca24965c51668 100644
--- a/flang/lib/Lower/ConvertConstant.cpp
+++ b/flang/lib/Lower/ConvertConstant.cpp
@@ -446,16 +446,15 @@ static mlir::Value genStructureComponentInit(
   if (Fortran::lower::isDerivedTypeWithLenParameters(sym))
     TODO(loc, "component with length parameters in structure constructor");
 
-  // Special handling for scalar c_ptr/c_funptr constants. The array constant
-  // must fall through to genConstantValue() below.
+  // Special handling for scalar c_ptr/c_funptr/c_devptr constants. The array
+  // constant must fall through to genConstantValue() below.
   if (Fortran::semantics::IsBuiltinCPtr(sym) && sym.Rank() == 0 &&
       (Fortran::evaluate::GetLastSymbol(expr) ||
        Fortran::evaluate::IsNullPointer(&expr))) {
-    // Builtin c_ptr and c_funptr have special handling because designators
-    // and NULL() are handled as initial values for them as an extension
-    // (otherwise only c_ptr_null/c_funptr_null are allowed and these are
-    // replaced by structure constructors by semantics, so GetLastSymbol
-    // returns nothing).
+    // Builtin C pointer types have special handling because designators and
+    // NULL() are handled as initial values for them as an extension (otherwise
+    // only the named null constants are allowed and these are replaced by
+    // structure constructors by semantics, so GetLastSymbol returns nothing).
 
     // The Ev::Expr is an initializer that is a pointer target (e.g., 'x' or
     // NULL()) that must be inserted into an intermediate cptr record value's
@@ -468,20 +467,36 @@ static mlir::Value genStructureComponentInit(
             mlir::isa<mlir::FunctionType>(addr.getType())) &&
            "expect reference type for address field");
     assert(fir::isa_derived(componentTy) &&
-           "expect C_PTR, C_FUNPTR to be a record");
-    auto cPtrRecTy = mlir::cast<fir::RecordType>(componentTy);
+           "expect C_PTR, C_FUNPTR, C_DEVPTR to be a record");
+    auto componentRecTy = mlir::cast<fir::RecordType>(componentTy);
+    mlir::Type cPtrTy = componentTy;
+    if (fir::isa_builtin_cdevptr_type(componentTy)) {
+      assert(componentRecTy.getTypeList().size() == 1);
+      cPtrTy = componentRecTy.getTypeList()[0].second;
+    }
+    auto cPtrRecTy = mlir::cast<fir::RecordType>(cPtrTy);
     llvm::StringRef addrFieldName = Fortran::lower::builtin::cptrFieldName;
     mlir::Type addrFieldTy = cPtrRecTy.getType(addrFieldName);
-    auto addrField = fir::FieldIndexOp::create(
-        builder, loc, fieldTy, addrFieldName, componentTy,
-        /*typeParams=*/mlir::ValueRange{});
+    auto addrField =
+        fir::FieldIndexOp::create(builder, loc, fieldTy, addrFieldName, cPtrTy,
+                                  /*typeParams=*/mlir::ValueRange{});
     mlir::Value castAddr = builder.createConvert(loc, addrFieldTy, addr);
-    auto undef = fir::UndefOp::create(builder, loc, componentTy);
-    addr = fir::InsertValueOp::create(
-        builder, loc, componentTy, undef, castAddr,
+    auto undef = fir::UndefOp::create(builder, loc, cPtrTy);
+    mlir::Value componentValue = fir::InsertValueOp::create(
+        builder, loc, cPtrTy, undef, castAddr,
         builder.getArrayAttr(addrField.getAttributes()));
+    if (fir::isa_builtin_cdevptr_type(componentTy)) {
+      auto cptrFieldName = componentRecTy.getTypeList()[0].first;
+      auto cptrField = fir::FieldIndexOp::create(
+          builder, loc, fieldTy, cptrFieldName, componentTy,
+          /*typeParams=*/mlir::ValueRange{});
+      auto cdevptrUndef = fir::UndefOp::create(builder, loc, componentTy);
+      componentValue = fir::InsertValueOp::create(
+          builder, loc, componentTy, cdevptrUndef, componentValue,
+          builder.getArrayAttr(cptrField.getAttributes()));
+    }
     res =
-        fir::InsertValueOp::create(builder, loc, recTy, res, addr,
+        fir::InsertValueOp::create(builder, loc, recTy, res, componentValue,
                                    builder.getArrayAttr(field.getAttributes()));
     return res;
   }

diff  --git a/flang/lib/Lower/ConvertVariable.cpp b/flang/lib/Lower/ConvertVariable.cpp
index 4e7b091d26186..f4efdfef7b2d2 100644
--- a/flang/lib/Lower/ConvertVariable.cpp
+++ b/flang/lib/Lower/ConvertVariable.cpp
@@ -588,15 +588,14 @@ fir::GlobalOp Fortran::lower::defineGlobal(
     if (details && details->init()) {
       auto sym{*details->init()};
       if (sym) // Has a procedure target.
-        createGlobalInitialization(
-            builder, global, [&](fir::FirOpBuilder &b) {
-              Fortran::lower::StatementContext stmtCtx(
-                  /*cleanupProhibited=*/true);
-              auto box{Fortran::lower::convertProcedureDesignatorInitialTarget(
-                  converter, loc, *sym)};
-              auto castTo{builder.createConvert(loc, symTy, box)};
-              fir::HasValueOp::create(b, loc, castTo);
-            });
+        createGlobalInitialization(builder, global, [&](fir::FirOpBuilder &b) {
+          Fortran::lower::StatementContext stmtCtx(
+              /*cleanupProhibited=*/true);
+          auto box{Fortran::lower::convertProcedureDesignatorInitialTarget(
+              converter, loc, *sym)};
+          auto castTo{builder.createConvert(loc, symTy, box)};
+          fir::HasValueOp::create(b, loc, castTo);
+        });
       else { // Has NULL() target.
         createGlobalInitialization(builder, global, [&](fir::FirOpBuilder &b) {
           auto box{fir::factory::createNullBoxProc(b, loc, symTy)};
@@ -2506,15 +2505,15 @@ void Fortran::lower::mapSymbolAttributes(
   if (!addr) {
     if (arg) {
       mlir::Type argType = arg.getType();
+      mlir::Type symType = converter.genType(sym);
       const bool isCptrByVal = Fortran::semantics::IsBuiltinCPtr(sym) &&
                                Fortran::lower::isCPtrArgByValueType(argType);
       if (isCptrByVal || !fir::conformsWithPassByRef(argType)) {
         // Dummy argument passed in register. Place the value in memory at that
         // point since lowering expect symbols to be mapped to memory addresses.
-        mlir::Type symType = converter.genType(sym);
         addr = fir::AllocaOp::create(builder, loc, symType);
         if (isCptrByVal) {
-          // Place the void* address into the CPTR address component.
+          // Place the void* address into the pointer address component.
           mlir::Value addrComponent =
               fir::factory::genCPtrOrCFunptrAddr(builder, loc, addr, symType);
           builder.createStoreWithConvert(loc, arg, addrComponent);

diff  --git a/flang/lib/Optimizer/Builder/FIRBuilder.cpp b/flang/lib/Optimizer/Builder/FIRBuilder.cpp
index 4ce5c6955b5c6..050f88b2f8ae0 100644
--- a/flang/lib/Optimizer/Builder/FIRBuilder.cpp
+++ b/flang/lib/Optimizer/Builder/FIRBuilder.cpp
@@ -1802,31 +1802,20 @@ mlir::Value fir::factory::genCPtrOrCFunptrAddr(fir::FirOpBuilder &builder,
                                                mlir::Location loc,
                                                mlir::Value cPtr,
                                                mlir::Type ty) {
+  if (fir::isa_builtin_cdevptr_type(ty)) {
+    auto [cptrFieldIndex, cptrFieldTy] =
+        genCPtrOrCFunptrFieldIndex(builder, loc, ty);
+    auto cptrCoord = fir::CoordinateOp::create(
+        builder, loc, builder.getRefType(cptrFieldTy), cPtr, cptrFieldIndex);
+    return fir::factory::genCPtrOrCFunptrAddr(builder, loc, cptrCoord,
+                                              cptrFieldTy);
+  }
   auto [addrFieldIndex, addrFieldTy] =
       genCPtrOrCFunptrFieldIndex(builder, loc, ty);
   return fir::CoordinateOp::create(
       builder, loc, builder.getRefType(addrFieldTy), cPtr, addrFieldIndex);
 }
 
-mlir::Value fir::factory::genCDevPtrAddr(fir::FirOpBuilder &builder,
-                                         mlir::Location loc,
-                                         mlir::Value cDevPtr, mlir::Type ty) {
-  auto recTy = mlir::cast<fir::RecordType>(ty);
-  assert(recTy.getTypeList().size() == 1);
-  auto cptrFieldName = recTy.getTypeList()[0].first;
-  mlir::Type cptrFieldTy = recTy.getTypeList()[0].second;
-  auto fieldIndexType = fir::FieldType::get(ty.getContext());
-  mlir::Value cptrFieldIndex = fir::FieldIndexOp::create(
-      builder, loc, fieldIndexType, cptrFieldName, recTy,
-      /*typeParams=*/mlir::ValueRange{});
-  auto cptrCoord = fir::CoordinateOp::create(
-      builder, loc, builder.getRefType(cptrFieldTy), cDevPtr, cptrFieldIndex);
-  auto [addrFieldIndex, addrFieldTy] =
-      genCPtrOrCFunptrFieldIndex(builder, loc, cptrFieldTy);
-  return fir::CoordinateOp::create(
-      builder, loc, builder.getRefType(addrFieldTy), cptrCoord, addrFieldIndex);
-}
-
 mlir::Value fir::factory::genCPtrOrCFunptrValue(fir::FirOpBuilder &builder,
                                                 mlir::Location loc,
                                                 mlir::Value cPtr) {

diff  --git a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
index e3be05bad1051..82f27c0fee37f 100644
--- a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
+++ b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
@@ -3108,15 +3108,12 @@ static void clocDeviceArgRewrite(fir::ExtendedValue arg) {
 static fir::ExtendedValue
 genCLocOrCFunLoc(fir::FirOpBuilder &builder, mlir::Location loc,
                  mlir::Type resultType, llvm::ArrayRef<fir::ExtendedValue> args,
-                 bool isFunc = false, bool isDevLoc = false) {
+                 bool isFunc = false) {
   assert(args.size() == 1);
   clocDeviceArgRewrite(args[0]);
   mlir::Value res = fir::AllocaOp::create(builder, loc, resultType);
-  mlir::Value resAddr;
-  if (isDevLoc)
-    resAddr = fir::factory::genCDevPtrAddr(builder, loc, res, resultType);
-  else
-    resAddr = fir::factory::genCPtrOrCFunptrAddr(builder, loc, res, resultType);
+  mlir::Value resAddr =
+      fir::factory::genCPtrOrCFunptrAddr(builder, loc, res, resultType);
   assert(fir::isa_box_type(fir::getBase(args[0]).getType()) &&
          "argument must have been lowered to box type");
   mlir::Value argAddr = getAddrFromBox(builder, loc, args[0], isFunc);
@@ -3178,8 +3175,7 @@ IntrinsicLibrary::genCAssociatedCPtr(mlir::Type resultType,
 fir::ExtendedValue
 IntrinsicLibrary::genCDevLoc(mlir::Type resultType,
                              llvm::ArrayRef<fir::ExtendedValue> args) {
-  return genCLocOrCFunLoc(builder, loc, resultType, args, /*isFunc=*/false,
-                          /*isDevLoc=*/true);
+  return genCLocOrCFunLoc(builder, loc, resultType, args);
 }
 
 // C_F_POINTER

diff  --git a/flang/test/HLFIR/c_devptr_byvalue.cuf b/flang/test/HLFIR/c_devptr_byvalue.cuf
new file mode 100644
index 0000000000000..46229df4610a9
--- /dev/null
+++ b/flang/test/HLFIR/c_devptr_byvalue.cuf
@@ -0,0 +1,22 @@
+! RUN: bbc -emit-hlfir -fcuda %s -o - | FileCheck %s
+
+! CHECK-LABEL:   func.func @_QPtest_c_devptr(
+! CHECK-SAME:                                %[[ARG0:.*]]: !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_devptr{cptr:!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>}>> {fir.bindc_name = "cdevptr"}) {
+! CHECK:           %[[DSCOPE:.*]] = fir.dummy_scope : !fir.dscope
+! CHECK:           %[[DECL:.*]]:2 = hlfir.declare %[[ARG0]] dummy_scope %[[DSCOPE]] arg {{[0-9]+}} {uniq_name = "_QFtest_c_devptrEcdevptr"} : (!fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_devptr{cptr:!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>}>>, !fir.dscope) -> (!fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_devptr{cptr:!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>}>>, !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_devptr{cptr:!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>}>>)
+! CHECK:           %[[CPTR:.*]] = fir.coordinate_of %[[DECL]]#0, cptr : (!fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_devptr{cptr:!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>}>>) -> !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>
+! CHECK:           %[[ADDR:.*]] = fir.coordinate_of %[[CPTR]], __address : (!fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>) -> !fir.ref<i64>
+! CHECK:           %[[VAL:.*]] = fir.load %[[ADDR]] : !fir.ref<i64>
+! CHECK:           %[[ARG:.*]] = fir.convert %[[VAL]] : (i64) -> !fir.ref<i64>
+! CHECK:           fir.call @get_expected_f(%[[ARG]]) proc_attrs<bind_c> fastmath<contract> : (!fir.ref<i64>) -> ()
+subroutine test_c_devptr(cdevptr)
+  use __fortran_builtins, only : c_devptr => __builtin_c_devptr
+  interface
+     subroutine get_expected_f(src) bind(c)
+       use __fortran_builtins, only : c_devptr => __builtin_c_devptr
+       type(c_devptr), value :: src
+     end subroutine get_expected_f
+  end interface
+  type(c_devptr) :: cdevptr
+  call get_expected_f(cdevptr)
+end


        


More information about the flang-commits mailing list