[flang-commits] [flang] [flang][cuda] Lower device/managed/unified allocation to cuda ops (PR #90623)

via flang-commits flang-commits at lists.llvm.org
Tue Apr 30 08:56:19 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-flang-fir-hlfir

Author: Valentin Clement (バレンタイン クレメン) (clementval)

<details>
<summary>Changes</summary>

Lower locals allocation of cuda device, managed and unified variables to fir.cuda_alloc. Add fir.cuda_free in the function context finalization.

@<!-- -->vzakhari For some reason the PR #<!-- -->90526 has been closed when I merged PR #<!-- -->90525. Just reopening one. 

---
Full diff: https://github.com/llvm/llvm-project/pull/90623.diff


6 Files Affected:

- (modified) flang/include/flang/Optimizer/Builder/FIRBuilder.h (+7) 
- (modified) flang/include/flang/Semantics/tools.h (+17) 
- (modified) flang/lib/Lower/ConvertVariable.cpp (+29) 
- (modified) flang/lib/Optimizer/Builder/FIRBuilder.cpp (+14-11) 
- (modified) flang/lib/Optimizer/Dialect/FIROps.cpp (+15) 
- (modified) flang/test/Lower/CUDA/cuda-data-attribute.cuf (+25) 


``````````diff
diff --git a/flang/include/flang/Optimizer/Builder/FIRBuilder.h b/flang/include/flang/Optimizer/Builder/FIRBuilder.h
index e4c954159f71be..0d650f830b64e0 100644
--- a/flang/include/flang/Optimizer/Builder/FIRBuilder.h
+++ b/flang/include/flang/Optimizer/Builder/FIRBuilder.h
@@ -708,6 +708,13 @@ mlir::Value createNullBoxProc(fir::FirOpBuilder &builder, mlir::Location loc,
 
 /// Set internal linkage attribute on a function.
 void setInternalLinkage(mlir::func::FuncOp);
+
+llvm::SmallVector<mlir::Value>
+elideExtentsAlreadyInType(mlir::Type type, mlir::ValueRange shape);
+
+llvm::SmallVector<mlir::Value>
+elideLengthsAlreadyInType(mlir::Type type, mlir::ValueRange lenParams);
+
 } // namespace fir::factory
 
 #endif // FORTRAN_OPTIMIZER_BUILDER_FIRBUILDER_H
diff --git a/flang/include/flang/Semantics/tools.h b/flang/include/flang/Semantics/tools.h
index da10969ebc7021..c9eb5bc857ac01 100644
--- a/flang/include/flang/Semantics/tools.h
+++ b/flang/include/flang/Semantics/tools.h
@@ -222,6 +222,23 @@ inline bool HasCUDAAttr(const Symbol &sym) {
   return false;
 }
 
+inline bool NeedCUDAAlloc(const Symbol &sym) {
+  bool inDeviceSubprogram{IsCUDADeviceContext(&sym.owner())};
+  if (const auto *details{
+          sym.GetUltimate().detailsIf<semantics::ObjectEntityDetails>()}) {
+    if (details->cudaDataAttr() &&
+        (*details->cudaDataAttr() == common::CUDADataAttr::Device ||
+            *details->cudaDataAttr() == common::CUDADataAttr::Managed ||
+            *details->cudaDataAttr() == common::CUDADataAttr::Unified)) {
+      // Descriptor is allocated on host when in host context.
+      if (Fortran::semantics::IsAllocatable(sym))
+        return inDeviceSubprogram;
+      return true;
+    }
+  }
+  return false;
+}
+
 const Scope *FindCUDADeviceContext(const Scope *);
 std::optional<common::CUDADataAttr> GetCUDADataAttr(const Symbol *);
 
diff --git a/flang/lib/Lower/ConvertVariable.cpp b/flang/lib/Lower/ConvertVariable.cpp
index 413563fe95ca38..f31fbab41028c1 100644
--- a/flang/lib/Lower/ConvertVariable.cpp
+++ b/flang/lib/Lower/ConvertVariable.cpp
@@ -693,6 +693,22 @@ static mlir::Value createNewLocal(Fortran::lower::AbstractConverter &converter,
   if (ultimateSymbol.test(Fortran::semantics::Symbol::Flag::CrayPointee))
     return builder.create<fir::ZeroOp>(loc, fir::ReferenceType::get(ty));
 
+  if (Fortran::semantics::NeedCUDAAlloc(ultimateSymbol)) {
+    fir::CUDADataAttributeAttr cudaAttr =
+        Fortran::lower::translateSymbolCUDADataAttribute(builder.getContext(),
+                                                         ultimateSymbol);
+    llvm::SmallVector<mlir::Value> indices;
+    llvm::SmallVector<mlir::Value> elidedShape =
+        fir::factory::elideExtentsAlreadyInType(ty, shape);
+    llvm::SmallVector<mlir::Value> elidedLenParams =
+        fir::factory::elideLengthsAlreadyInType(ty, lenParams);
+    auto idxTy = builder.getIndexType();
+    for (mlir::Value sh : elidedShape)
+      indices.push_back(builder.createConvert(loc, idxTy, sh));
+    return builder.create<fir::CUDAAllocOp>(loc, ty, nm, symNm, cudaAttr,
+                                            lenParams, indices);
+  }
+
   // Let the builder do all the heavy lifting.
   if (!Fortran::semantics::IsProcedurePointer(ultimateSymbol))
     return builder.allocateLocal(loc, ty, nm, symNm, shape, lenParams, isTarg);
@@ -927,6 +943,19 @@ static void instantiateLocal(Fortran::lower::AbstractConverter &converter,
       });
     }
   }
