[llvm-branch-commits] [flang] [flang][cuda] Convert cuf.sync_descriptor to runtime call (PR #121524)

via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Thu Jan 2 14:05:39 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-flang-fir-hlfir

@llvm/pr-subscribers-flang-runtime

Author: Valentin Clement (バレンタイン クレメン) (clementval)

<details>
<summary>Changes</summary>

Convert the op to a new entry point in the runtime `CUFSyncGlobalDescriptor`

---
Full diff: https://github.com/llvm/llvm-project/pull/121524.diff


4 Files Affected:

- (modified) flang/include/flang/Runtime/CUDA/descriptor.h (+4) 
- (modified) flang/lib/Optimizer/Transforms/CUFOpConversion.cpp (+41-1) 
- (modified) flang/runtime/CUDA/descriptor.cpp (+7) 
- (added) flang/test/Fir/CUDA/cuda-sync-desc.mlir (+20) 


``````````diff
diff --git a/flang/include/flang/Runtime/CUDA/descriptor.h b/flang/include/flang/Runtime/CUDA/descriptor.h
index 55878aaac57fb3..0ee7feca10e44c 100644
--- a/flang/include/flang/Runtime/CUDA/descriptor.h
+++ b/flang/include/flang/Runtime/CUDA/descriptor.h
@@ -33,6 +33,10 @@ void *RTDECL(CUFGetDeviceAddress)(
 void RTDECL(CUFDescriptorSync)(Descriptor *dst, const Descriptor *src,
     const char *sourceFile = nullptr, int sourceLine = 0);
 
+/// Get the device address of registered with the \p hostPtr and sync them.
+void RTDECL(CUFSyncGlobalDescriptor)(
+    void *hostPtr, const char *sourceFile = nullptr, int sourceLine = 0);
+
 } // extern "C"
 
 } // namespace Fortran::runtime::cuda
diff --git a/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp b/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp
index fb0ef246546444..f08f9e412b8857 100644
--- a/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp
+++ b/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp
@@ -788,6 +788,45 @@ struct CUFLaunchOpConversion
   const mlir::SymbolTable &symTab;
 };
 
+struct CUFSyncDescriptorOpConversion
+    : public mlir::OpRewritePattern<cuf::SyncDescriptorOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  CUFSyncDescriptorOpConversion(mlir::MLIRContext *context,
+                                const mlir::SymbolTable &symTab)
+      : OpRewritePattern(context), symTab{symTab} {}
+
+  mlir::LogicalResult
+  matchAndRewrite(cuf::SyncDescriptorOp op,
+                  mlir::PatternRewriter &rewriter) const override {
+    auto mod = op->getParentOfType<mlir::ModuleOp>();
+    fir::FirOpBuilder builder(rewriter, mod);
+    mlir::Location loc = op.getLoc();
+
+    auto globalOp = mod.lookupSymbol<fir::GlobalOp>(op.getGlobalName());
+    if (!globalOp)
+      return mlir::failure();
+
+    auto hostAddr = builder.create<fir::AddrOfOp>(
+        loc, fir::ReferenceType::get(globalOp.getType()), op.getGlobalName());
+    mlir::func::FuncOp callee =
+        fir::runtime::getRuntimeFunc<mkRTKey(CUFSyncGlobalDescriptor)>(loc,
+                                                                       builder);
+    auto fTy = callee.getFunctionType();
+    mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc);
+    mlir::Value sourceLine =
+        fir::factory::locationToLineNo(builder, loc, fTy.getInput(2));
+    llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments(
+        builder, loc, fTy, hostAddr, sourceFile, sourceLine)};
+    builder.create<fir::CallOp>(loc, callee, args);
+    op.erase();
+    return mlir::success();
+  }
+
+private:
+  const mlir::SymbolTable &symTab;
+};
+
 class CUFOpConversion : public fir::impl::CUFOpConversionBase<CUFOpConversion> {
 public:
   void runOnOperation() override {
@@ -851,7 +890,8 @@ void cuf::populateCUFToFIRConversionPatterns(
                   CUFFreeOpConversion>(patterns.getContext());
   patterns.insert<CUFDataTransferOpConversion>(patterns.getContext(), symtab,
                                                &dl, &converter);
-  patterns.insert<CUFLaunchOpConversion>(patterns.getContext(), symtab);
+  patterns.insert<CUFLaunchOpConversion, CUFSyncDescriptorOpConversion>(
+      patterns.getContext(), symtab);
 }
 
 void cuf::populateFIRCUFConversionPatterns(const mlir::SymbolTable &symtab,
diff --git a/flang/runtime/CUDA/descriptor.cpp b/flang/runtime/CUDA/descriptor.cpp
index 391c47e84241d4..947eeb66aa3d6c 100644
--- a/flang/runtime/CUDA/descriptor.cpp
+++ b/flang/runtime/CUDA/descriptor.cpp
@@ -46,6 +46,13 @@ void RTDEF(CUFDescriptorSync)(Descriptor *dst, const Descriptor *src,
       (void *)dst, (const void *)src, count, cudaMemcpyHostToDevice));
 }
 
+void RTDEF(CUFSyncGlobalDescriptor)(
+    void *hostPtr, const char *sourceFile, int sourceLine) {
+  void *devAddr{RTNAME(CUFGetDeviceAddress)(hostPtr, sourceFile, sourceLine)};
+  RTNAME(CUFDescriptorSync)
+  ((Descriptor *)devAddr, (Descriptor *)hostPtr, sourceFile, sourceLine);
+}
+
 RT_EXT_API_GROUP_END
 }
 } // namespace Fortran::runtime::cuda
