[flang-commits] [flang] ef8d88c - [flang][cuda] Support scalar to array data transfer (#115273)

via flang-commits flang-commits at lists.llvm.org
Thu Nov 7 09:27:14 PST 2024


Author: Valentin Clement (バレンタイン クレメン)
Date: 2024-11-07T09:27:10-08:00
New Revision: ef8d88ca1af0a8348bc616e93d50919462224d9b

URL: https://github.com/llvm/llvm-project/commit/ef8d88ca1af0a8348bc616e93d50919462224d9b
DIFF: https://github.com/llvm/llvm-project/commit/ef8d88ca1af0a8348bc616e93d50919462224d9b.diff

LOG: [flang][cuda] Support scalar to array data transfer (#115273)

Do it via descriptor assignment until we have a more efficient way.

Added: 
    

Modified: 
    flang/lib/Optimizer/Transforms/CUFOpConversion.cpp
    flang/test/Fir/CUDA/cuda-data-transfer.fir

Removed: 
    


################################################################################
diff  --git a/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp b/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp
index 6187ca03d2c411..881f54133ce732 100644
--- a/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp
+++ b/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp
@@ -448,6 +448,53 @@ static mlir::Value getShapeFromDecl(mlir::Value src) {
   return mlir::Value{};
 }
 
+static mlir::Value emboxSrc(mlir::PatternRewriter &rewriter,
+                            cuf::DataTransferOp op,
+                            const mlir::SymbolTable &symtab) {
+  auto mod = op->getParentOfType<mlir::ModuleOp>();
+  mlir::Location loc = op.getLoc();
+  fir::FirOpBuilder builder(rewriter, mod);
+  mlir::Value addr;
+  mlir::Type srcTy = fir::unwrapRefType(op.getSrc().getType());
+  if (fir::isa_trivial(srcTy) &&
+      mlir::matchPattern(op.getSrc().getDefiningOp(), mlir::m_Constant())) {
+    // Put constant in memory if it is not.
+    mlir::Value alloc = builder.createTemporary(loc, srcTy);
+    builder.create<fir::StoreOp>(loc, op.getSrc(), alloc);
+    addr = alloc;
+  } else {
+    addr = getDeviceAddress(rewriter, op.getSrcMutable(), symtab);
+  }
+  llvm::SmallVector<mlir::Value> lenParams;
+  mlir::Type boxTy = fir::BoxType::get(srcTy);
+  mlir::Value box =
+      builder.createBox(loc, boxTy, addr, getShapeFromDecl(op.getSrc()),
+                        /*slice=*/nullptr, lenParams,
+                        /*tdesc=*/nullptr);
+  mlir::Value src = builder.createTemporary(loc, box.getType());
+  builder.create<fir::StoreOp>(loc, box, src);
+  return src;
+}
+
+static mlir::Value emboxDst(mlir::PatternRewriter &rewriter,
+                            cuf::DataTransferOp op,
+                            const mlir::SymbolTable &symtab) {
+  auto mod = op->getParentOfType<mlir::ModuleOp>();
+  mlir::Location loc = op.getLoc();
+  fir::FirOpBuilder builder(rewriter, mod);
+  mlir::Type dstTy = fir::unwrapRefType(op.getDst().getType());
+  mlir::Value dstAddr = getDeviceAddress(rewriter, op.getDstMutable(), symtab);
+  mlir::Type dstBoxTy = fir::BoxType::get(dstTy);
+  llvm::SmallVector<mlir::Value> lenParams;
+  mlir::Value dstBox =
+      builder.createBox(loc, dstBoxTy, dstAddr, getShapeFromDecl(op.getDst()),
+                        /*slice=*/nullptr, lenParams,
+                        /*tdesc=*/nullptr);
+  mlir::Value dst = builder.createTemporary(loc, dstBox.getType());
+  builder.create<fir::StoreOp>(loc, dstBox, dst);
+  return dst;
+}
+
 struct CUFDataTransferOpConversion
     : public mlir::OpRewritePattern<cuf::DataTransferOp> {
   using OpRewritePattern::OpRewritePattern;
@@ -486,10 +533,22 @@ struct CUFDataTransferOpConversion
         !mlir::isa<fir::BaseBoxType>(dstTy)) {
 
       if (fir::isa_trivial(srcTy) && !fir::isa_trivial(dstTy)) {
-        // TODO: scalar to array data transfer.
-        mlir::emitError(loc,
-                        "not yet implemented: scalar to array data transfer\n");
-        return mlir::failure();
+        // Initialization of an array from a scalar value should be implemented
+        // via a kernel launch. Use the flan runtime via the Assign function
+        // until we have more infrastructure.
+        mlir::Value src = emboxSrc(rewriter, op, symtab);
+        mlir::Value dst = emboxDst(rewriter, op, symtab);
+        mlir::func::FuncOp func = fir::runtime::getRuntimeFunc<mkRTKey(
+            CUFDataTransferDescDescNoRealloc)>(loc, builder);
+        auto fTy = func.getFunctionType();
+        mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc);
+        mlir::Value sourceLine =
+            fir::factory::locationToLineNo(builder, loc, fTy.getInput(4));
+        llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments(
+            builder, loc, fTy, dst, src, modeValue, sourceFile, sourceLine)};
+        builder.create<fir::CallOp>(loc, func, args);
+        rewriter.eraseOp(op);
+        return mlir::success();
       }
 
       mlir::Type i64Ty = builder.getI64Type();
@@ -548,29 +607,8 @@ struct CUFDataTransferOpConversion
       mlir::Value dst = op.getDst();
       mlir::Value src = op.getSrc();
 
-      if (!mlir::isa<fir::BaseBoxType>(srcTy)) {
-        // If src is not a descriptor, create one.
-        mlir::Value addr;
-        if (fir::isa_trivial(srcTy) &&
-            mlir::matchPattern(op.getSrc().getDefiningOp(),
-                               mlir::m_Constant())) {
-          // Put constant in memory if it is not.
-          mlir::Value alloc = builder.createTemporary(loc, srcTy);
-          builder.create<fir::StoreOp>(loc, op.getSrc(), alloc);
-          addr = alloc;
-        } else {
-          addr = getDeviceAddress(rewriter, op.getSrcMutable(), symtab);
-        }
-        mlir::Type boxTy = fir::BoxType::get(srcTy);
-        llvm::SmallVector<mlir::Value> lenParams;
-        mlir::Value box =
-            builder.createBox(loc, boxTy, addr, getShapeFromDecl(src),
-                              /*slice=*/nullptr, lenParams,
-                              /*tdesc=*/nullptr);
-        mlir::Value memBox = builder.createTemporary(loc, box.getType());
-        builder.create<fir::StoreOp>(loc, box, memBox);
-        src = memBox;
-      }
+      if (!mlir::isa<fir::BaseBoxType>(srcTy))
+        src = emboxSrc(rewriter, op, symtab);
 
       auto fTy = func.getFunctionType();
       mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc);