+  if (Fortran::semantics::NeedCUDAAlloc(var.getSymbol())) {
+    auto *builder = &converter.getFirOpBuilder();
+    mlir::Location loc = converter.getCurrentLocation();
+    fir::ExtendedValue exv =
+        converter.getSymbolExtendedValue(var.getSymbol(), &symMap);
+    auto *sym = &var.getSymbol();
+    converter.getFctCtx().attachCleanup([builder, loc, exv, sym]() {
+      fir::CUDADataAttributeAttr cudaAttr =
+          Fortran::lower::translateSymbolCUDADataAttribute(
+              builder->getContext(), *sym);
+      builder->create<fir::CUDAFreeOp>(loc, fir::getBase(exv), cudaAttr);
+    });
+  }
 }
 
 //===----------------------------------------------------------------===//
diff --git a/flang/lib/Optimizer/Builder/FIRBuilder.cpp b/flang/lib/Optimizer/Builder/FIRBuilder.cpp
index a6da387637264d..bd018d7f015b86 100644
--- a/flang/lib/Optimizer/Builder/FIRBuilder.cpp
+++ b/flang/lib/Optimizer/Builder/FIRBuilder.cpp
@@ -176,8 +176,9 @@ mlir::Value fir::FirOpBuilder::createRealConstant(mlir::Location loc,
   llvm_unreachable("should use builtin floating-point type");
 }
 
