[flang-commits] [flang] [flang][cuda] Add fir.cuda_allocate operation (PR #88586)

via flang-commits flang-commits at lists.llvm.org
Fri Apr 12 15:40:42 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-flang-fir-hlfir

Author: Valentin Clement (バレンタイン クレメン) (clementval)

<details>
<summary>Changes</summary>

Allocatable with cuda device attribute have special semantic for the allocate statement. In flang the allocate statement is lowered to a sequence of runtime call initializing the descriptor and then allocating the descriptor data. This new operation will replace the last runtime call and abstract all the device memory allocation needed. 
The lowering patch will follow. 

---
Full diff: https://github.com/llvm/llvm-project/pull/88586.diff


5 Files Affected:

- (modified) flang/include/flang/Optimizer/Dialect/FIROps.td (+31) 
- (modified) flang/include/flang/Optimizer/Dialect/FIRTypes.td (+1) 
- (modified) flang/lib/Optimizer/Dialect/FIROps.cpp (+19) 
- (added) flang/test/Fir/cuf-invalid.fir (+50) 
- (added) flang/test/Fir/cuf.mlir (+70) 


``````````diff
diff --git a/flang/include/flang/Optimizer/Dialect/FIROps.td b/flang/include/flang/Optimizer/Dialect/FIROps.td
index dff1cdb20cbfef..5f1a097edf09cc 100644
--- a/flang/include/flang/Optimizer/Dialect/FIROps.td
+++ b/flang/include/flang/Optimizer/Dialect/FIROps.td
@@ -3190,4 +3190,35 @@ def fir_CUDADataTransferOp : fir_Op<"cuda_data_transfer", []> {
   }];
 }
 
+def fir_CUDAAllocateOp : fir_Op<"cuda_allocate", [AttrSizedOperandSegments,
+    MemoryEffects<[MemAlloc<DefaultResource>]>]> {
+  let summary = "Perform the device allocation of data of an allocatable";
+
+  let description = [{
+    The fir.cuda_allocate operation performs the allocation on the device
+    of the data of an allocatable. The descriptor passed to the operation
+    is initialized before with the standard flang runtime calls.
+  }];
+
+  let arguments = (ins AnyRefOrBoxType:$box,
+                       Optional<AnyRefOrBoxType>:$errmsg,
+                       Optional<AnyIntegerType>:$stream,
+                       Optional<AnyRefOrBoxType>:$pinned,
+                       Optional<AnyRefOrBoxType>:$source,
+                       UnitAttr:$hasStat);
+
+  let results = (outs AnyIntegerType:$stat);
+
+  let assemblyFormat = [{
+    $box `:` qualified(type($box))
+    ( `source` `(` $source^ `:` qualified(type($source) )`)` )?
+    ( `errmsg` `(` $errmsg^ `:` type($errmsg) `)` )?
+    ( `stream` `(` $stream^ `:` type($stream) `)` )?
+    ( `pinned` `(` $pinned^ `:` type($pinned) `)` )?
+    attr-dict `->` type($stat)
+  }];
+
+  let hasVerifier = 1;
+}
+
 #endif
