[flang-commits] [flang] [flang][cuda] Add fir.cuda_allocate operation (PR #88586)

Valentin Clement バレンタイン クレメン via flang-commits flang-commits at lists.llvm.org
Fri Apr 12 15:40:13 PDT 2024


https://github.com/clementval created https://github.com/llvm/llvm-project/pull/88586

Allocatable with cuda device attribute have special semantic for the allocate statement. In flang the allocate statement is lowered to a sequence of runtime call initializing the descriptor and then allocating the descriptor data. This new operation will replace the last runtime call and abstract all the device memory allocation needed. 
The lowering patch will follow. 

>From c0b5d4d1344cd450ca3ceae3206610f6c4978484 Mon Sep 17 00:00:00 2001
From: Valentin Clement <clementval at gmail.com>
Date: Fri, 12 Apr 2024 13:17:42 -0700
Subject: [PATCH] [flang][cuda] Add fir.cuda_allocate operation

---
 .../include/flang/Optimizer/Dialect/FIROps.td | 31 ++++++++
 .../flang/Optimizer/Dialect/FIRTypes.td       |  1 +
 flang/lib/Optimizer/Dialect/FIROps.cpp        | 19 +++++
 flang/test/Fir/cuf-invalid.fir                | 50 +++++++++++++
 flang/test/Fir/cuf.mlir                       | 70 +++++++++++++++++++
 5 files changed, 171 insertions(+)
 create mode 100644 flang/test/Fir/cuf-invalid.fir
 create mode 100644 flang/test/Fir/cuf.mlir

diff --git a/flang/include/flang/Optimizer/Dialect/FIROps.td b/flang/include/flang/Optimizer/Dialect/FIROps.td
index dff1cdb20cbfef..5f1a097edf09cc 100644
--- a/flang/include/flang/Optimizer/Dialect/FIROps.td
+++ b/flang/include/flang/Optimizer/Dialect/FIROps.td
@@ -3190,4 +3190,35 @@ def fir_CUDADataTransferOp : fir_Op<"cuda_data_transfer", []> {
   }];
 }
 
+def fir_CUDAAllocateOp : fir_Op<"cuda_allocate", [AttrSizedOperandSegments,
+    MemoryEffects<[MemAlloc<DefaultResource>]>]> {
+  let summary = "Perform the device allocation of data of an allocatable";
+
+  let description = [{
+    The fir.cuda_allocate operation performs the allocation on the device
+    of the data of an allocatable. The descriptor passed to the operation
+    is initialized before with the standard flang runtime calls.
+  }];
+
+  let arguments = (ins AnyRefOrBoxType:$box,
+                       Optional<AnyRefOrBoxType>:$errmsg,
+                       Optional<AnyIntegerType>:$stream,
+                       Optional<AnyRefOrBoxType>:$pinned,
+                       Optional<AnyRefOrBoxType>:$source,
+                       UnitAttr:$hasStat);
+
+  let results = (outs AnyIntegerType:$stat);
+
+  let assemblyFormat = [{
+    $box `:` qualified(type($box))
+    ( `source` `(` $source^ `:` qualified(type($source) )`)` )?
+    ( `errmsg` `(` $errmsg^ `:` type($errmsg) `)` )?
+    ( `stream` `(` $stream^ `:` type($stream) `)` )?
+    ( `pinned` `(` $pinned^ `:` type($pinned) `)` )?
+    attr-dict `->` type($stat)
+  }];
+
+  let hasVerifier = 1;
+}
+
 #endif
