[llvm-branch-commits] [flang] [flang][cuda] Convert cuf.sync_descriptor to runtime call (PR #121524)
via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Thu Jan 2 14:05:39 PST 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-flang-fir-hlfir
@llvm/pr-subscribers-flang-runtime
Author: Valentin Clement (バレンタイン クレメン) (clementval)
<details>
<summary>Changes</summary>
Convert the op to a new entry point in the runtime `CUFSyncGlobalDescriptor`
---
Full diff: https://github.com/llvm/llvm-project/pull/121524.diff
4 Files Affected:
- (modified) flang/include/flang/Runtime/CUDA/descriptor.h (+4)
- (modified) flang/lib/Optimizer/Transforms/CUFOpConversion.cpp (+41-1)
- (modified) flang/runtime/CUDA/descriptor.cpp (+7)
- (added) flang/test/Fir/CUDA/cuda-sync-desc.mlir (+20)
``````````diff
diff --git a/flang/include/flang/Runtime/CUDA/descriptor.h b/flang/include/flang/Runtime/CUDA/descriptor.h
index 55878aaac57fb3..0ee7feca10e44c 100644
--- a/flang/include/flang/Runtime/CUDA/descriptor.h
+++ b/flang/include/flang/Runtime/CUDA/descriptor.h
@@ -33,6 +33,10 @@ void *RTDECL(CUFGetDeviceAddress)(
void RTDECL(CUFDescriptorSync)(Descriptor *dst, const Descriptor *src,
const char *sourceFile = nullptr, int sourceLine = 0);
+/// Get the device address of registered with the \p hostPtr and sync them.
+void RTDECL(CUFSyncGlobalDescriptor)(
+ void *hostPtr, const char *sourceFile = nullptr, int sourceLine = 0);
+
} // extern "C"
} // namespace Fortran::runtime::cuda
diff --git a/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp b/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp
index fb0ef246546444..f08f9e412b8857 100644
--- a/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp
+++ b/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp
@@ -788,6 +788,45 @@ struct CUFLaunchOpConversion
const mlir::SymbolTable &symTab;
};
+struct CUFSyncDescriptorOpConversion
+ : public mlir::OpRewritePattern<cuf::SyncDescriptorOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ CUFSyncDescriptorOpConversion(mlir::MLIRContext *context,
+ const mlir::SymbolTable &symTab)
+ : OpRewritePattern(context), symTab{symTab} {}
+
+ mlir::LogicalResult
+ matchAndRewrite(cuf::SyncDescriptorOp op,
+ mlir::PatternRewriter &rewriter) const override {
+ auto mod = op->getParentOfType<mlir::ModuleOp>();
+ fir::FirOpBuilder builder(rewriter, mod);
+ mlir::Location loc = op.getLoc();
+
+ auto globalOp = mod.lookupSymbol<fir::GlobalOp>(op.getGlobalName());
+ if (!globalOp)
+ return mlir::failure();
+
+ auto hostAddr = builder.create<fir::AddrOfOp>(
+ loc, fir::ReferenceType::get(globalOp.getType()), op.getGlobalName());
+ mlir::func::FuncOp callee =
+ fir::runtime::getRuntimeFunc<mkRTKey(CUFSyncGlobalDescriptor)>(loc,
+ builder);
+ auto fTy = callee.getFunctionType();
+ mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc);
+ mlir::Value sourceLine =
+ fir::factory::locationToLineNo(builder, loc, fTy.getInput(2));
+ llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments(
+ builder, loc, fTy, hostAddr, sourceFile, sourceLine)};
+ builder.create<fir::CallOp>(loc, callee, args);
+ op.erase();
+ return mlir::success();
+ }
+
+private:
+ const mlir::SymbolTable &symTab;
+};
+
class CUFOpConversion : public fir::impl::CUFOpConversionBase<CUFOpConversion> {
public:
void runOnOperation() override {
@@ -851,7 +890,8 @@ void cuf::populateCUFToFIRConversionPatterns(
CUFFreeOpConversion>(patterns.getContext());
patterns.insert<CUFDataTransferOpConversion>(patterns.getContext(), symtab,
&dl, &converter);
- patterns.insert<CUFLaunchOpConversion>(patterns.getContext(), symtab);
+ patterns.insert<CUFLaunchOpConversion, CUFSyncDescriptorOpConversion>(
+ patterns.getContext(), symtab);
}
void cuf::populateFIRCUFConversionPatterns(const mlir::SymbolTable &symtab,
diff --git a/flang/runtime/CUDA/descriptor.cpp b/flang/runtime/CUDA/descriptor.cpp
index 391c47e84241d4..947eeb66aa3d6c 100644
--- a/flang/runtime/CUDA/descriptor.cpp
+++ b/flang/runtime/CUDA/descriptor.cpp
@@ -46,6 +46,13 @@ void RTDEF(CUFDescriptorSync)(Descriptor *dst, const Descriptor *src,
(void *)dst, (const void *)src, count, cudaMemcpyHostToDevice));
}
+void RTDEF(CUFSyncGlobalDescriptor)(
+ void *hostPtr, const char *sourceFile, int sourceLine) {
+ void *devAddr{RTNAME(CUFGetDeviceAddress)(hostPtr, sourceFile, sourceLine)};
+ RTNAME(CUFDescriptorSync)
+ ((Descriptor *)devAddr, (Descriptor *)hostPtr, sourceFile, sourceLine);
+}
+
RT_EXT_API_GROUP_END
}
} // namespace Fortran::runtime::cuda
diff --git a/flang/test/Fir/CUDA/cuda-sync-desc.mlir b/flang/test/Fir/CUDA/cuda-sync-desc.mlir
new file mode 100644
index 00000000000000..20b317f34a7f26
--- /dev/null
+++ b/flang/test/Fir/CUDA/cuda-sync-desc.mlir
@@ -0,0 +1,20 @@
+// RUN: fir-opt --cuf-convert %s | FileCheck %s
+
+module attributes {dlti.dl_spec = #dlti.dl_spec<i16 = dense<16> : vector<2xi64>, i8 = dense<8> : vector<2xi64>, i32 = dense<32> : vector<2xi64>, f80 = dense<128> : vector<2xi64>, i1 = dense<8> : vector<2xi64>, !llvm.ptr = dense<64> : vector<4xi64>, i64 = dense<64> : vector<2xi64>, i128 = dense<128> : vector<2xi64>, f128 = dense<128> : vector<2xi64>, !llvm.ptr<270> = dense<32> : vector<4xi64>, f64 = dense<64> : vector<2xi64>, f16 = dense<16> : vector<2xi64>, !llvm.ptr<271> = dense<32> : vector<4xi64>, !llvm.ptr<272> = dense<64> : vector<4xi64>, "dlti.endianness" = "little", "dlti.stack_alignment" = 128 : i64>, fir.defaultkind = "a1c4d8i4l4r4", fir.kindmap = "", llvm.data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-i128:128-f80:128-n8:16:32:64-S128", llvm.ident = "flang version 20.0.0 (git at github.com:clementval/llvm-project.git f37e52237791f58438790c77edeb8de08f692987)", llvm.target_triple = "x86_64-unknown-linux-gnu"} {
+ fir.global @_QMdevptrEdev_ptr {data_attr = #cuf.cuda<device>} : !fir.box<!fir.ptr<!fir.array<?xf32>>> {
+ %0 = fir.zero_bits !fir.ptr<!fir.array<?xf32>>
+ %c0 = arith.constant 0 : index
+ %1 = fir.shape %c0 : (index) -> !fir.shape<1>
+ %2 = fir.embox %0(%1) {allocator_idx = 2 : i32} : (!fir.ptr<!fir.array<?xf32>>, !fir.shape<1>) -> !fir.box<!fir.ptr<!fir.array<?xf32>>>
+ fir.has_value %2 : !fir.box<!fir.ptr<!fir.array<?xf32>>>
+ }
+ func.func @_QQmain() {
+ cuf.sync_descriptor @_QMdevptrEdev_ptr
+ return
+ }
+}
+
+// CHECK-LABEL: func.func @_QQmain()
+// CHECK: %[[HOST_ADDR:.*]] = fir.address_of(@_QMdevptrEdev_ptr) : !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>
+// CHECK: %[[HOST_ADDR_PTR:.*]] = fir.convert %[[HOST_ADDR]] : (!fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>) -> !fir.llvm_ptr<i8>
+// CHECK: fir.call @_FortranACUFSyncGlobalDescriptor(%[[HOST_ADDR_PTR]], %{{.*}}, %{{.*}}) : (!fir.llvm_ptr<i8>, !fir.ref<i8>, i32)
``````````
</details>
https://github.com/llvm/llvm-project/pull/121524
More information about the llvm-branch-commits
mailing list