[flang-commits] [flang] [flang][cuda] Allocate descriptor in managed memory when emboxing device memory (PR #120485)

Valentin Clement バレンタイン クレメン via flang-commits flang-commits at lists.llvm.org
Wed Dec 18 14:03:06 PST 2024


https://github.com/clementval created https://github.com/llvm/llvm-project/pull/120485

When emboxing memory that comes from CUFMemAlloc, we need to allocate the descriptor in manger memory as it might be passed to a kernel. 

>From 387177b019d0fcfb709a6f4ac2ead9e349a7714e Mon Sep 17 00:00:00 2001
From: Valentin Clement <clementval at gmail.com>
Date: Wed, 18 Dec 2024 14:01:11 -0800
Subject: [PATCH] [flang][cuda] Allocate descriptor in managed memory when
 emboxing device memory

---
 flang/lib/Optimizer/CodeGen/CodeGen.cpp | 212 +++++++++++++-----------
 flang/test/Fir/CUDA/cuda-code-gen.mlir  |  31 +++-
 2 files changed, 147 insertions(+), 96 deletions(-)

diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
index 082f2b15512b8b..a479ea5e865ef9 100644
--- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp
+++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
@@ -23,7 +23,9 @@
 #include "flang/Optimizer/Support/InternalNames.h"
 #include "flang/Optimizer/Support/TypeCode.h"
 #include "flang/Optimizer/Support/Utils.h"
+#include "flang/Optimizer/Transforms/CUFCommon.h"
 #include "flang/Runtime/CUDA/descriptor.h"
+#include "flang/Runtime/CUDA/memory.h"
 #include "flang/Runtime/allocator-registry-consts.h"
 #include "flang/Runtime/descriptor-consts.h"
 #include "flang/Semantics/runtime-type-info.h"
@@ -1135,6 +1137,93 @@ convertSubcomponentIndices(mlir::Location loc, mlir::Type eleTy,
   return result;
 }
 
+static mlir::Value genSourceFile(mlir::Location loc, mlir::ModuleOp mod,
+                                 mlir::ConversionPatternRewriter &rewriter) {
+  auto ptrTy = mlir::LLVM::LLVMPointerType::get(rewriter.getContext());
+  if (auto flc = mlir::dyn_cast<mlir::FileLineColLoc>(loc)) {
+    auto fn = flc.getFilename().str() + '\0';
+    std::string globalName = fir::factory::uniqueCGIdent("cl", fn);
+
+    if (auto g = mod.lookupSymbol<fir::GlobalOp>(globalName)) {
+      return rewriter.create<mlir::LLVM::AddressOfOp>(loc, ptrTy, g.getName());
+    } else if (auto g = mod.lookupSymbol<mlir::LLVM::GlobalOp>(globalName)) {
+      return rewriter.create<mlir::LLVM::AddressOfOp>(loc, ptrTy, g.getName());
+    }
+
+    auto crtInsPt = rewriter.saveInsertionPoint();
+    rewriter.setInsertionPoint(mod.getBody(), mod.getBody()->end());
+    auto arrayTy = mlir::LLVM::LLVMArrayType::get(
+        mlir::IntegerType::get(rewriter.getContext(), 8), fn.size());
+    mlir::LLVM::GlobalOp globalOp = rewriter.create<mlir::LLVM::GlobalOp>(
+        loc, arrayTy, /*constant=*/true, mlir::LLVM::Linkage::Linkonce,
+        globalName, mlir::Attribute());
+
+    mlir::Region &region = globalOp.getInitializerRegion();
+    mlir::Block *block = rewriter.createBlock(&region);
+    rewriter.setInsertionPoint(block, block->begin());
+    mlir::Value constValue = rewriter.create<mlir::LLVM::ConstantOp>(
+        loc, arrayTy, rewriter.getStringAttr(fn));
+    rewriter.create<mlir::LLVM::ReturnOp>(loc, constValue);
+    rewriter.restoreInsertionPoint(crtInsPt);
+    return rewriter.create<mlir::LLVM::AddressOfOp>(loc, ptrTy,
+                                                    globalOp.getName());
+  }
+  return rewriter.create<mlir::LLVM::ZeroOp>(loc, ptrTy);
+}
+
+static mlir::Value genSourceLine(mlir::Location loc,
+                                 mlir::ConversionPatternRewriter &rewriter) {
+  if (auto flc = mlir::dyn_cast<mlir::FileLineColLoc>(loc))
+    return rewriter.create<mlir::LLVM::ConstantOp>(loc, rewriter.getI32Type(),
+                                                   flc.getLine());
+  return rewriter.create<mlir::LLVM::ConstantOp>(loc, rewriter.getI32Type(), 0);
+}
+
+static mlir::Value
+genCUFAllocDescriptor(mlir::Location loc,
+                      mlir::ConversionPatternRewriter &rewriter,
+                      mlir::ModuleOp mod, fir::BaseBoxType boxTy,
+                      const fir::LLVMTypeConverter &typeConverter) {
+  std::optional<mlir::DataLayout> dl =
+      fir::support::getOrSetDataLayout(mod, /*allowDefaultLayout=*/true);
+  if (!dl)
+    mlir::emitError(mod.getLoc(),
+                    "module operation must carry a data layout attribute "
+                    "to generate llvm IR from FIR");
+
+  mlir::Value sourceFile = genSourceFile(loc, mod, rewriter);
+  mlir::Value sourceLine = genSourceLine(loc, rewriter);
+
+  mlir::MLIRContext *ctx = mod.getContext();
+
+  mlir::LLVM::LLVMPointerType llvmPointerType =
+      mlir::LLVM::LLVMPointerType::get(ctx);
+  mlir::Type llvmInt32Type = mlir::IntegerType::get(ctx, 32);
+  mlir::Type llvmIntPtrType =
+      mlir::IntegerType::get(ctx, typeConverter.getPointerBitwidth(0));
+  auto fctTy = mlir::LLVM::LLVMFunctionType::get(
+      llvmPointerType, {llvmIntPtrType, llvmPointerType, llvmInt32Type});
+
+  auto llvmFunc = mod.lookupSymbol<mlir::LLVM::LLVMFuncOp>(
+      RTNAME_STRING(CUFAllocDesciptor));
+  auto funcFunc =
+      mod.lookupSymbol<mlir::func::FuncOp>(RTNAME_STRING(CUFAllocDesciptor));
+  if (!llvmFunc && !funcFunc)
+    mlir::OpBuilder::atBlockEnd(mod.getBody())
+        .create<mlir::LLVM::LLVMFuncOp>(loc, RTNAME_STRING(CUFAllocDesciptor),
+                                        fctTy);
+
+  mlir::Type structTy = typeConverter.convertBoxTypeAsStruct(boxTy);
+  std::size_t boxSize = dl->getTypeSizeInBits(structTy) / 8;
+  mlir::Value sizeInBytes =
+      genConstantIndex(loc, llvmIntPtrType, rewriter, boxSize);
+  llvm::SmallVector args = {sizeInBytes, sourceFile, sourceLine};
+  return rewriter
+      .create<mlir::LLVM::CallOp>(loc, fctTy, RTNAME_STRING(CUFAllocDesciptor),
+                                  args)
+      .getResult();
+}
+
 /// Common base class for embox to descriptor conversion.
 template <typename OP>
 struct EmboxCommonConversion : public fir::FIROpConversion<OP> {
@@ -1548,15 +1637,24 @@ struct EmboxCommonConversion : public fir::FIROpConversion<OP> {
   mlir::Value
   placeInMemoryIfNotGlobalInit(mlir::ConversionPatternRewriter &rewriter,
                                mlir::Location loc, mlir::Type boxTy,
-                               mlir::Value boxValue) const {
+                               mlir::Value boxValue,
+                               bool needDeviceAllocation = false) const {
     if (isInGlobalOp(rewriter))
       return boxValue;
     mlir::Type llvmBoxTy = boxValue.getType();
-    auto alloca = this->genAllocaAndAddrCastWithType(loc, llvmBoxTy,
-                                                     defaultAlign, rewriter);
-    auto storeOp = rewriter.create<mlir::LLVM::StoreOp>(loc, boxValue, alloca);
+    mlir::Value storage;
+    if (needDeviceAllocation) {
+      auto mod = boxValue.getDefiningOp()->getParentOfType<mlir::ModuleOp>();
+      auto baseBoxTy = mlir::dyn_cast<fir::BaseBoxType>(boxTy);
+      storage =
+          genCUFAllocDescriptor(loc, rewriter, mod, baseBoxTy, this->lowerTy());
+    } else {
+      storage = this->genAllocaAndAddrCastWithType(loc, llvmBoxTy, defaultAlign,
+                                                   rewriter);
+    }
+    auto storeOp = rewriter.create<mlir::LLVM::StoreOp>(loc, boxValue, storage);
     this->attachTBAATag(storeOp, boxTy, boxTy, nullptr);
-    return alloca;
+    return storage;
   }
 };
 
@@ -1608,6 +1706,18 @@ struct EmboxOpConversion : public EmboxCommonConversion<fir::EmboxOp> {
   }
 };
 
+static bool isDeviceAllocation(mlir::Value val) {
+  if (auto convertOp =
+          mlir::dyn_cast_or_null<fir::ConvertOp>(val.getDefiningOp()))
+    val = convertOp.getValue();
+  if (auto callOp = mlir::dyn_cast_or_null<fir::CallOp>(val.getDefiningOp()))
+    if (callOp.getCallee() &&
+        callOp.getCallee().value().getRootReference().getValue().starts_with(
+            RTNAME_STRING(CUFMemAlloc)))
+      return true;
+  return false;
+}
+
 /// Create a generic box on a memory reference.
 struct XEmboxOpConversion : public EmboxCommonConversion<fir::cg::XEmboxOp> {
   using EmboxCommonConversion::EmboxCommonConversion;
@@ -1791,9 +1901,8 @@ struct XEmboxOpConversion : public EmboxCommonConversion<fir::cg::XEmboxOp> {
     dest = insertBaseAddress(rewriter, loc, dest, base);
     if (fir::isDerivedTypeWithLenParams(boxTy))
       TODO(loc, "fir.embox codegen of derived with length parameters");
-
-    mlir::Value result =
-        placeInMemoryIfNotGlobalInit(rewriter, loc, boxTy, dest);
+    mlir::Value result = placeInMemoryIfNotGlobalInit(
+        rewriter, loc, boxTy, dest, isDeviceAllocation(xbox.getMemref()));
     rewriter.replaceOp(xbox, result);
     return mlir::success();
   }
@@ -2971,93 +3080,6 @@ struct GlobalOpConversion : public fir::FIROpConversion<fir::GlobalOp> {
   }
 };
 
-static mlir::Value genSourceFile(mlir::Location loc, mlir::ModuleOp mod,
-                                 mlir::ConversionPatternRewriter &rewriter) {
-  auto ptrTy = mlir::LLVM::LLVMPointerType::get(rewriter.getContext());
-  if (auto flc = mlir::dyn_cast<mlir::FileLineColLoc>(loc)) {
-    auto fn = flc.getFilename().str() + '\0';
-    std::string globalName = fir::factory::uniqueCGIdent("cl", fn);
-
-    if (auto g = mod.lookupSymbol<fir::GlobalOp>(globalName)) {
-      return rewriter.create<mlir::LLVM::AddressOfOp>(loc, ptrTy, g.getName());
-    } else if (auto g = mod.lookupSymbol<mlir::LLVM::GlobalOp>(globalName)) {
-      return rewriter.create<mlir::LLVM::AddressOfOp>(loc, ptrTy, g.getName());
-    }
-
-    auto crtInsPt = rewriter.saveInsertionPoint();
-    rewriter.setInsertionPoint(mod.getBody(), mod.getBody()->end());
-    auto arrayTy = mlir::LLVM::LLVMArrayType::get(
-        mlir::IntegerType::get(rewriter.getContext(), 8), fn.size());
-    mlir::LLVM::GlobalOp globalOp = rewriter.create<mlir::LLVM::GlobalOp>(
-        loc, arrayTy, /*constant=*/true, mlir::LLVM::Linkage::Linkonce,
-        globalName, mlir::Attribute());
-
-    mlir::Region &region = globalOp.getInitializerRegion();
-    mlir::Block *block = rewriter.createBlock(&region);
-    rewriter.setInsertionPoint(block, block->begin());
-    mlir::Value constValue = rewriter.create<mlir::LLVM::ConstantOp>(
-        loc, arrayTy, rewriter.getStringAttr(fn));
-    rewriter.create<mlir::LLVM::ReturnOp>(loc, constValue);
-    rewriter.restoreInsertionPoint(crtInsPt);
-    return rewriter.create<mlir::LLVM::AddressOfOp>(loc, ptrTy,
-                                                    globalOp.getName());
-  }
-  return rewriter.create<mlir::LLVM::ZeroOp>(loc, ptrTy);
-}
-
-static mlir::Value genSourceLine(mlir::Location loc,
-                                 mlir::ConversionPatternRewriter &rewriter) {
-  if (auto flc = mlir::dyn_cast<mlir::FileLineColLoc>(loc))
-    return rewriter.create<mlir::LLVM::ConstantOp>(loc, rewriter.getI32Type(),
-                                                   flc.getLine());
-  return rewriter.create<mlir::LLVM::ConstantOp>(loc, rewriter.getI32Type(), 0);
-}
-
-static mlir::Value
-genCUFAllocDescriptor(mlir::Location loc,
-                      mlir::ConversionPatternRewriter &rewriter,
-                      mlir::ModuleOp mod, fir::BaseBoxType boxTy,
-                      const fir::LLVMTypeConverter &typeConverter) {
-  std::optional<mlir::DataLayout> dl =
-      fir::support::getOrSetDataLayout(mod, /*allowDefaultLayout=*/true);
-  if (!dl)
-    mlir::emitError(mod.getLoc(),
-                    "module operation must carry a data layout attribute "
-                    "to generate llvm IR from FIR");
-
-  mlir::Value sourceFile = genSourceFile(loc, mod, rewriter);
-  mlir::Value sourceLine = genSourceLine(loc, rewriter);
-
-  mlir::MLIRContext *ctx = mod.getContext();
-
-  mlir::LLVM::LLVMPointerType llvmPointerType =
-      mlir::LLVM::LLVMPointerType::get(ctx);
-  mlir::Type llvmInt32Type = mlir::IntegerType::get(ctx, 32);
-  mlir::Type llvmIntPtrType =
-      mlir::IntegerType::get(ctx, typeConverter.getPointerBitwidth(0));
-  auto fctTy = mlir::LLVM::LLVMFunctionType::get(
-      llvmPointerType, {llvmIntPtrType, llvmPointerType, llvmInt32Type});
-
-  auto llvmFunc = mod.lookupSymbol<mlir::LLVM::LLVMFuncOp>(
-      RTNAME_STRING(CUFAllocDesciptor));
-  auto funcFunc =
-      mod.lookupSymbol<mlir::func::FuncOp>(RTNAME_STRING(CUFAllocDesciptor));
-  if (!llvmFunc && !funcFunc)
-    mlir::OpBuilder::atBlockEnd(mod.getBody())
-        .create<mlir::LLVM::LLVMFuncOp>(loc, RTNAME_STRING(CUFAllocDesciptor),
-                                        fctTy);
-
-  mlir::Type structTy = typeConverter.convertBoxTypeAsStruct(boxTy);
-  std::size_t boxSize = dl->getTypeSizeInBits(structTy) / 8;
-  mlir::Value sizeInBytes =
-      genConstantIndex(loc, llvmIntPtrType, rewriter, boxSize);
-  llvm::SmallVector args = {sizeInBytes, sourceFile, sourceLine};
-  return rewriter
-      .create<mlir::LLVM::CallOp>(loc, fctTy, RTNAME_STRING(CUFAllocDesciptor),
-                                  args)
-      .getResult();
-}
-
 /// `fir.load` --> `llvm.load`
 struct LoadOpConversion : public fir::FIROpConversion<fir::LoadOp> {
   using FIROpConversion::FIROpConversion;
diff --git a/flang/test/Fir/CUDA/cuda-code-gen.mlir b/flang/test/Fir/CUDA/cuda-code-gen.mlir
index 55e473ef2549e3..a34c2770c5f6c5 100644
--- a/flang/test/Fir/CUDA/cuda-code-gen.mlir
+++ b/flang/test/Fir/CUDA/cuda-code-gen.mlir
@@ -1,7 +1,6 @@
 // RUN: fir-opt --split-input-file --fir-to-llvm-ir="target=x86_64-unknown-linux-gnu" %s | FileCheck %s
 
 module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<f80, dense<128> : vector<2xi64>>, #dlti.dl_entry<i128, dense<128> : vector<2xi64>>, #dlti.dl_entry<i64, dense<64> : vector<2xi64>>, #dlti.dl_entry<!llvm.ptr<272>, dense<64> : vector<4xi64>>, #dlti.dl_entry<!llvm.ptr<271>, dense<32> : vector<4xi64>>, #dlti.dl_entry<!llvm.ptr<270>, dense<32> : vector<4xi64>>, #dlti.dl_entry<f128, dense<128> : vector<2xi64>>, #dlti.dl_entry<f64, dense<64> : vector<2xi64>>, #dlti.dl_entry<f16, dense<16> : vector<2xi64>>, #dlti.dl_entry<i32, dense<32> : vector<2xi64>>, #dlti.dl_entry<i16, dense<16> : vector<2xi64>>, #dlti.dl_entry<i8, dense<8> : vector<2xi64>>, #dlti.dl_entry<i1, dense<8> : vector<2xi64>>, #dlti.dl_entry<!llvm.ptr, dense<64> : vector<4xi64>>, #dlti.dl_entry<"dlti.endianness", "little">, #dlti.dl_entry<"dlti.stack_alignment", 128 : i64>>} {
-
   func.func @_QQmain() attributes {fir.bindc_name = "cufkernel_global"} {
     %c0 = arith.constant 0 : index
     %0 = fir.address_of(@_QQclX3C737464696E3E00) : !fir.ref<!fir.char<1,8>>
@@ -27,3 +26,33 @@ module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<f80, dense<128> :
   }
   func.func private @_FortranACUFAllocDesciptor(i64, !fir.ref<i8>, i32) -> !fir.ref<!fir.box<none>> attributes {fir.runtime}
 }
+
+// -----
+
+module attributes {dlti.dl_spec = #dlti.dl_spec<f80 = dense<128> : vector<2xi64>, i128 = dense<128> : vector<2xi64>, i64 = dense<64> : vector<2xi64>, !llvm.ptr<272> = dense<64> : vector<4xi64>, !llvm.ptr<271> = dense<32> : vector<4xi64>, !llvm.ptr<270> = dense<32> : vector<4xi64>, f128 = dense<128> : vector<2xi64>, f64 = dense<64> : vector<2xi64>, f16 = dense<16> : vector<2xi64>, i32 = dense<32> : vector<2xi64>, i16 = dense<16> : vector<2xi64>, i8 = dense<8> : vector<2xi64>, i1 = dense<8> : vector<2xi64>, !llvm.ptr = dense<64> : vector<4xi64>, "dlti.endianness" = "little", "dlti.stack_alignment" = 128 : i64>} {
+  func.func @_QQmain() attributes {fir.bindc_name = "test"} {
+    %c10 = arith.constant 10 : index
+    %c20 = arith.constant 20 : index
+    %0 = fir.address_of(@_QQclX64756D6D792E6D6C697200) : !fir.ref<!fir.char<1,11>>
+    %c4 = arith.constant 4 : index
+    %c200 = arith.constant 200 : index
+    %1 = arith.muli %c200, %c4 : index
+    %c6_i32 = arith.constant 6 : i32
+    %c0_i32 = arith.constant 0 : i32
+    %2 = fir.convert %1 : (index) -> i64
+    %3 = fir.convert %0 : (!fir.ref<!fir.char<1,11>>) -> !fir.ref<i8>
+    %4 = fir.call @_FortranACUFMemAlloc(%2, %c0_i32, %3, %c6_i32) : (i64, i32, !fir.ref<i8>, i32) -> !fir.llvm_ptr<i8>
+    %5 = fir.convert %4 : (!fir.llvm_ptr<i8>) -> !fir.ref<!fir.array<10x20xi32>>
+    %6 = fircg.ext_embox %5(%c10, %c20) : (!fir.ref<!fir.array<10x20xi32>>, index, index) -> !fir.box<!fir.array<10x20xi32>>
+    return
+  }
+  fir.global linkonce @_QQclX64756D6D792E6D6C697200 constant : !fir.char<1,11> {
+    %0 = fir.string_lit "dummy.mlir\00"(11) : !fir.char<1,11>
+    fir.has_value %0 : !fir.char<1,11>
+  }
+  func.func private @_FortranACUFMemAlloc(i64, i32, !fir.ref<i8>, i32) -> !fir.llvm_ptr<i8> attributes {fir.runtime}
+}
+
+// CHECK-LABEL: llvm.func @_QQmain()
+// CHECK: llvm.call @_FortranACUFMemAlloc
+// CHECK: llvm.call @_FortranACUFAllocDesciptor



More information about the flang-commits mailing list