[flang-commits] [flang] 864902e - [flang][cuda] Call CUFGetDeviceAddress to get global device address from host address (#112989)
via flang-commits
flang-commits at lists.llvm.org
Fri Oct 18 17:35:41 PDT 2024
Author: Renaud Kauffmann
Date: 2024-10-18T17:35:38-07:00
New Revision: 864902e9b4d8bc6d3f0852d5c475e3dc97dd8335
URL: https://github.com/llvm/llvm-project/commit/864902e9b4d8bc6d3f0852d5c475e3dc97dd8335
DIFF: https://github.com/llvm/llvm-project/commit/864902e9b4d8bc6d3f0852d5c475e3dc97dd8335.diff
LOG: [flang][cuda] Call CUFGetDeviceAddress to get global device address from host address (#112989)
Added:
Modified:
flang/include/flang/Optimizer/Transforms/CufOpConversion.h
flang/lib/Optimizer/Transforms/CufOpConversion.cpp
flang/test/Fir/CUDA/cuda-data-transfer.fir
Removed:
################################################################################
diff --git a/flang/include/flang/Optimizer/Transforms/CufOpConversion.h b/flang/include/flang/Optimizer/Transforms/CufOpConversion.h
index 79ce4ac5c6cbc0..0a71cdfddec1ab 100644
--- a/flang/include/flang/Optimizer/Transforms/CufOpConversion.h
+++ b/flang/include/flang/Optimizer/Transforms/CufOpConversion.h
@@ -18,12 +18,14 @@ class LLVMTypeConverter;
namespace mlir {
class DataLayout;
+class SymbolTable;
}
namespace cuf {
void populateCUFToFIRConversionPatterns(const fir::LLVMTypeConverter &converter,
mlir::DataLayout &dl,
+ const mlir::SymbolTable &symtab,
mlir::RewritePatternSet &patterns);
} // namespace cuf
diff --git a/flang/lib/Optimizer/Transforms/CufOpConversion.cpp b/flang/lib/Optimizer/Transforms/CufOpConversion.cpp
index 91ef1259332de9..9df559ee0ab1f8 100644
--- a/flang/lib/Optimizer/Transforms/CufOpConversion.cpp
+++ b/flang/lib/Optimizer/Transforms/CufOpConversion.cpp
@@ -77,6 +77,69 @@ static bool hasDoubleDescriptors(OpTy op) {
return false;
}
+static mlir::Value createConvertOp(mlir::PatternRewriter &rewriter,
+ mlir::Location loc, mlir::Type toTy,
+ mlir::Value val) {
+ if (val.getType() != toTy)
+ return rewriter.create<fir::ConvertOp>(loc, toTy, val);
+ return val;
+}
+
+mlir::Value getDeviceAddress(mlir::PatternRewriter &rewriter,
+ mlir::OpOperand &operand,
+ const mlir::SymbolTable &symtab) {
+ mlir::Value v = operand.get();
+ auto declareOp = v.getDefiningOp<fir::DeclareOp>();
+ if (!declareOp)
+ return v;
+
+ auto addrOfOp = declareOp.getMemref().getDefiningOp<fir::AddrOfOp>();
+ if (!addrOfOp)
+ return v;
+
+ auto globalOp = symtab.lookup<fir::GlobalOp>(
+ addrOfOp.getSymbol().getRootReference().getValue());
+
+ if (!globalOp)
+ return v;
+
+ bool isDevGlobal{false};
+ auto attr = globalOp.getDataAttrAttr();
+ if (attr) {
+ switch (attr.getValue()) {
+ case cuf::DataAttribute::Device:
+ case cuf::DataAttribute::Managed:
+ case cuf::DataAttribute::Pinned:
+ isDevGlobal = true;
+ break;
+ default:
+ break;
+ }
+ }
+ if (!isDevGlobal)
+ return v;
+ mlir::OpBuilder::InsertionGuard guard(rewriter);
+ rewriter.setInsertionPoint(operand.getOwner());
+ auto loc = declareOp.getLoc();
+ auto mod = declareOp->getParentOfType<mlir::ModuleOp>();
+ fir::FirOpBuilder builder(rewriter, mod);
+
+ mlir::func::FuncOp callee =
+ fir::runtime::getRuntimeFunc<mkRTKey(CUFGetDeviceAddress)>(loc, builder);
+ auto fTy = callee.getFunctionType();
+ auto toTy = fTy.getInput(0);
+ mlir::Value inputArg =
+ createConvertOp(rewriter, loc, toTy, declareOp.getResult());
+ mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc);
+ mlir::Value sourceLine =
+ fir::factory::locationToLineNo(builder, loc, fTy.getInput(2));
+ llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments(
+ builder, loc, fTy, inputArg, sourceFile, sourceLine)};
+ auto call = rewriter.create<fir::CallOp>(loc, callee, args);
+
+ return call->getResult(0);
+}
+
template <typename OpTy>
static mlir::LogicalResult convertOpToCall(OpTy op,
mlir::PatternRewriter &rewriter,
@@ -363,18 +426,14 @@ struct CufFreeOpConversion : public mlir::OpRewritePattern<cuf::FreeOp> {
}
};
-static mlir::Value createConvertOp(mlir::PatternRewriter &rewriter,
- mlir::Location loc, mlir::Type toTy,
- mlir::Value val) {
- if (val.getType() != toTy)
- return rewriter.create<fir::ConvertOp>(loc, toTy, val);
- return val;
-}
-
struct CufDataTransferOpConversion
: public mlir::OpRewritePattern<cuf::DataTransferOp> {
using OpRewritePattern::OpRewritePattern;
+ CufDataTransferOpConversion(mlir::MLIRContext *context,
+ const mlir::SymbolTable &symtab)
+ : OpRewritePattern(context), symtab{symtab} {}
+
mlir::LogicalResult
matchAndRewrite(cuf::DataTransferOp op,
mlir::PatternRewriter &rewriter) const override {
@@ -445,9 +504,11 @@ struct CufDataTransferOpConversion
mlir::Value sourceLine =
fir::factory::locationToLineNo(builder, loc, fTy.getInput(5));
- llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments(
- builder, loc, fTy, op.getDst(), op.getSrc(), bytes, modeValue,
- sourceFile, sourceLine)};
+ mlir::Value dst = getDeviceAddress(rewriter, op.getDstMutable(), symtab);
+ mlir::Value src = getDeviceAddress(rewriter, op.getSrcMutable(), symtab);
+ llvm::SmallVector<mlir::Value> args{
+ fir::runtime::createArguments(builder, loc, fTy, dst, src, bytes,
+ modeValue, sourceFile, sourceLine)};
builder.create<fir::CallOp>(loc, func, args);
rewriter.eraseOp(op);
return mlir::success();
@@ -552,6 +613,9 @@ struct CufDataTransferOpConversion
}
return mlir::success();
}
+
+private:
+ const mlir::SymbolTable &symtab;
};
class CufOpConversion : public fir::impl::CufOpConversionBase<CufOpConversion> {
@@ -565,13 +629,15 @@ class CufOpConversion : public fir::impl::CufOpConversionBase<CufOpConversion> {
mlir::ModuleOp module = mlir::dyn_cast<mlir::ModuleOp>(op);
if (!module)
return signalPassFailure();
+ mlir::SymbolTable symtab(module);
std::optional<mlir::DataLayout> dl =
fir::support::getOrSetDataLayout(module, /*allowDefaultLayout=*/false);
fir::LLVMTypeConverter typeConverter(module, /*applyTBAA=*/false,
/*forceUnifiedTBAATree=*/false, *dl);
target.addLegalDialect<fir::FIROpsDialect, mlir::arith::ArithDialect>();
- cuf::populateCUFToFIRConversionPatterns(typeConverter, *dl, patterns);
+ cuf::populateCUFToFIRConversionPatterns(typeConverter, *dl, symtab,
+ patterns);
if (mlir::failed(mlir::applyPartialConversion(getOperation(), target,
std::move(patterns)))) {
mlir::emitError(mlir::UnknownLoc::get(ctx),
@@ -584,9 +650,9 @@ class CufOpConversion : public fir::impl::CufOpConversionBase<CufOpConversion> {
void cuf::populateCUFToFIRConversionPatterns(
const fir::LLVMTypeConverter &converter, mlir::DataLayout &dl,
- mlir::RewritePatternSet &patterns) {
+ const mlir::SymbolTable &symtab, mlir::RewritePatternSet &patterns) {
patterns.insert<CufAllocOpConversion>(patterns.getContext(), &dl, &converter);
patterns.insert<CufAllocateOpConversion, CufDeallocateOpConversion,
- CufFreeOpConversion, CufDataTransferOpConversion>(
- patterns.getContext());
+ CufFreeOpConversion>(patterns.getContext());
+ patterns.insert<CufDataTransferOpConversion>(patterns.getContext(), symtab);
}
diff --git a/flang/test/Fir/CUDA/cuda-data-transfer.fir b/flang/test/Fir/CUDA/cuda-data-transfer.fir
index ed894aed5534a0..c33c50115b9fc0 100644
--- a/flang/test/Fir/CUDA/cuda-data-transfer.fir
+++ b/flang/test/Fir/CUDA/cuda-data-transfer.fir
@@ -189,4 +189,47 @@ func.func @_QPsub7() {
// CHECK: %[[SRC:.*]] = fir.convert %[[IHOST]]#0 : (!fir.ref<!fir.array<10xi32>>) -> !fir.llvm_ptr<i8>
// CHECK: fir.call @_FortranACUFDataTransferPtrPtr(%[[DST]], %[[SRC]], %[[BYTES]], %c0{{.*}}, %{{.*}}, %{{.*}}) : (!fir.llvm_ptr<i8>, !fir.llvm_ptr<i8>, i64, i32, !fir.ref<i8>, i32) -> none
+fir.global @_QMmtestsEn(dense<[3, 4, 5, 6, 7]> : tensor<5xi32>) {data_attr = #cuf.cuda<device>} : !fir.array<5xi32>
+func.func @_QPsub8() attributes {fir.bindc_name = "t"} {
+ %c5 = arith.constant 5 : index
+ %0 = fir.alloca !fir.array<5xi32> {bindc_name = "m", uniq_name = "_QFEm"}
+ %1 = fir.shape %c5 : (index) -> !fir.shape<1>
+ %2 = fir.declare %0(%1) {uniq_name = "_QFEm"} : (!fir.ref<!fir.array<5xi32>>, !fir.shape<1>) -> !fir.ref<!fir.array<5xi32>>
+ %3 = fir.address_of(@_QMmtestsEn) : !fir.ref<!fir.array<5xi32>>
+ %4 = fir.declare %3(%1) {data_attr = #cuf.cuda<device>, uniq_name = "_QMmtestsEn"} : (!fir.ref<!fir.array<5xi32>>, !fir.shape<1>) -> !fir.ref<!fir.array<5xi32>>
+ cuf.data_transfer %4 to %2 {transfer_kind = #cuf.cuda_transfer<device_host>} : !fir.ref<!fir.array<5xi32>>, !fir.ref<!fir.array<5xi32>>
+ return
+}
+
+// CHECK-LABEL: func.func @_QPsub8()
+// CHECK: %[[ALLOCA:.*]] = fir.alloca !fir.array<5xi32>
+// CHECK: %[[LOCAL:.*]] = fir.declare %[[ALLOCA]]
+// CHECK: %[[GBL:.*]] = fir.address_of(@_QMmtestsEn) : !fir.ref<!fir.array<5xi32>>
+// CHECK: %[[DECL:.*]] = fir.declare %[[GBL]]
+// CHECK: %[[HOST:.*]] = fir.convert %[[DECL]] : (!fir.ref<!fir.array<5xi32>>) -> !fir.llvm_ptr<i8>
+// CHECK: %[[SRC:.*]] = fir.call @_FortranACUFGetDeviceAddress(%[[HOST]], %{{.*}}, %{{.*}}) : (!fir.llvm_ptr<i8>, !fir.ref<i8>, i32) -> !fir.llvm_ptr<i8>
+// CHECK: %[[DST:.*]] = fir.convert %[[LOCAL]] : (!fir.ref<!fir.array<5xi32>>) -> !fir.llvm_ptr<i8>
+// CHECK: fir.call @_FortranACUFDataTransferPtrPtr(%[[DST]], %[[SRC]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (!fir.llvm_ptr<i8>, !fir.llvm_ptr<i8>, i64, i32, !fir.ref<i8>, i32) -> none
+
+
+func.func @_QPsub9() {
+ %c5 = arith.constant 5 : index
+ %0 = fir.alloca !fir.array<5xi32> {bindc_name = "m", uniq_name = "_QFtest9Em"}
+ %1 = fir.shape %c5 : (index) -> !fir.shape<1>
+ %2 = fir.declare %0(%1) {uniq_name = "_QFtest9Em"} : (!fir.ref<!fir.array<5xi32>>, !fir.shape<1>) -> !fir.ref<!fir.array<5xi32>>
+ %3 = fir.address_of(@_QMmtestsEn) : !fir.ref<!fir.array<5xi32>>
+ %4 = fir.declare %3(%1) {data_attr = #cuf.cuda<device>, uniq_name = "_QMmtestsEn"} : (!fir.ref<!fir.array<5xi32>>, !fir.shape<1>) -> !fir.ref<!fir.array<5xi32>>
+ cuf.data_transfer %2 to %4 {transfer_kind = #cuf.cuda_transfer<host_device>} : !fir.ref<!fir.array<5xi32>>, !fir.ref<!fir.array<5xi32>>
+ return
+}
+
+// CHECK-LABEL: func.func @_QPsub9()
+// CHECK: %[[ALLOCA:.*]] = fir.alloca !fir.array<5xi32>
+// CHECK: %[[LOCAL:.*]] = fir.declare %[[ALLOCA]]
+// CHECK: %[[GBL:.*]] = fir.address_of(@_QMmtestsEn) : !fir.ref<!fir.array<5xi32>>
+// CHECK: %[[DECL:.*]] = fir.declare %[[GBL]]
+// CHECK: %[[HOST:.*]] = fir.convert %[[DECL]] : (!fir.ref<!fir.array<5xi32>>) -> !fir.llvm_ptr<i8>
+// CHECK: %[[DST:.*]] = fir.call @_FortranACUFGetDeviceAddress(%[[HOST]], %{{.*}}, %{{.*}}) : (!fir.llvm_ptr<i8>, !fir.ref<i8>, i32) -> !fir.llvm_ptr<i8>
+// CHECK: %[[SRC:.*]] = fir.convert %[[LOCAL]] : (!fir.ref<!fir.array<5xi32>>) -> !fir.llvm_ptr<i8>
+// CHECK: fir.call @_FortranACUFDataTransferPtrPtr(%[[DST]], %[[SRC]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (!fir.llvm_ptr<i8>, !fir.llvm_ptr<i8>, i64, i32, !fir.ref<i8>, i32) -> none
} // end of module
More information about the flang-commits
mailing list