diff --git a/flang/include/flang/Optimizer/Dialect/FIRTypes.td b/flang/include/flang/Optimizer/Dialect/FIRTypes.td
index 4c6a8064991ab0..3b876e4642da9a 100644
--- a/flang/include/flang/Optimizer/Dialect/FIRTypes.td
+++ b/flang/include/flang/Optimizer/Dialect/FIRTypes.td
@@ -625,6 +625,7 @@ def AnyRefOrBoxLike : TypeConstraint<Or<[AnyReferenceLike.predicate,
 def AnyRefOrBox : TypeConstraint<Or<[fir_ReferenceType.predicate,
     fir_HeapType.predicate, fir_PointerType.predicate,
     IsBaseBoxTypePred]>, "any reference or box">;
+def AnyRefOrBoxType : Type<AnyRefOrBox.predicate, "any legal ref or box type">;
 
 def AnyShapeLike : TypeConstraint<Or<[fir_ShapeType.predicate,
     fir_ShapeShiftType.predicate]>, "any legal shape type">;
diff --git a/flang/lib/Optimizer/Dialect/FIROps.cpp b/flang/lib/Optimizer/Dialect/FIROps.cpp
index 8ab74103cb6a80..88710880174d21 100644
--- a/flang/lib/Optimizer/Dialect/FIROps.cpp
+++ b/flang/lib/Optimizer/Dialect/FIROps.cpp
@@ -3993,6 +3993,25 @@ mlir::LogicalResult fir::CUDAKernelOp::verify() {
   return mlir::success();
 }
 
+mlir::LogicalResult fir::CUDAAllocateOp::verify() {
+  if (getPinned() && getStream())
+    return emitOpError("pinned and stream cannot appears at the same time");
+  if (!fir::unwrapRefType(getBox().getType()).isa<fir::BaseBoxType>())
+    return emitOpError(
+        "expect box to be a reference to/or a class or box type value");
+  if (getSource() &&
+      !fir::unwrapRefType(getSource().getType()).isa<fir::BaseBoxType>())
+    return emitOpError(
+        "expect source to be a reference to/or a class or box type value");
+  if (getErrmsg() &&
+      !fir::unwrapRefType(getErrmsg().getType()).isa<fir::BoxType>())
+    return emitOpError(
+        "expect errmsg to be a reference to/or a box type value");
+  if (getErrmsg() && !getHasStat())
+    return emitOpError("expect stat attribute when errmsg is provided");
+  return mlir::success();
+}
+
 //===----------------------------------------------------------------------===//
 // FIROpsDialect
 //===----------------------------------------------------------------------===//
diff --git a/flang/test/Fir/cuf-invalid.fir b/flang/test/Fir/cuf-invalid.fir
new file mode 100644
index 00000000000000..a5258acffa123e
--- /dev/null
+++ b/flang/test/Fir/cuf-invalid.fir
@@ -0,0 +1,50 @@
+// RUN: fir-opt -split-input-file -verify-diagnostics %s
+
+func.func @_QPsub1() {
+  %0 = fir.alloca !fir.box<!fir.heap<!fir.array<?xf32>>> {bindc_name = "a", uniq_name = "_QFsub1Ea"}
+  %1 = fir.alloca i32
+  %pinned = fir.alloca i1
+  %4:2 = hlfir.declare %0 {cuda_attr = #fir.cuda<device>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFsub1Ea"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>)
+  %11 = fir.convert %4#1 : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>) -> !fir.ref<!fir.box<none>>
+  %s = fir.load %1 : !fir.ref<i32>
+  // expected-error at +1{{'fir.cuda_allocate' op pinned and stream cannot appears at the same time}}
+  %13 = fir.cuda_allocate %11 : !fir.ref<!fir.box<none>> stream(%s : i32) pinned(%pinned : !fir.ref<i1>) -> i32
+  return
+}
+
+// -----
+
+func.func @_QPsub1() {
+  %1 = fir.alloca i32
+  // expected-error at +1{{'fir.cuda_allocate' op expect box to be a reference to/or a class or box type value}}
+  %2 = fir.cuda_allocate %1 : !fir.ref<i32> -> i32
+  return
+}
+
+// -----
+
+func.func @_QPsub1() {
+  %0 = fir.alloca !fir.box<!fir.heap<!fir.array<?xf32>>> {bindc_name = "a", uniq_name = "_QFsub1Ea"}
+  %4:2 = hlfir.declare %0 {cuda_attr = #fir.cuda<device>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFsub1Ea"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>)
+  %c100 = arith.constant 100 : index
+  %7 = fir.alloca !fir.char<1,100> {bindc_name = "msg", uniq_name = "_QFsub1Emsg"}
+  %8:2 = hlfir.declare %7 typeparams %c100 {uniq_name = "_QFsub1Emsg"} : (!fir.ref<!fir.char<1,100>>, index) -> (!fir.ref<!fir.char<1,100>>, !fir.ref<!fir.char<1,100>>)
+  %9 = fir.embox %8#1 : (!fir.ref<!fir.char<1,100>>) -> !fir.box<!fir.char<1,100>>
+  %11 = fir.convert %4#1 : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>) -> !fir.ref<!fir.box<none>>
+  %16 = fir.convert %9 : (!fir.box<!fir.char<1,100>>) -> !fir.box<none>
+  // expected-error at +1{{'fir.cuda_allocate' op expect stat attribute when errmsg is provided}}
+  %13 = fir.cuda_allocate %11 : !fir.ref<!fir.box<none>> errmsg(%16 : !fir.box<none>) -> i32
+  return
+}
+
+// -----
+
+func.func @_QPsub1() {
+  %0 = fir.alloca !fir.box<!fir.heap<!fir.array<?xf32>>> {bindc_name = "a", uniq_name = "_QFsub1Ea"}
+  %4:2 = hlfir.declare %0 {cuda_attr = #fir.cuda<device>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFsub1Ea"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>)
+  %1 = fir.alloca i32
+  %11 = fir.convert %4#1 : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>) -> !fir.ref<!fir.box<none>>
+  // expected-error at +1{{'fir.cuda_allocate' op expect errmsg to be a reference to/or a box type value}}
+  %13 = fir.cuda_allocate %11 : !fir.ref<!fir.box<none>> errmsg(%1 : !fir.ref<i32>) {hasStat} -> i32
+  return
+}
diff --git a/flang/test/Fir/cuf.mlir b/flang/test/Fir/cuf.mlir
new file mode 100644
index 00000000000000..bf4f866d464bbd
--- /dev/null
+++ b/flang/test/Fir/cuf.mlir
@@ -0,0 +1,70 @@
+// RUN: fir-opt --split-input-file %s | fir-opt --split-input-file | FileCheck %s
+
+// Simple round trip test of operations.
+
+func.func @_QPsub1() {
+  %0 = fir.alloca !fir.box<!fir.heap<!fir.array<?xf32>>> {bindc_name = "a", uniq_name = "_QFsub1Ea"}
+  %4:2 = hlfir.declare %0 {cuda_attr = #fir.cuda<device>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFsub1Ea"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>)
+  %11 = fir.convert %4#1 : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>) -> !fir.ref<!fir.box<none>>
+  %13 = fir.cuda_allocate %11 : !fir.ref<!fir.box<none>> -> i32
+  return
+}
+
+// CHECK: fir.cuda_allocate %{{.*}} : !fir.ref<!fir.box<none>> -> i32
+
+// -----
+
+func.func @_QPsub1() {
+  %0 = fir.alloca !fir.box<!fir.heap<!fir.array<?xf32>>> {bindc_name = "a", uniq_name = "_QFsub1Ea"}
+  %1 = fir.alloca i32
+  %4:2 = hlfir.declare %0 {cuda_attr = #fir.cuda<device>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFsub1Ea"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>)
+  %11 = fir.convert %4#1 : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>) -> !fir.ref<!fir.box<none>>
+  %s = fir.load %1 : !fir.ref<i32>
+  %13 = fir.cuda_allocate %11 : !fir.ref<!fir.box<none>> stream(%s : i32) -> i32
+  return
+}
+
+// CHECK: fir.cuda_allocate %{{.*}} : !fir.ref<!fir.box<none>> stream(%{{.*}} : i32) -> i32
+
+// -----
+
+func.func @_QPsub1() {
+  %0 = fir.alloca !fir.box<!fir.heap<!fir.array<?xf32>>> {bindc_name = "a", uniq_name = "_QFsub1Ea"}
+  %1 = fir.alloca !fir.box<!fir.heap<!fir.array<?xf32>>> {bindc_name = "b", uniq_name = "_QFsub1Eb"}
+  %4:2 = hlfir.declare %0 {cuda_attr = #fir.cuda<device>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFsub1Ea"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>)
+  %5:2 = hlfir.declare %1 {fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFsub1Ea"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>)
+  %11 = fir.convert %4#1 : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>) -> !fir.ref<!fir.box<none>>
+  %12 = fir.convert %5#1 : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>) -> !fir.ref<!fir.box<none>>
+  %13 = fir.cuda_allocate %11 : !fir.ref<!fir.box<none>> source(%12 : !fir.ref<!fir.box<none>>) -> i32
+  return
+}
+
+// CHECK: fir.cuda_allocate %{{.*}} : !fir.ref<!fir.box<none>> source(%{{.*}} : !fir.ref<!fir.box<none>>) -> i32
+
+// -----
+
+func.func @_QPsub1() {
+  %0 = fir.alloca !fir.box<!fir.heap<!fir.array<?xf32>>> {bindc_name = "a", uniq_name = "_QFsub1Ea"}
+  %pinned = fir.alloca i1
+  %4:2 = hlfir.declare %0 {cuda_attr = #fir.cuda<device>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFsub1Ea"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>)
+  %11 = fir.convert %4#1 : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>) -> !fir.ref<!fir.box<none>>
+  %13 = fir.cuda_allocate %11 : !fir.ref<!fir.box<none>> pinned(%pinned : !fir.ref<i1>) -> i32
+  return
+}
+
+// CHECK: fir.cuda_allocate %{{.*}} : !fir.ref<!fir.box<none>> pinned(%{{.*}} : !fir.ref<i1>) -> i32
+
+// -----
+
+func.func @_QPsub1() {
+  %0 = fir.alloca !fir.box<!fir.heap<!fir.array<?xf32>>> {bindc_name = "a", uniq_name = "_QFsub1Ea"}
+  %4:2 = hlfir.declare %0 {cuda_attr = #fir.cuda<device>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFsub1Ea"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>)
+  %c100 = arith.constant 100 : index
+  %7 = fir.alloca !fir.char<1,100> {bindc_name = "msg", uniq_name = "_QFsub1Emsg"}
+  %8:2 = hlfir.declare %7 typeparams %c100 {uniq_name = "_QFsub1Emsg"} : (!fir.ref<!fir.char<1,100>>, index) -> (!fir.ref<!fir.char<1,100>>, !fir.ref<!fir.char<1,100>>)
+  %9 = fir.embox %8#1 : (!fir.ref<!fir.char<1,100>>) -> !fir.box<!fir.char<1,100>>
+  %11 = fir.convert %4#1 : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>) -> !fir.ref<!fir.box<none>>
+  %16 = fir.convert %9 : (!fir.box<!fir.char<1,100>>) -> !fir.box<none>
+  %13 = fir.cuda_allocate %11 : !fir.ref<!fir.box<none>> errmsg(%16 : !fir.box<none>) {hasStat} -> i32
+  return
+}



More information about the flang-commits mailing list