[flang-commits] [flang] [flang][cuda] Correctly allocate memory for descriptor load (PR #120164)

Valentin Clement バレンタイン クレメン via flang-commits flang-commits at lists.llvm.org
Mon Dec 16 19:10:39 PST 2024


https://github.com/clementval updated https://github.com/llvm/llvm-project/pull/120164

>From 1cee7492b76ea633e7db3a5b1554fa349fb671ec Mon Sep 17 00:00:00 2001
From: Valentin Clement <clementval at gmail.com>
Date: Mon, 16 Dec 2024 15:59:51 -0800
Subject: [PATCH 1/2] [flang][cuda] Correctly allocate memory for descriptor
 load

---
 flang/lib/Optimizer/CodeGen/CodeGen.cpp | 108 +++++++++++++++++++++++-
 flang/test/Fir/CUDA/cuda-code-gen.mlir  |  29 +++++++
 2 files changed, 135 insertions(+), 2 deletions(-)
 create mode 100644 flang/test/Fir/CUDA/cuda-code-gen.mlir

diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
index 5345d64c330f06..723b4ecd8f582a 100644
--- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp
+++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
@@ -23,6 +23,7 @@
 #include "flang/Optimizer/Support/InternalNames.h"
 #include "flang/Optimizer/Support/TypeCode.h"
 #include "flang/Optimizer/Support/Utils.h"
+#include "flang/Runtime/CUDA/descriptor.h"
 #include "flang/Runtime/allocator-registry-consts.h"
 #include "flang/Runtime/descriptor-consts.h"
 #include "flang/Semantics/runtime-type-info.h"
@@ -63,6 +64,8 @@ namespace fir {
 
 #define DEBUG_TYPE "flang-codegen"
 
+using namespace Fortran::runtime::cuda;
+
 // TODO: This should really be recovered from the specified target.
 static constexpr unsigned defaultAlign = 8;
 
@@ -2970,6 +2973,93 @@ struct GlobalOpConversion : public fir::FIROpConversion<fir::GlobalOp> {
   }
 };
 
+static mlir::Value genSourceFile(mlir::Location loc, mlir::ModuleOp mod,
+                                 mlir::ConversionPatternRewriter &rewriter) {
+  auto ptrTy = mlir::LLVM::LLVMPointerType::get(rewriter.getContext());
+  if (auto flc = mlir::dyn_cast<mlir::FileLineColLoc>(loc)) {
+    auto fn = flc.getFilename().str() + '\0';
+    std::string globalName = fir::factory::uniqueCGIdent("cl", fn);
+
+    if (auto g = mod.lookupSymbol<fir::GlobalOp>(globalName)) {
+      return rewriter.create<mlir::LLVM::AddressOfOp>(loc, ptrTy, g.getName());
+    } else if (auto g = mod.lookupSymbol<mlir::LLVM::GlobalOp>(globalName)) {
+      return rewriter.create<mlir::LLVM::AddressOfOp>(loc, ptrTy, g.getName());
+    }
+
+    auto crtInsPt = rewriter.saveInsertionPoint();
+    rewriter.setInsertionPoint(mod.getBody(), mod.getBody()->end());
+    auto arrayTy = mlir::LLVM::LLVMArrayType::get(
+        mlir::IntegerType::get(rewriter.getContext(), 8), fn.size());
+    mlir::LLVM::GlobalOp globalOp = rewriter.create<mlir::LLVM::GlobalOp>(
+        loc, arrayTy, /*constant=*/true, mlir::LLVM::Linkage::Linkonce,
+        globalName, mlir::Attribute());
+
+    mlir::Region &region = globalOp.getInitializerRegion();
+    mlir::Block *block = rewriter.createBlock(&region);
+    rewriter.setInsertionPoint(block, block->begin());
+    mlir::Value constValue = rewriter.create<mlir::LLVM::ConstantOp>(
+        loc, arrayTy, rewriter.getStringAttr(fn));
+    rewriter.create<mlir::LLVM::ReturnOp>(loc, constValue);
+    rewriter.restoreInsertionPoint(crtInsPt);
+    return rewriter.create<mlir::LLVM::AddressOfOp>(loc, ptrTy,
+                                                    globalOp.getName());
+  }
+  return rewriter.create<mlir::LLVM::ZeroOp>(loc, ptrTy);
+}
+
+static mlir::Value genSourceLine(mlir::Location loc,
+                                 mlir::ConversionPatternRewriter &rewriter) {
+  if (auto flc = mlir::dyn_cast<mlir::FileLineColLoc>(loc))
+    return rewriter.create<mlir::LLVM::ConstantOp>(loc, rewriter.getI32Type(),
+                                                   flc.getLine());
+  return rewriter.create<mlir::LLVM::ConstantOp>(loc, rewriter.getI32Type(), 0);
+}
+
+static mlir::Value
+genCUFAllocDescriptor(mlir::Location loc,
+                      mlir::ConversionPatternRewriter &rewriter,
+                      mlir::ModuleOp mod, fir::BaseBoxType boxTy,
+                      const fir::LLVMTypeConverter &typeConverter) {
+  std::optional<mlir::DataLayout> dl =
+      fir::support::getOrSetDataLayout(mod, /*allowDefaultLayout=*/true);
+  if (!dl)
+    mlir::emitError(mod.getLoc(),
+                    "module operation must carry a data layout attribute "
+                    "to generate llvm IR from FIR");
+
+  mlir::Value sourceFile = genSourceFile(loc, mod, rewriter);
+  mlir::Value sourceLine = genSourceLine(loc, rewriter);
+
+  mlir::MLIRContext *ctx = mod.getContext();
+
+  mlir::LLVM::LLVMPointerType llvmPointerType =
+      mlir::LLVM::LLVMPointerType::get(ctx);
+  mlir::Type llvmInt32Type = mlir::IntegerType::get(ctx, 32);
+  mlir::Type llvmIntPtrType =
+      mlir::IntegerType::get(ctx, typeConverter.getPointerBitwidth(0));
+  auto fctTy = mlir::LLVM::LLVMFunctionType::get(
+      llvmPointerType, {llvmIntPtrType, llvmPointerType, llvmInt32Type});
+
+  auto llvmFunc = mod.lookupSymbol<mlir::LLVM::LLVMFuncOp>(
+      RTNAME_STRING(CUFAllocDesciptor));
+  auto funcFunc =
+      mod.lookupSymbol<mlir::func::FuncOp>(RTNAME_STRING(CUFAllocDesciptor));
+  if (!llvmFunc && !funcFunc)
+    mlir::OpBuilder::atBlockEnd(mod.getBody())
+        .create<mlir::LLVM::LLVMFuncOp>(loc, RTNAME_STRING(CUFAllocDesciptor),
+                                        fctTy);
+
+  mlir::Type structTy = typeConverter.convertBoxTypeAsStruct(boxTy);
+  std::size_t boxSize = dl->getTypeSizeInBits(structTy) / 8;
+  mlir::Value sizeInBytes =
+      genConstantIndex(loc, llvmIntPtrType, rewriter, boxSize);
+  llvm::SmallVector args = {sizeInBytes, sourceFile, sourceLine};
+  return rewriter
+      .create<mlir::LLVM::CallOp>(loc, fctTy, RTNAME_STRING(CUFAllocDesciptor),
+                                  args)
+      .getResult();
+}
+
 /// `fir.load` --> `llvm.load`
 struct LoadOpConversion : public fir::FIROpConversion<fir::LoadOp> {
   using FIROpConversion::FIROpConversion;
@@ -2986,9 +3076,23 @@ struct LoadOpConversion : public fir::FIROpConversion<fir::LoadOp> {
       // loading a fir.ref<fir.box> is implemented as taking a snapshot of the
       // descriptor value into a new descriptor temp.
       auto inputBoxStorage = adaptor.getOperands()[0];
+      mlir::Value newBoxStorage;
       mlir::Location loc = load.getLoc();
-      auto newBoxStorage =
-          genAllocaAndAddrCastWithType(loc, llvmLoadTy, defaultAlign, rewriter);
+      if (auto callOp = mlir::dyn_cast_or_null<mlir::LLVM::CallOp>(
+              inputBoxStorage.getDefiningOp())) {
+        if (callOp.getCallee() &&
+            (*callOp.getCallee())
+                .starts_with(RTNAME_STRING(CUFAllocDesciptor))) {
+          // CUDA Fortran local descriptor are allocated in managed memory. So
+          // new storage must be allocated the same way.
+          auto mod = load->getParentOfType<mlir::ModuleOp>();
+          newBoxStorage =
+              genCUFAllocDescriptor(loc, rewriter, mod, boxTy, lowerTy());
+        }
+      }
+      if (!newBoxStorage)
+        newBoxStorage = genAllocaAndAddrCastWithType(loc, llvmLoadTy,
+                                                     defaultAlign, rewriter);
 
       TypePair boxTypePair{boxTy, llvmLoadTy};
       mlir::Value boxSize =
diff --git a/flang/test/Fir/CUDA/cuda-code-gen.mlir b/flang/test/Fir/CUDA/cuda-code-gen.mlir
new file mode 100644
index 00000000000000..55e473ef2549e3
--- /dev/null
+++ b/flang/test/Fir/CUDA/cuda-code-gen.mlir
@@ -0,0 +1,29 @@
+// RUN: fir-opt --split-input-file --fir-to-llvm-ir="target=x86_64-unknown-linux-gnu" %s | FileCheck %s
+
+module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<f80, dense<128> : vector<2xi64>>, #dlti.dl_entry<i128, dense<128> : vector<2xi64>>, #dlti.dl_entry<i64, dense<64> : vector<2xi64>>, #dlti.dl_entry<!llvm.ptr<272>, dense<64> : vector<4xi64>>, #dlti.dl_entry<!llvm.ptr<271>, dense<32> : vector<4xi64>>, #dlti.dl_entry<!llvm.ptr<270>, dense<32> : vector<4xi64>>, #dlti.dl_entry<f128, dense<128> : vector<2xi64>>, #dlti.dl_entry<f64, dense<64> : vector<2xi64>>, #dlti.dl_entry<f16, dense<16> : vector<2xi64>>, #dlti.dl_entry<i32, dense<32> : vector<2xi64>>, #dlti.dl_entry<i16, dense<16> : vector<2xi64>>, #dlti.dl_entry<i8, dense<8> : vector<2xi64>>, #dlti.dl_entry<i1, dense<8> : vector<2xi64>>, #dlti.dl_entry<!llvm.ptr, dense<64> : vector<4xi64>>, #dlti.dl_entry<"dlti.endianness", "little">, #dlti.dl_entry<"dlti.stack_alignment", 128 : i64>>} {
+
+  func.func @_QQmain() attributes {fir.bindc_name = "cufkernel_global"} {
+    %c0 = arith.constant 0 : index
+    %0 = fir.address_of(@_QQclX3C737464696E3E00) : !fir.ref<!fir.char<1,8>>
+    %c4_i32 = arith.constant 4 : i32
+    %c48 = arith.constant 48 : index
+    %1 = fir.convert %c48 : (index) -> i64
+    %2 = fir.convert %0 : (!fir.ref<!fir.char<1,8>>) -> !fir.ref<i8>
+    %3 = fir.call @_FortranACUFAllocDesciptor(%1, %2, %c4_i32) : (i64, !fir.ref<i8>, i32) -> !fir.ref<!fir.box<none>>
+    %4 = fir.convert %3 : (!fir.ref<!fir.box<none>>) -> !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
+    %5 = fir.zero_bits !fir.heap<!fir.array<?xi32>>
+    %6 = fircg.ext_embox %5(%c0) {allocator_idx = 2 : i32} : (!fir.heap<!fir.array<?xi32>>, index) -> !fir.box<!fir.heap<!fir.array<?xi32>>>
+    fir.store %6 to %4 : !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
+    %8 = fir.load %3 : !fir.ref<!fir.box<none>>
+    return
+  }
+
+  // CHECK-LABEL: llvm.func @_QQmain()
+  // CHECK-COUNT-2: llvm.call @_FortranACUFAllocDesciptor 
+
+  fir.global linkonce @_QQclX3C737464696E3E00 constant : !fir.char<1,8> {
+    %0 = fir.string_lit "<stdin>\00"(8) : !fir.char<1,8>
+    fir.has_value %0 : !fir.char<1,8>
+  }
+  func.func private @_FortranACUFAllocDesciptor(i64, !fir.ref<i8>, i32) -> !fir.ref<!fir.box<none>> attributes {fir.runtime}
+}

>From db97343b4b8fb767c96c74536da18d489c42c998 Mon Sep 17 00:00:00 2001
From: Valentin Clement <clementval at gmail.com>
Date: Mon, 16 Dec 2024 19:10:25 -0800
Subject: [PATCH 2/2] Remove useless code

---
 flang/lib/Optimizer/CodeGen/CodeGen.cpp | 2 --
 1 file changed, 2 deletions(-)

diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
index 723b4ecd8f582a..082f2b15512b8b 100644
--- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp
+++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
@@ -64,8 +64,6 @@ namespace fir {
 
 #define DEBUG_TYPE "flang-codegen"
 
-using namespace Fortran::runtime::cuda;
-
 // TODO: This should really be recovered from the specified target.
 static constexpr unsigned defaultAlign = 8;
 



More information about the flang-commits mailing list