-static llvm::SmallVector<mlir::Value>
-elideExtentsAlreadyInType(mlir::Type type, mlir::ValueRange shape) {
+llvm::SmallVector<mlir::Value>
+fir::factory::elideExtentsAlreadyInType(mlir::Type type,
+                                        mlir::ValueRange shape) {
   auto arrTy = mlir::dyn_cast<fir::SequenceType>(type);
   if (shape.empty() || !arrTy)
     return {};
@@ -191,8 +192,9 @@ elideExtentsAlreadyInType(mlir::Type type, mlir::ValueRange shape) {
   return dynamicShape;
 }
 
-static llvm::SmallVector<mlir::Value>
-elideLengthsAlreadyInType(mlir::Type type, mlir::ValueRange lenParams) {
+llvm::SmallVector<mlir::Value>
+fir::factory::elideLengthsAlreadyInType(mlir::Type type,
+                                        mlir::ValueRange lenParams) {
   if (lenParams.empty())
     return {};
   if (auto arrTy = mlir::dyn_cast<fir::SequenceType>(type))
@@ -211,9 +213,9 @@ mlir::Value fir::FirOpBuilder::allocateLocal(
   // Convert the shape extents to `index`, as needed.
   llvm::SmallVector<mlir::Value> indices;
   llvm::SmallVector<mlir::Value> elidedShape =
-      elideExtentsAlreadyInType(ty, shape);
+      fir::factory::elideExtentsAlreadyInType(ty, shape);
   llvm::SmallVector<mlir::Value> elidedLenParams =
-      elideLengthsAlreadyInType(ty, lenParams);
+      fir::factory::elideLengthsAlreadyInType(ty, lenParams);
   auto idxTy = getIndexType();
   for (mlir::Value sh : elidedShape)
     indices.push_back(createConvert(loc, idxTy, sh));
@@ -283,9 +285,9 @@ fir::FirOpBuilder::createTemporary(mlir::Location loc, mlir::Type type,
                                    mlir::ValueRange lenParams,
                                    llvm::ArrayRef<mlir::NamedAttribute> attrs) {
   llvm::SmallVector<mlir::Value> dynamicShape =
-      elideExtentsAlreadyInType(type, shape);
+      fir::factory::elideExtentsAlreadyInType(type, shape);
   llvm::SmallVector<mlir::Value> dynamicLength =
-      elideLengthsAlreadyInType(type, lenParams);
+      fir::factory::elideLengthsAlreadyInType(type, lenParams);
   InsertPoint insPt;
   const bool hoistAlloc = dynamicShape.empty() && dynamicLength.empty();
   if (hoistAlloc) {
@@ -306,9 +308,9 @@ mlir::Value fir::FirOpBuilder::createHeapTemporary(
     mlir::ValueRange shape, mlir::ValueRange lenParams,
     llvm::ArrayRef<mlir::NamedAttribute> attrs) {
   llvm::SmallVector<mlir::Value> dynamicShape =
-      elideExtentsAlreadyInType(type, shape);
+      fir::factory::elideExtentsAlreadyInType(type, shape);
   llvm::SmallVector<mlir::Value> dynamicLength =
-      elideLengthsAlreadyInType(type, lenParams);
+      fir::factory::elideLengthsAlreadyInType(type, lenParams);
 
   assert(!mlir::isa<fir::ReferenceType>(type) && "cannot be a reference");
   return create<fir::AllocMemOp>(loc, type, /*unique_name=*/llvm::StringRef{},
@@ -660,7 +662,8 @@ mlir::Value fir::FirOpBuilder::createBox(mlir::Location loc, mlir::Type boxType,
   mlir::Type valueOrSequenceType = fir::unwrapPassByRefType(boxType);
   return create<fir::EmboxOp>(
       loc, boxType, addr, shape, slice,
-      elideLengthsAlreadyInType(valueOrSequenceType, lengths), tdesc);
+      fir::factory::elideLengthsAlreadyInType(valueOrSequenceType, lengths),
+      tdesc);
 }
 
 void fir::FirOpBuilder::dumpFunc() { getFunction().dump(); }
diff --git a/flang/lib/Optimizer/Dialect/FIROps.cpp b/flang/lib/Optimizer/Dialect/FIROps.cpp
index 6773d0adced0ce..5e6c18af2dd0f9 100644
--- a/flang/lib/Optimizer/Dialect/FIROps.cpp
+++ b/flang/lib/Optimizer/Dialect/FIROps.cpp
@@ -4033,6 +4033,21 @@ mlir::LogicalResult fir::CUDADeallocateOp::verify() {
   return mlir::success();
 }
 
+void fir::CUDAAllocOp::build(
+    mlir::OpBuilder &builder, mlir::OperationState &result, mlir::Type inType,
+    llvm::StringRef uniqName, llvm::StringRef bindcName,
+    fir::CUDADataAttributeAttr cudaAttr, mlir::ValueRange typeparams,
+    mlir::ValueRange shape, llvm::ArrayRef<mlir::NamedAttribute> attributes) {
+  mlir::StringAttr nameAttr =
+      uniqName.empty() ? mlir::StringAttr{} : builder.getStringAttr(uniqName);
+  mlir::StringAttr bindcAttr =
+      bindcName.empty() ? mlir::StringAttr{} : builder.getStringAttr(bindcName);
+  build(builder, result, wrapAllocaResultType(inType),
+        mlir::TypeAttr::get(inType), nameAttr, bindcAttr, typeparams, shape,
+        cudaAttr);
+  result.addAttributes(attributes);
+}
+
 //===----------------------------------------------------------------------===//
 // FIROpsDialect
 //===----------------------------------------------------------------------===//
diff --git a/flang/test/Lower/CUDA/cuda-data-attribute.cuf b/flang/test/Lower/CUDA/cuda-data-attribute.cuf
index 937c981bddd368..083a3cacc02062 100644
--- a/flang/test/Lower/CUDA/cuda-data-attribute.cuf
+++ b/flang/test/Lower/CUDA/cuda-data-attribute.cuf
@@ -62,4 +62,29 @@ end subroutine
 ! CHECK-SAME: %[[ARG0:.*]]: !fir.ref<f32> {fir.bindc_name = "du", fir.cuda_attr = #fir.cuda<unified>})
 ! CHECK: %{{.*}}:2 = hlfir.declare %[[ARG0]] {cuda_attr = #fir.cuda<unified>, uniq_name = "_QMcuda_varFdummy_arg_unifiedEdu"} : (!fir.ref<f32>) -> (!fir.ref<f32>, !fir.ref<f32>)
 
+subroutine cuda_alloc_free(n)
+  integer :: n
+  real, device :: a(10)
+  integer, unified :: u
+  real, managed :: b(n)
+end
+
+! CHECK-LABEL: func.func @_QMcuda_varPcuda_alloc_free
+! CHECK: %[[ALLOC_A:.*]] = fir.cuda_alloc !fir.array<10xf32> {bindc_name = "a", cuda_attr = #fir.cuda<device>, uniq_name = "_QMcuda_varFcuda_alloc_freeEa"} -> !fir.ref<!fir.array<10xf32>>
+! CHECK: %[[SHAPE:.*]] = fir.shape %c10 : (index) -> !fir.shape<1>
+! CHECK: %[[DECL_A:.*]]:2 = hlfir.declare %[[ALLOC_A]](%[[SHAPE]]) {cuda_attr = #fir.cuda<device>, uniq_name = "_QMcuda_varFcuda_alloc_freeEa"} : (!fir.ref<!fir.array<10xf32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<10xf32>>, !fir.ref<!fir.array<10xf32>>)
+
+! CHECK: %[[ALLOC_U:.*]] = fir.cuda_alloc i32 {bindc_name = "u", cuda_attr = #fir.cuda<unified>, uniq_name = "_QMcuda_varFcuda_alloc_freeEu"} -> !fir.ref<i32>
+! CHECK: %[[DECL_U:.*]]:2 = hlfir.declare %[[ALLOC_U]] {cuda_attr = #fir.cuda<unified>, uniq_name = "_QMcuda_varFcuda_alloc_freeEu"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+
+! CHECK: %[[ALLOC_B:.*]] = fir.cuda_alloc !fir.array<?xf32>, %{{.*}} : index {bindc_name = "b", cuda_attr = #fir.cuda<managed>, uniq_name = "_QMcuda_varFcuda_alloc_freeEb"} -> !fir.ref<!fir.array<?xf32>>
+! CHECK: %[[SHAPE:.*]] = fir.shape %{{.*}} : (index) -> !fir.shape<1>
+! CHECK: %[[DECL_B:.*]]:2 = hlfir.declare %[[ALLOC_B]](%[[SHAPE]]) {cuda_attr = #fir.cuda<managed>, uniq_name = "_QMcuda_varFcuda_alloc_freeEb"} : (!fir.ref<!fir.array<?xf32>>, !fir.shape<1>) -> (!fir.box<!fir.array<?xf32>>, !fir.ref<!fir.array<?xf32>>)
+
+! CHECK: fir.cuda_free %[[DECL_B]]#1 : !fir.ref<!fir.array<?xf32>> {cuda_attr = #fir.cuda<managed>}
+! CHECK: fir.cuda_free %[[DECL_U]]#1 : !fir.ref<i32> {cuda_attr = #fir.cuda<unified>}
+! CHECK: fir.cuda_free %[[DECL_A]]#1 : !fir.ref<!fir.array<10xf32>> {cuda_attr = #fir.cuda<device>}
+
 end module
+
+

``````````

</details>


https://github.com/llvm/llvm-project/pull/90623


More information about the flang-commits mailing list