@@ -582,16 +620,7 @@ struct CUFDataTransferOpConversion
       rewriter.eraseOp(op);
     } else {
       // Transfer from a descriptor.
-
-      mlir::Value addr = getDeviceAddress(rewriter, op.getDstMutable(), symtab);
-      mlir::Type boxTy = fir::BoxType::get(dstTy);
-      llvm::SmallVector<mlir::Value> lenParams;
-      mlir::Value box =
-          builder.createBox(loc, boxTy, addr, getShapeFromDecl(op.getDst()),
-                            /*slice=*/nullptr, lenParams,
-                            /*tdesc=*/nullptr);
-      mlir::Value memBox = builder.createTemporary(loc, box.getType());
-      builder.create<fir::StoreOp>(loc, box, memBox);
+      mlir::Value dst = emboxDst(rewriter, op, symtab);
 
       mlir::func::FuncOp func = fir::runtime::getRuntimeFunc<mkRTKey(
           CUFDataTransferDescDescNoRealloc)>(loc, builder);
@@ -601,7 +630,7 @@ struct CUFDataTransferOpConversion
       mlir::Value sourceLine =
           fir::factory::locationToLineNo(builder, loc, fTy.getInput(4));
       llvm::SmallVector<mlir::Value> args{
-          fir::runtime::createArguments(builder, loc, fTy, memBox, op.getSrc(),
+          fir::runtime::createArguments(builder, loc, fTy, dst, op.getSrc(),
                                         modeValue, sourceFile, sourceLine)};
       builder.create<fir::CallOp>(loc, func, args);
       rewriter.eraseOp(op);

diff  --git a/flang/test/Fir/CUDA/cuda-data-transfer.fir b/flang/test/Fir/CUDA/cuda-data-transfer.fir
index d9588942b21e81..8497aee2e2cf9c 100644
--- a/flang/test/Fir/CUDA/cuda-data-transfer.fir
+++ b/flang/test/Fir/CUDA/cuda-data-transfer.fir
@@ -281,4 +281,18 @@ func.func @_QPdesc_global_ptr() {
 // CHECK: %[[AHOST_BOXNONE:.*]] = fir.convert %[[TEMP_BOX]] : (!fir.ref<!fir.box<!fir.array<10xi32>>>) -> !fir.ref<!fir.box<none>>
 // CHECK: fir.call @_FortranACUFDataTransferGlobalDescDesc(%[[ADEV_BOXNONE]], %[[AHOST_BOXNONE]], %c0{{.*}}, %{{.*}}, %{{.*}}) : (!fir.ref<!fir.box<none>>, !fir.ref<!fir.box<none>>, i32, !fir.ref<i8>, i32) -> none
 
+func.func @_QPscalar_to_array() {
+  %c1_i32 = arith.constant 1 : i32
+  %c10 = arith.constant 10 : index
+  %0 = cuf.alloc !fir.array<10xi32> {bindc_name = "a", data_attr = #cuf.cuda<device>, uniq_name = "_QFscalar_to_arrayEa"} -> !fir.ref<!fir.array<10xi32>>
+  %1 = fir.shape %c10 : (index) -> !fir.shape<1>
+  %2:2 = hlfir.declare %0(%1) {data_attr = #cuf.cuda<device>, uniq_name = "_QFscalar_to_arrayEa"} : (!fir.ref<!fir.array<10xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<10xi32>>, !fir.ref<!fir.array<10xi32>>)
+  cuf.data_transfer %c1_i32 to %2#0 {transfer_kind = #cuf.cuda_transfer<host_device>} : i32, !fir.ref<!fir.array<10xi32>>
+  cuf.free %2#1 : !fir.ref<!fir.array<10xi32>> {data_attr = #cuf.cuda<device>}
+  return
+}
+
+// CHECK-LABEL: func.func @_QPscalar_to_array()
+// CHECK: _FortranACUFDataTransferDescDescNoRealloc
+
 } // end of module


        


More information about the flang-commits mailing list