diff --git a/flang/test/Fir/CUDA/cuda-sync-desc.mlir b/flang/test/Fir/CUDA/cuda-sync-desc.mlir
new file mode 100644
index 00000000000000..20b317f34a7f26
--- /dev/null
+++ b/flang/test/Fir/CUDA/cuda-sync-desc.mlir
@@ -0,0 +1,20 @@
+// RUN: fir-opt --cuf-convert %s | FileCheck %s
+
+module attributes {dlti.dl_spec = #dlti.dl_spec<i16 = dense<16> : vector<2xi64>, i8 = dense<8> : vector<2xi64>, i32 = dense<32> : vector<2xi64>, f80 = dense<128> : vector<2xi64>, i1 = dense<8> : vector<2xi64>, !llvm.ptr = dense<64> : vector<4xi64>, i64 = dense<64> : vector<2xi64>, i128 = dense<128> : vector<2xi64>, f128 = dense<128> : vector<2xi64>, !llvm.ptr<270> = dense<32> : vector<4xi64>, f64 = dense<64> : vector<2xi64>, f16 = dense<16> : vector<2xi64>, !llvm.ptr<271> = dense<32> : vector<4xi64>, !llvm.ptr<272> = dense<64> : vector<4xi64>, "dlti.endianness" = "little", "dlti.stack_alignment" = 128 : i64>, fir.defaultkind = "a1c4d8i4l4r4", fir.kindmap = "", llvm.data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-i128:128-f80:128-n8:16:32:64-S128", llvm.ident = "flang version 20.0.0 (git at github.com:clementval/llvm-project.git f37e52237791f58438790c77edeb8de08f692987)", llvm.target_triple = "x86_64-unknown-linux-gnu"} {
+  fir.global @_QMdevptrEdev_ptr {data_attr = #cuf.cuda<device>} : !fir.box<!fir.ptr<!fir.array<?xf32>>> {
+    %0 = fir.zero_bits !fir.ptr<!fir.array<?xf32>>
+    %c0 = arith.constant 0 : index
+    %1 = fir.shape %c0 : (index) -> !fir.shape<1>
+    %2 = fir.embox %0(%1) {allocator_idx = 2 : i32} : (!fir.ptr<!fir.array<?xf32>>, !fir.shape<1>) -> !fir.box<!fir.ptr<!fir.array<?xf32>>>
+    fir.has_value %2 : !fir.box<!fir.ptr<!fir.array<?xf32>>>
+  }
+  func.func @_QQmain() {
+    cuf.sync_descriptor @_QMdevptrEdev_ptr
+    return
+  }
+}
+
+// CHECK-LABEL: func.func @_QQmain()
+// CHECK: %[[HOST_ADDR:.*]] = fir.address_of(@_QMdevptrEdev_ptr) : !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>
+// CHECK: %[[HOST_ADDR_PTR:.*]] = fir.convert %[[HOST_ADDR]] : (!fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>) -> !fir.llvm_ptr<i8>
+// CHECK: fir.call @_FortranACUFSyncGlobalDescriptor(%[[HOST_ADDR_PTR]], %{{.*}}, %{{.*}}) : (!fir.llvm_ptr<i8>, !fir.ref<i8>, i32)

``````````

</details>


https://github.com/llvm/llvm-project/pull/121524


More information about the llvm-branch-commits mailing list