[Mlir-commits] [flang] [mlir] [flang][openacc] Add ability to link acc.declare_enter with acc.declare_exit ops (PR #72476)

Valentin Clement バレンタイン クレメン llvmlistbot at llvm.org
Wed Nov 15 21:42:26 PST 2023


https://github.com/clementval updated https://github.com/llvm/llvm-project/pull/72476

>From 1c46556685a440cbb482ddc02391ab8f15b95eb0 Mon Sep 17 00:00:00 2001
From: Valentin Clement <clementval at gmail.com>
Date: Wed, 15 Nov 2023 15:28:09 -0800
Subject: [PATCH 1/3] [flang][openacc] Add ability to link acc.declare_enter
 with acc.declare_exit

---
 flang/lib/Lower/OpenACC.cpp                   | 45 ++++++++++++-------
 .../test/Lower/OpenACC/HLFIR/acc-declare.f90  | 28 ++++++------
 .../mlir/Dialect/OpenACC/OpenACCOps.td        |  8 +++-
 .../mlir/Dialect/OpenACC/OpenACCOpsTypes.td   |  9 ++++
 mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp       |  8 +++-
 5 files changed, 63 insertions(+), 35 deletions(-)

diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp
index e470154ce8c2d0b..8c6c22210cf0894 100644
--- a/flang/lib/Lower/OpenACC.cpp
+++ b/flang/lib/Lower/OpenACC.cpp
@@ -163,7 +163,8 @@ static void createDeclareAllocFuncWithArg(mlir::OpBuilder &modBuilder,
       builder, loc, boxAddrOp.getResult(), asFortran, bounds,
       /*structured=*/false, /*implicit=*/false, clause, boxAddrOp.getType());
   builder.create<mlir::acc::DeclareEnterOp>(
-      loc, mlir::ValueRange(entryOp.getAccPtr()));
+      loc, mlir::acc::DeclareTokenType::get(entryOp.getContext()),
+      mlir::ValueRange(entryOp.getAccPtr()));
 
   modBuilder.setInsertionPointAfter(registerFuncOp);
   builder.restoreInsertionPoint(crtInsPt);