diff --git a/flang/include/flang/Optimizer/Dialect/FIRTypes.td b/flang/include/flang/Optimizer/Dialect/FIRTypes.td
index 4c6a8064991ab0..3b876e4642da9a 100644
--- a/flang/include/flang/Optimizer/Dialect/FIRTypes.td
+++ b/flang/include/flang/Optimizer/Dialect/FIRTypes.td
@@ -625,6 +625,7 @@ def AnyRefOrBoxLike : TypeConstraint<Or<[AnyReferenceLike.predicate,
 def AnyRefOrBox : TypeConstraint<Or<[fir_ReferenceType.predicate,
     fir_HeapType.predicate, fir_PointerType.predicate,
     IsBaseBoxTypePred]>, "any reference or box">;
+def AnyRefOrBoxType : Type<AnyRefOrBox.predicate, "any legal ref or box type">;
 
 def AnyShapeLike : TypeConstraint<Or<[fir_ShapeType.predicate,
     fir_ShapeShiftType.predicate]>, "any legal shape type">;
diff --git a/flang/lib/Optimizer/Dialect/FIROps.cpp b/flang/lib/Optimizer/Dialect/FIROps.cpp
index 8ab74103cb6a80..88710880174d21 100644
--- a/flang/lib/Optimizer/Dialect/FIROps.cpp
+++ b/flang/lib/Optimizer/Dialect/FIROps.cpp
@@ -3993,6 +3993,25 @@ mlir::LogicalResult fir::CUDAKernelOp::verify() {
   return mlir::success();
 }
 
+mlir::LogicalResult fir::CUDAAllocateOp::verify() {
+  if (getPinned() && getStream())
+    return emitOpError("pinned and stream cannot appears at the same time");
+  if (!fir::unwrapRefType(getBox().getType()).isa<fir::BaseBoxType>())
+    return emitOpError(
+        "expect box to be a reference to/or a class or box type value");
+  if (getSource() &&
+      !fir::unwrapRefType(getSource().getType()).isa<fir::BaseBoxType>())
+    return emitOpError(
+        "expect source to be a reference to/or a class or box type value");
+  if (getErrmsg() &&
+      !fir::unwrapRefType(getErrmsg().getType()).isa<fir::BoxType>())
+    return emitOpError(
+        "expect errmsg to be a reference to/or a box type value");
+  if (getErrmsg() && !getHasStat())
+    return emitOpError("expect stat attribute when errmsg is provided");
+  return mlir::success();
+}
+
 //===----------------------------------------------------------------------===//
 // FIROpsDialect
 //===----------------------------------------------------------------------===//
diff --git a/flang/test/Fir/cuf-invalid.fir b/flang/test/Fir/cuf-invalid.fir
new file mode 100644
index 00000000000000..a5258acffa123e
--- /dev/null
+++ b/flang/test/Fir/cuf-invalid.fir
@@ -0,0 +1,50 @@
+// RUN: fir-opt -split-input-file -verify-diagnostics %s
+
+func.func @_QPsub1() {
+  %0 = fir.alloca !fir.box<!fir.heap<!fir.array<?xf32>>> {bindc_name = "a", uniq_name = "_QFsub1Ea"}
+  %1 = fir.alloca i32
+  %pinned = fir.alloca i1
+  %4:2 = hlfir.declare %0 {cuda_attr = #fir.cuda<device>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFsub1Ea"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>)
+  %11 = fir.convert %4#1 : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>) -> !fir.ref<!fir.box<none>>
+  %s = fir.load %1 : !fir.ref<i32>
+  // expected-error at +1{{'fir.cuda_allocate' op pinned and stream cannot appears at the same time}}
+  %13 = fir.cuda_allocate %11 : !fir.ref<!fir.box<none>> stream(%s : i32) pinned(%pinned : !fir.ref<i1>) -> i32
+  return
+}
+
+// -----
+
+func.func @_QPsub1() {
+  %1 = fir.alloca i32
+  // expected-error at +1{{'fir.cuda_allocate' op expect box to be a reference to/or a class or box type value}}
+  %2 = fir.cuda_allocate %1 : !fir.ref<i32> -> i32
+  return
+}
+
+// -----
+
+func.func @_QPsub1() {
+  %0 = fir.alloca !fir.box<!fir.heap<!fir.array<?xf32>>> {bindc_name = "a", uniq_name = "_QFsub1Ea"}
+  %4:2 = hlfir.declare %0 {cuda_attr = #fir.cuda<device>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFsub1Ea"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>)
+  %c100 = arith.constant 100 : index
+  %7 = fir.alloca !fir.char<1,100> {bindc_name = "msg", uniq_name = "_QFsub1Emsg"}
+  %8:2 = hlfir.declare %7 typeparams %c100 {uniq_name = "_QFsub1Emsg"} : (!fir.ref<!fir.char<1,100>>, index) -> (!fir.ref<!fir.char<1,100>>, !fir.ref<!fir.char<1,100>>)
+  %9 = fir.embox %8#1 : (!fir.ref<!fir.char<1,100>>) -> !fir.box<!fir.char<1,100>>
+  %11 = fir.convert %4#1 : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>) -> !fir.ref<!fir.box<none>>
+  %16 = fir.convert %9 : (!fir.box<!fir.char<1,100>>) -> !fir.box<none>
+  // expected-error at +1{{'fir.cuda_allocate' op expect stat attribute when errmsg is provided}}
+  %13 = fir.cuda_allocate %11 : !fir.ref<!fir.box<none>> errmsg(%16 : !fir.box<none>) -> i32
+  return
+}
+
+// -----
+
+func.func @_QPsub1() {
+  %0 = fir.alloca !fir.box<!fir.heap<!fir.array<?xf32>>> {bindc_name = "a", uniq_name = "_QFsub1Ea"}
+  %4:2 = hlfir.declare %0 {cuda_attr = #fir.cuda<device>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFsub1Ea"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>)
+  %1 = fir.alloca i32
+  %11 = fir.convert %4#1 : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>) -> !fir.ref<!fir.box<none>>
+  // expected-error at +1{{'fir.cuda_allocate' op expect errmsg to be a reference to/or a box type value}}
+  %13 = fir.cuda_allocate %11 : !fir.ref<!fir.box<none>> errmsg(%1 : !fir.ref<i32>) {hasStat} -> i32
+  return
+}
diff --git a/flang/test/Fir/cuf.mlir b/flang/test/Fir/cuf.mlir
new file mode 100644
index 00000000000000..bf4f866d464bbd
--- /dev/null
+++ b/flang/test/Fir/cuf.mlir
@@ -0,0 +1,70 @@
+// RUN: fir-opt --split-input-file %s | fir-opt --split-input-file | FileCheck %s
+
+// Simple round trip test of operations.
+
+func.func @_QPsub1() {
+  %0 = fir.alloca !fir.box<!fir.heap<!fir.array<?xf32>>> {bindc_name = "a", uniq_name = "_QFsub1Ea"}
+  %4:2 = hlfir.declare %0 {cuda_attr = #fir.cuda<device>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFsub1Ea"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>)
+  %11 = fir.convert %4#1 : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>) -> !fir.ref<!fir.box<none>>
+  %13 = fir.cuda_allocate %11 : !fir.ref<!fir.box<none>> -> i32
+  return
+}
+
+// CHECK: fir.cuda_allocate %{{.*}} : !fir.ref<!fir.box<none>> -> i32
+
+// -----
+
+func.func @_QPsub1() {
+  %0 = fir.alloca !fir.box<!fir.heap<!fir.array<?xf32>>> {bindc_name = "a", uniq_name = "_QFsub1Ea"}
+  %1 = fir.alloca i32
+  %4:2 = hlfir.declare %0 {cuda_attr = #fir.cuda<device>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFsub1Ea"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>)
+  %11 = fir.convert %4#1 : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>) -> !fir.ref<!fir.box<none>>
+  %s = fir.load %1 : !fir.ref<i32>
+  %13 = fir.cuda_allocate %11 : !fir.ref<!fir.box<none>> stream(%s : i32) -> i32
+  return
+}
+
+// CHECK: fir.cuda_allocate %{{.*}} : !fir.ref<!fir.box<none>> stream(%{{.*}} : i32) -> i32
+
+// -----
+
+func.func @_QPsub1() {
+  %0 = fir.alloca !fir.box<!fir.heap<!fir.array<?xf32>>> {bindc_name = "a", uniq_name = "_QFsub1Ea"}
+  %1 = fir.alloca !fir.box<!fir.heap<!fir.array<?xf32>>> {bindc_name = "b", uniq_name = "_QFsub1Eb"}
+  %4:2 = hlfir.declare %0 {cuda_attr = #fir.cuda<device>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFsub1Ea"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>)
+  %5:2 = hlfir.declare %1 {fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFsub1Ea"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>)
+  %11 = fir.convert %4#1 : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>) -> !fir.ref<!fir.box<none>>
+  %12 = fir.convert %5#1 : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>) -> !fir.ref<!fir.box<none>>
+  %13 = fir.cuda_allocate %11 : !fir.ref<!fir.box<none>> source(%12 : !fir.ref<!fir.box<none>>) -> i32
+  return
+}
+
+// CHECK: fir.cuda_allocate %{{.*}} : !fir.ref<!fir.box<none>> source(%{{.*}} : !fir.ref<!fir.box<none>>) -> i32
+
+// -----
+
+func.func @_QPsub1() {
+  %0 = fir.alloca !fir.box<!fir.heap<!fir.array<?xf32>>> {bindc_name = "a", uniq_name = "_QFsub1Ea"}
+  %pinned = fir.alloca i1
+  %4:2 = hlfir.declare %0 {cuda_attr = #fir.cuda<device>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFsub1Ea"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>)
+  %11 = fir.convert %4#1 : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>) -> !fir.ref<!fir.box<none>>
+  %13 = fir.cuda_allocate %11 : !fir.ref<!fir.box<none>> pinned(%pinned : !fir.ref<i1>) -> i32
+  return
+}
+
+// CHECK: fir.cuda_allocate %{{.*}} : !fir.ref<!fir.box<none>> pinned(%{{.*}} : !fir.ref<i1>) -> i32
+
+// -----
+
+func.func @_QPsub1() {
+  %0 = fir.alloca !fir.box<!fir.heap<!fir.array<?xf32>>> {bindc_name = "a", uniq_name = "_QFsub1Ea"}
+  %4:2 = hlfir.declare %0 {cuda_attr = #fir.cuda<device>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFsub1Ea"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>)
+  %c100 = arith.constant 100 : index
+  %7 = fir.alloca !fir.char<1,100> {bindc_name = "msg", uniq_name = "_QFsub1Emsg"}
+  %8:2 = hlfir.declare %7 typeparams %c100 {uniq_name = "_QFsub1Emsg"} : (!fir.ref<!fir.char<1,100>>, index) -> (!fir.ref<!fir.char<1,100>>, !fir.ref<!fir.char<1,100>>)
+  %9 = fir.embox %8#1 : (!fir.ref<!fir.char<1,100>>) -> !fir.box<!fir.char<1,100>>
+  %11 = fir.convert %4#1 : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>) -> !fir.ref<!fir.box<none>>
+  %16 = fir.convert %9 : (!fir.box<!fir.char<1,100>>) -> !fir.box<none>
+  %13 = fir.cuda_allocate %11 : !fir.ref<!fir.box<none>> errmsg(%16 : !fir.box<none>) {hasStat} -> i32
+  return
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/88586


More information about the flang-commits mailing list