@@ -195,7 +196,7 @@ static void createDeclareDeallocFuncWithArg(
           /*structured=*/false, /*implicit=*/false, clause,
           boxAddrOp.getType());
   builder.create<mlir::acc::DeclareExitOp>(
-      loc, mlir::ValueRange(entryOp.getAccPtr()));
+      loc, mlir::Value{}, mlir::ValueRange(entryOp.getAccPtr()));
 
   mlir::Value varPtr;
   if constexpr (std::is_same_v<ExitOp, mlir::acc::CopyoutOp> ||
@@ -2762,7 +2763,13 @@ static void createDeclareGlobalOp(mlir::OpBuilder &modBuilder,
   EntryOp entryOp = createDataEntryOp<EntryOp>(
       builder, loc, addrOp.getResTy(), asFortran, bounds,
       /*structured=*/false, implicit, clause, addrOp.getResTy().getType());
-  builder.create<DeclareOp>(loc, mlir::ValueRange(entryOp.getAccPtr()));
+  if constexpr (std::is_same_v<DeclareOp, mlir::acc::DeclareEnterOp>)
+    builder.create<DeclareOp>(
+        loc, mlir::acc::DeclareTokenType::get(entryOp.getContext()),
+        mlir::ValueRange(entryOp.getAccPtr()));
+  else
+    builder.create<DeclareOp>(loc, mlir::Value{},
+                              mlir::ValueRange(entryOp.getAccPtr()));
   mlir::Value varPtr;
   if constexpr (std::is_same_v<GlobalOp, mlir::acc::GlobalDestructorOp>) {
     builder.create<ExitOp>(entryOp.getLoc(), entryOp.getAccPtr(), varPtr,
@@ -2812,7 +2819,8 @@ static void createDeclareAllocFunc(mlir::OpBuilder &modBuilder,
       builder, loc, boxAddrOp.getResult(), asFortran, bounds,
       /*structured=*/false, /*implicit=*/false, clause, boxAddrOp.getType());
   builder.create<mlir::acc::DeclareEnterOp>(
-      loc, mlir::ValueRange(entryOp.getAccPtr()));
+      loc, mlir::acc::DeclareTokenType::get(entryOp.getContext()),
+      mlir::ValueRange(entryOp.getAccPtr()));
 
   modBuilder.setInsertionPointAfter(registerFuncOp);
 }
@@ -2850,7 +2858,7 @@ static void createDeclareDeallocFunc(mlir::OpBuilder &modBuilder,
           boxAddrOp.getType());
 
   builder.create<mlir::acc::DeclareExitOp>(
-      loc, mlir::ValueRange(entryOp.getAccPtr()));
+      loc, mlir::Value{}, mlir::ValueRange(entryOp.getAccPtr()));
 
   mlir::Value varPtr;
   if constexpr (std::is_same_v<ExitOp, mlir::acc::CopyoutOp> ||
@@ -3092,34 +3100,37 @@ genDeclareInFunction(Fortran::lower::AbstractConverter &converter,
 
   mlir::func::FuncOp funcOp = builder.getFunction();
   auto ops = funcOp.getOps<mlir::acc::DeclareEnterOp>();
+  mlir::Value declareToken;
   if (ops.empty()) {
-    builder.create<mlir::acc::DeclareEnterOp>(loc, dataClauseOperands);
+    declareToken = builder.create<mlir::acc::DeclareEnterOp>(
+        loc, mlir::acc::DeclareTokenType::get(builder.getContext()),
+        dataClauseOperands);
   } else {
     auto declareOp = *ops.begin();
     auto newDeclareOp = builder.create<mlir::acc::DeclareEnterOp>(
-        loc, declareOp.getDataClauseOperands());
+        loc, mlir::acc::DeclareTokenType::get(builder.getContext()),
+        declareOp.getDataClauseOperands());
     newDeclareOp.getDataClauseOperandsMutable().append(dataClauseOperands);
+    declareToken = newDeclareOp.getToken();
     declareOp.erase();
   }
 
   openAccCtx.attachCleanup([&builder, loc, createEntryOperands,
                             copyEntryOperands, copyoutEntryOperands,
-                            deviceResidentEntryOperands]() {
+                            deviceResidentEntryOperands, declareToken]() {
     llvm::SmallVector<mlir::Value> operands;
     operands.append(createEntryOperands);
     operands.append(deviceResidentEntryOperands);
     operands.append(copyEntryOperands);
     operands.append(copyoutEntryOperands);
 
-    if (!operands.empty()) {
-      mlir::func::FuncOp funcOp = builder.getFunction();
-      auto ops = funcOp.getOps<mlir::acc::DeclareExitOp>();
-      if (ops.empty()) {
-        builder.create<mlir::acc::DeclareExitOp>(loc, operands);
-      } else {
-        auto declareOp = *ops.begin();
-        declareOp.getDataClauseOperandsMutable().append(operands);
-      }
+    mlir::func::FuncOp funcOp = builder.getFunction();
+    auto ops = funcOp.getOps<mlir::acc::DeclareExitOp>();
+    if (ops.empty()) {
+      builder.create<mlir::acc::DeclareExitOp>(loc, declareToken, operands);
+    } else {
+      auto declareOp = *ops.begin();
+      declareOp.getDataClauseOperandsMutable().append(operands);
     }
 
     genDataExitOperations<mlir::acc::CreateOp, mlir::acc::DeleteOp>(
diff --git a/flang/test/Lower/OpenACC/HLFIR/acc-declare.f90 b/flang/test/Lower/OpenACC/HLFIR/acc-declare.f90
index 489f81d297df4d2..b0a78fbd5439fb9 100644
--- a/flang/test/Lower/OpenACC/HLFIR/acc-declare.f90
+++ b/flang/test/Lower/OpenACC/HLFIR/acc-declare.f90
@@ -24,11 +24,11 @@ subroutine acc_declare_copy()
 ! ALL: %[[BOUND:.*]] = acc.bounds   lowerbound(%{{.*}} : index) upperbound(%{{.*}} : index) extent(%{{.*}} : index) stride(%[[C1]] : index) startIdx(%[[C1]] : index)
 ! FIR: %[[COPYIN:.*]] = acc.copyin varPtr(%[[DECL]] : !fir.ref<!fir.array<100xi32>>) bounds(%[[BOUND]]) -> !fir.ref<!fir.array<100xi32>> {dataClause = #acc<data_clause acc_copy>, name = "a"}
 ! HLFIR: %[[COPYIN:.*]] = acc.copyin varPtr(%[[DECL]]#1 : !fir.ref<!fir.array<100xi32>>) bounds(%[[BOUND]]) -> !fir.ref<!fir.array<100xi32>> {dataClause = #acc<data_clause acc_copy>, name = "a"}
-! ALL: acc.declare_enter dataOperands(%[[COPYIN]] : !fir.ref<!fir.array<100xi32>>)
+! ALL: %[[TOKEN:.*]] = acc.declare_enter dataOperands(%[[COPYIN]] : !fir.ref<!fir.array<100xi32>>)
 
 ! ALL: %{{.*}}:2 = fir.do_loop %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%{{.*}} = %{{.*}}) -> (index, i32) {
 ! ALL: }
-! ALL: acc.declare_exit dataOperands(%[[COPYIN]] : !fir.ref<!fir.array<100xi32>>)
+! ALL: acc.declare_exit token(%[[TOKEN]]) dataOperands(%[[COPYIN]] : !fir.ref<!fir.array<100xi32>>)
 ! FIR: acc.copyout accPtr(%[[COPYIN]] : !fir.ref<!fir.array<100xi32>>)   bounds(%[[BOUND]]) to varPtr(%[[DECL]] : !fir.ref<!fir.array<100xi32>>) {dataClause = #acc<data_clause acc_copy>, name = "a"}
 ! HLFIR: acc.copyout accPtr(%[[COPYIN]] : !fir.ref<!fir.array<100xi32>>) bounds(%[[BOUND]]) to varPtr(%[[DECL]]#1 : !fir.ref<!fir.array<100xi32>>) {dataClause = #acc<data_clause acc_copy>, name = "a"}
 
@@ -51,11 +51,11 @@ subroutine acc_declare_create()
 ! ALL: %[[BOUND:.*]] = acc.bounds   lowerbound(%{{.*}} : index) upperbound(%{{.*}} : index) extent(%{{.*}} : index) stride(%[[C1]] : index) startIdx(%[[C1]] : index)
 ! FIR: %[[CREATE:.*]] = acc.create varPtr(%[[DECL]] : !fir.ref<!fir.array<100xi32>>) bounds(%[[BOUND]]) -> !fir.ref<!fir.array<100xi32>> {name = "a"}
 ! HLFIR: %[[CREATE:.*]] = acc.create varPtr(%[[DECL]]#1 : !fir.ref<!fir.array<100xi32>>)   bounds(%[[BOUND]]) -> !fir.ref<!fir.array<100xi32>> {name = "a"}
-! ALL: acc.declare_enter dataOperands(%[[CREATE]] : !fir.ref<!fir.array<100xi32>>)
+! ALL: %[[TOKEN:.*]] = acc.declare_enter dataOperands(%[[CREATE]] : !fir.ref<!fir.array<100xi32>>)
 
 ! ALL: %{{.*}}:2 = fir.do_loop %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%{{.*}} = %{{.*}}) -> (index, i32) {
 ! ALL: }
-! ALL: acc.declare_exit dataOperands(%[[CREATE]] : !fir.ref<!fir.array<100xi32>>)
+! ALL: acc.declare_exit token(%[[TOKEN]]) dataOperands(%[[CREATE]] : !fir.ref<!fir.array<100xi32>>)
 ! ALL: acc.delete accPtr(%[[CREATE]] : !fir.ref<!fir.array<100xi32>>) bounds(%[[BOUND]]) {dataClause = #acc<data_clause acc_create>, name = "a"}
 ! ALL: return
 
@@ -119,10 +119,10 @@ subroutine acc_declare_copyout()
 ! HLFIR: %[[ADECL:.*]]:2 = hlfir.declare %[[A]](%{{.*}}) {acc.declare = #acc.declare<dataClause =  acc_copyout>, uniq_name = "_QMacc_declareFacc_declare_copyoutEa"} : (!fir.ref<!fir.array<100xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<100xi32>>, !fir.ref<!fir.array<100xi32>>)
 ! FIR: %[[CREATE:.*]] = acc.create varPtr(%[[ADECL]] : !fir.ref<!fir.array<100xi32>>) bounds(%{{.*}}) -> !fir.ref<!fir.array<100xi32>> {dataClause = #acc<data_clause acc_copyout>, name = "a"}
 ! HLFIR: %[[CREATE:.*]] = acc.create varPtr(%[[ADECL]]#1 : !fir.ref<!fir.array<100xi32>>)   bounds(%{{.*}}) -> !fir.ref<!fir.array<100xi32>> {dataClause = #acc<data_clause acc_copyout>, name = "a"}
-! ALL: acc.declare_enter dataOperands(%[[CREATE]] : !fir.ref<!fir.array<100xi32>>)
+! ALL: %[[TOKEN:.*]] = acc.declare_enter dataOperands(%[[CREATE]] : !fir.ref<!fir.array<100xi32>>)
 ! ALL: %{{.*}}:2 = fir.do_loop %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%arg{{.*}} = %{{.*}}) -> (index, i32)
 
-! ALL: acc.declare_exit dataOperands(%[[CREATE]] : !fir.ref<!fir.array<100xi32>>)
+! ALL: acc.declare_exit token(%[[TOKEN]]) dataOperands(%[[CREATE]] : !fir.ref<!fir.array<100xi32>>)
 ! FIR: acc.copyout accPtr(%[[CREATE]] : !fir.ref<!fir.array<100xi32>>) bounds(%{{.*}}) to varPtr(%[[ADECL]] : !fir.ref<!fir.array<100xi32>>) {name = "a"}
 ! HLFIR: acc.copyout accPtr(%[[CREATE]] : !fir.ref<!fir.array<100xi32>>)   bounds(%{{.*}}) to varPtr(%[[ADECL]]#1 : !fir.ref<!fir.array<100xi32>>) {name = "a"}
 ! ALL: return
@@ -178,9 +178,9 @@ subroutine acc_declare_device_resident(a)
 ! HLFIR: %[[DECL:.*]]:2 = hlfir.declare %[[ARG0]](%{{.*}}) {acc.declare = #acc.declare<dataClause =  acc_declare_device_resident>, uniq_name = "_QMacc_declareFacc_declare_device_residentEa"} : (!fir.ref<!fir.array<100xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<100xi32>>, !fir.ref<!fir.array<100xi32>>)
 ! FIR: %[[DEVICERES:.*]] = acc.declare_device_resident varPtr(%[[DECL]] : !fir.ref<!fir.array<100xi32>>) bounds(%{{.*}}) -> !fir.ref<!fir.array<100xi32>> {name = "a"}
 ! HLFIR: %[[DEVICERES:.*]] = acc.declare_device_resident varPtr(%[[DECL]]#1 : !fir.ref<!fir.array<100xi32>>)   bounds(%{{.*}}) -> !fir.ref<!fir.array<100xi32>> {name = "a"}
-! ALL: acc.declare_enter dataOperands(%[[DEVICERES]] : !fir.ref<!fir.array<100xi32>>)
+! ALL: %[[TOKEN:.*]] = acc.declare_enter dataOperands(%[[DEVICERES]] : !fir.ref<!fir.array<100xi32>>)
 ! ALL: %{{.*}}:2 = fir.do_loop %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%arg{{.*}} = %{{.*}}) -> (index, i32)
-! ALL: acc.declare_exit dataOperands(%[[DEVICERES]] : !fir.ref<!fir.array<100xi32>>)
+! ALL: acc.declare_exit token(%[[TOKEN]]) dataOperands(%[[DEVICERES]] : !fir.ref<!fir.array<100xi32>>)
 ! ALL: acc.delete accPtr(%[[DEVICERES]] : !fir.ref<!fir.array<100xi32>>) bounds(%{{.*}}) {dataClause = #acc<data_clause acc_declare_device_resident>, name = "a"}
 
   subroutine acc_declare_device_resident2()
@@ -195,8 +195,8 @@ subroutine acc_declare_device_resident2()
 ! HLFIR: %[[DECL:.*]]:2 = hlfir.declare %[[ALLOCA]](%{{.*}}) {acc.declare = #acc.declare<dataClause =  acc_declare_device_resident>, uniq_name = "_QMacc_declareFacc_declare_device_resident2Edataparam"} : (!fir.ref<!fir.array<100xf32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<100xf32>>, !fir.ref<!fir.array<100xf32>>)
 ! FIR: %[[DEVICERES:.*]] = acc.declare_device_resident varPtr(%[[DECL]] : !fir.ref<!fir.array<100xf32>>) bounds(%{{.*}}) -> !fir.ref<!fir.array<100xf32>> {name = "dataparam"}
 ! HLFIR: %[[DEVICERES:.*]] = acc.declare_device_resident varPtr(%[[DECL]]#1 : !fir.ref<!fir.array<100xf32>>)   bounds(%{{.*}}) -> !fir.ref<!fir.array<100xf32>> {name = "dataparam"}
-! ALL: acc.declare_enter dataOperands(%[[DEVICERES]] : !fir.ref<!fir.array<100xf32>>)
-! ALL: acc.declare_exit dataOperands(%[[DEVICERES]] : !fir.ref<!fir.array<100xf32>>)
+! ALL: %[[TOKEN:.*]] = acc.declare_enter dataOperands(%[[DEVICERES]] : !fir.ref<!fir.array<100xf32>>)
+! ALL: acc.declare_exit token(%[[TOKEN]]) dataOperands(%[[DEVICERES]] : !fir.ref<!fir.array<100xf32>>)
 ! ALL: acc.delete accPtr(%[[DEVICERES]] : !fir.ref<!fir.array<100xf32>>) bounds(%{{.*}}) {dataClause = #acc<data_clause acc_declare_device_resident>, name = "dataparam"}
 
   subroutine acc_declare_link2()
@@ -234,10 +234,10 @@ function acc_declare_in_func()
 
 ! ALL-LABEL: func.func @_QMacc_declarePacc_declare_in_func() -> f32 {
 ! HLFIR: %[[DEVICE_RESIDENT:.*]] = acc.declare_device_resident varPtr(%{{.*}}#1 : !fir.ref<!fir.array<1024xf32>>) bounds(%{{.*}}) -> !fir.ref<!fir.array<1024xf32>> {name = "a"}
-! HLFIR: acc.declare_enter dataOperands(%[[DEVICE_RESIDENT]] : !fir.ref<!fir.array<1024xf32>>)
+! HLFIR: %[[TOKEN:.*]] = acc.declare_enter dataOperands(%[[DEVICE_RESIDENT]] : !fir.ref<!fir.array<1024xf32>>)
 
 ! HLFIR: %[[LOAD:.*]] = fir.load %{{.*}}#1 : !fir.ref<f32>
-! HLFIR: acc.declare_exit dataOperands(%[[DEVICE_RESIDENT]] : !fir.ref<!fir.array<1024xf32>>)
+! HLFIR: acc.declare_exit token(%[[TOKEN]]) dataOperands(%[[DEVICE_RESIDENT]] : !fir.ref<!fir.array<1024xf32>>)
 ! HLFIR: acc.delete accPtr(%[[DEVICE_RESIDENT]] : !fir.ref<!fir.array<1024xf32>>) bounds(%6) {dataClause = #acc<data_clause acc_declare_device_resident>, name = "a"}
 ! HLFIR: return %[[LOAD]] : f32
 ! ALL: }
@@ -254,10 +254,10 @@ function acc_declare_in_func2(i)
 ! HLFIR: %[[ALLOCA_A:.*]] = fir.alloca !fir.array<1024xf32> {bindc_name = "a", uniq_name = "_QMacc_declareFacc_declare_in_func2Ea"}
 ! HLFIR: %[[DECL_A:.*]]:2 = hlfir.declare %[[ALLOCA_A]](%{{.*}}) {acc.declare = #acc.declare<dataClause =  acc_create>, uniq_name = "_QMacc_declareFacc_declare_in_func2Ea"} : (!fir.ref<!fir.array<1024xf32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<1024xf32>>, !fir.ref<!fir.array<1024xf32>>)
 ! HLFIR: %[[CREATE:.*]] = acc.create varPtr(%[[DECL_A]]#1 : !fir.ref<!fir.array<1024xf32>>) bounds(%7) -> !fir.ref<!fir.array<1024xf32>> {name = "a"}
-! HLFIR: acc.declare_enter dataOperands(%[[CREATE]] : !fir.ref<!fir.array<1024xf32>>)
+! HLFIR: %[[TOKEN:.*]] = acc.declare_enter dataOperands(%[[CREATE]] : !fir.ref<!fir.array<1024xf32>>)
 ! HLFIR:   cf.br ^bb1
 ! HLFIR: ^bb1:
-! HLFIR: acc.declare_exit dataOperands(%[[CREATE]] : !fir.ref<!fir.array<1024xf32>>)
+! HLFIR: acc.declare_exit token(%[[TOKEN]]) dataOperands(%[[CREATE]] : !fir.ref<!fir.array<1024xf32>>)
 ! HLFIR: acc.delete accPtr(%[[CREATE]] : !fir.ref<!fir.array<1024xf32>>) bounds(%7) {dataClause = #acc<data_clause acc_create>, name = "a"}
 ! ALL:   return %{{.*}} : f32
 ! ALL: }
diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
index d0b52a0b4024172..391e77e0c4081a3 100644
--- a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
+++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
@@ -1430,6 +1430,7 @@ def OpenACC_DeclareEnterOp : OpenACC_Op<"declare_enter", []> {
   }];
 
   let arguments = (ins Variadic<OpenACC_PointerLikeTypeInterface>:$dataClauseOperands);
+  let results = (outs OpenACC_DeclareTokenType:$token);
 
   let assemblyFormat = [{
     oilist(
@@ -1441,7 +1442,7 @@ def OpenACC_DeclareEnterOp : OpenACC_Op<"declare_enter", []> {
   let hasVerifier = 1;
 }
 
-def OpenACC_DeclareExitOp : OpenACC_Op<"declare_exit", []> {
+def OpenACC_DeclareExitOp : OpenACC_Op<"declare_exit", [AttrSizedOperandSegments]> {
   let summary = "declare directive - exit from implicit data region";
 
   let description = [{
@@ -1458,10 +1459,13 @@ def OpenACC_DeclareExitOp : OpenACC_Op<"declare_exit", []> {
     ```
   }];
 
-  let arguments = (ins Variadic<OpenACC_PointerLikeTypeInterface>:$dataClauseOperands);
+  let arguments = (ins
+      Optional<OpenACC_DeclareTokenType>:$token,
+      Variadic<OpenACC_PointerLikeTypeInterface>:$dataClauseOperands);
 
   let assemblyFormat = [{
     oilist(
+        `token` `(` $token `)` |
         `dataOperands` `(` $dataClauseOperands `:` type($dataClauseOperands) `)`
     )
     attr-dict-with-keyword
diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCOpsTypes.td b/mlir/include/mlir/Dialect/OpenACC/OpenACCOpsTypes.td
index 4a930ad94c3f175..92ea71a7e8418de 100644
--- a/mlir/include/mlir/Dialect/OpenACC/OpenACCOpsTypes.td
+++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCOpsTypes.td
@@ -24,4 +24,13 @@ def OpenACC_DataBoundsType : OpenACC_Type<"DataBounds", "data_bounds_ty"> {
   let summary = "Type for representing acc data clause bounds information";
 }
 
+def OpenACC_DeclareTokenType : OpenACC_Type<"DeclareToken", "declare_token"> {
+  let summary = "declare token type";
+  let description = [{
+    `acc.declare_token` is a type returned by a `declare_enter` operation and
+    can be passed to a `declare_exit` operation to represent an implicit
+    data region.
+  }];
+}
+
 #endif // OPENACC_OPS_TYPES
diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
index 6e5df705fee05d8..b30ffe638999f9f 100644
--- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
+++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
@@ -1088,8 +1088,9 @@ LogicalResult AtomicCaptureOp::verifyRegions() { return verifyRegionsCommon(); }
 
 template <typename Op>
 static LogicalResult checkDeclareOperands(Op &op,
-                                          const mlir::ValueRange &operands) {
-  if (operands.empty())
+                                          const mlir::ValueRange &operands,
+                                          bool requireAtLeastOnOperand = true) {
+  if (operands.empty() && requireAtLeastOnOperand)
     return emitError(
         op->getLoc(),
         "at least one operand must appear on the declare operation");
@@ -1151,6 +1152,9 @@ LogicalResult acc::DeclareEnterOp::verify() {
 //===----------------------------------------------------------------------===//
 
 LogicalResult acc::DeclareExitOp::verify() {
+  if (getToken())
+    return checkDeclareOperands(*this, this->getDataClauseOperands(),
+                                /*requireAtLeastOneOperand=*/false);
   return checkDeclareOperands(*this, this->getDataClauseOperands());
 }
 

>From f6ba6f80b06fdb2fcd94b683bb3fa09c836b34fd Mon Sep 17 00:00:00 2001
From: Valentin Clement <clementval at gmail.com>
Date: Wed, 15 Nov 2023 21:38:07 -0800
Subject: [PATCH 2/3] Fix variable name

---
 mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
index b30ffe638999f9f..dabdf05bb44dd4d 100644
--- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
+++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
@@ -1089,8 +1089,8 @@ LogicalResult AtomicCaptureOp::verifyRegions() { return verifyRegionsCommon(); }
 template <typename Op>
 static LogicalResult checkDeclareOperands(Op &op,
                                           const mlir::ValueRange &operands,
-                                          bool requireAtLeastOnOperand = true) {
-  if (operands.empty() && requireAtLeastOnOperand)
+                                          bool requireAtLeastOneOperand = true) {
+  if (operands.empty() && requireAtLeastOneOperand)
     return emitError(
         op->getLoc(),
         "at least one operand must appear on the declare operation");

>From ec2aaf2c93802c5ac2fefcfd0866b0460b28f7f3 Mon Sep 17 00:00:00 2001
From: Valentin Clement <clementval at gmail.com>
Date: Wed, 15 Nov 2023 21:42:13 -0800
Subject: [PATCH 3/3] clang-format

---
 mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
index dabdf05bb44dd4d..08e83cad482207d 100644
--- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
+++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
@@ -1087,9 +1087,9 @@ LogicalResult AtomicCaptureOp::verifyRegions() { return verifyRegionsCommon(); }
 //===----------------------------------------------------------------------===//
 
 template <typename Op>
-static LogicalResult checkDeclareOperands(Op &op,
-                                          const mlir::ValueRange &operands,
-                                          bool requireAtLeastOneOperand = true) {
+static LogicalResult
+checkDeclareOperands(Op &op, const mlir::ValueRange &operands,
+                     bool requireAtLeastOneOperand = true) {
   if (operands.empty() && requireAtLeastOneOperand)
     return emitError(
         op->getLoc(),



More information about the Mlir-commits mailing list