[flang-commits] [flang] 44b0ea4 - [flang[OpenACC] Lower wait directive

Valentin Clement via flang-commits flang-commits at lists.llvm.org
Thu Mar 24 09:15:36 PDT 2022


Author: Valentin Clement
Date: 2022-03-24T17:15:27+01:00
New Revision: 44b0ea44f26d405f663d20b067dafd514cb5b26a

URL: https://github.com/llvm/llvm-project/commit/44b0ea44f26d405f663d20b067dafd514cb5b26a
DIFF: https://github.com/llvm/llvm-project/commit/44b0ea44f26d405f663d20b067dafd514cb5b26a.diff

LOG: [flang[OpenACC] Lower wait directive

This patch adds lowering for the `!$acc wait` directive
from the PFT to OpenACC dialect.

This patch is part of the upstreaming effort from fir-dev branch.

Reviewed By: PeteSteinfeld

Differential Revision: https://reviews.llvm.org/D122399

Added: 
    flang/test/Lower/OpenACC/acc-wait.f90

Modified: 
    flang/lib/Lower/OpenACC.cpp

Removed: 
    


################################################################################
diff  --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp
index abac3ff585768..f6db839000690 100644
--- a/flang/lib/Lower/OpenACC.cpp
+++ b/flang/lib/Lower/OpenACC.cpp
@@ -898,16 +898,16 @@ static void genACC(Fortran::lower::AbstractConverter &converter,
   const auto &accClauseList =
       std::get<Fortran::parser::AccClauseList>(waitConstruct.t);
 
-  mlir::Value ifCond, waitDevnum, async;
-  SmallVector<mlir::Value, 2> waitOperands;
+  mlir::Value ifCond, asyncOperand, waitDevnum, async;
+  SmallVector<mlir::Value> waitOperands;
 
   // Async clause have optional values but can be present with
   // no value as well. When there is no value, the op has an attribute to
   // represent the clause.
   bool addAsyncAttr = false;
 
-  auto &firOpBuilder = converter.getFirOpBuilder();
-  auto currentLocation = converter.getCurrentLocation();
+  fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
+  mlir::Location currentLocation = converter.getCurrentLocation();
   Fortran::lower::StatementContext stmtCtx;
 
   if (waitArgument) { // wait has a value.
@@ -930,35 +930,26 @@ static void genACC(Fortran::lower::AbstractConverter &converter,
   // Lower clauses values mapped to operands.
   // Keep track of each group of operands separatly as clauses can appear
   // more than once.
-  for (const auto &clause : accClauseList.v) {
+  for (const Fortran::parser::AccClause &clause : accClauseList.v) {
     if (const auto *ifClause =
             std::get_if<Fortran::parser::AccClause::If>(&clause.u)) {
-      mlir::Value cond = fir::getBase(converter.genExprValue(
-          *Fortran::semantics::GetExpr(ifClause->v), stmtCtx));
-      ifCond = firOpBuilder.createConvert(currentLocation,
-                                          firOpBuilder.getI1Type(), cond);
+      genIfClause(converter, ifClause, ifCond, stmtCtx);
     } else if (const auto *asyncClause =
                    std::get_if<Fortran::parser::AccClause::Async>(&clause.u)) {
-      const auto &asyncClauseValue = asyncClause->v;
-      if (asyncClauseValue) { // async has a value.
-        async = fir::getBase(converter.genExprValue(
-            *Fortran::semantics::GetExpr(*asyncClauseValue), stmtCtx));
-      } else {
-        addAsyncAttr = true;
-      }
+      genAsyncClause(converter, asyncClause, async, addAsyncAttr, stmtCtx);
     }
   }
 
   // Prepare the operand segement size attribute and the operands value range.
-  SmallVector<mlir::Value, 8> operands;
-  SmallVector<int32_t, 4> operandSegments;
+  SmallVector<mlir::Value> operands;
+  SmallVector<int32_t> operandSegments;
   addOperands(operands, operandSegments, waitOperands);
   addOperand(operands, operandSegments, async);
   addOperand(operands, operandSegments, waitDevnum);
   addOperand(operands, operandSegments, ifCond);
 
-  auto waitOp = createSimpleOp<mlir::acc::WaitOp>(firOpBuilder, currentLocation,
-                                                  operands, operandSegments);
+  mlir::acc::WaitOp waitOp = createSimpleOp<mlir::acc::WaitOp>(
+      firOpBuilder, currentLocation, operands, operandSegments);
 
   if (addAsyncAttr)
     waitOp.asyncAttr(firOpBuilder.getUnitAttr());

diff  --git a/flang/test/Lower/OpenACC/acc-wait.f90 b/flang/test/Lower/OpenACC/acc-wait.f90
new file mode 100644
index 0000000000000..70285999895f7
--- /dev/null
+++ b/flang/test/Lower/OpenACC/acc-wait.f90
@@ -0,0 +1,41 @@
+! This test checks lowering of OpenACC wait directive.
+
+! RUN: bbc -fopenacc -emit-fir %s -o - | FileCheck %s
+
+subroutine acc_update
+  integer :: async = 1
+  logical :: ifCondition = .TRUE.
+
+  !$acc wait
+!CHECK: acc.wait{{$}}
+
+  !$acc wait if(.true.)
+!CHECK: [[IF1:%.*]] = arith.constant true
+!CHECK: acc.wait if([[IF1]]){{$}}
+
+  !$acc wait if(ifCondition)
+!CHECK: [[IFCOND:%.*]] = fir.load %{{.*}} : !fir.ref<!fir.logical<4>>
+!CHECK: [[IF2:%.*]] = fir.convert [[IFCOND]] : (!fir.logical<4>) -> i1
+!CHECK: acc.wait if([[IF2]]){{$}}
+
+  !$acc wait(1, 2)
+!CHECK: [[WAIT1:%.*]] = arith.constant 1 : i32
+!CHECK: [[WAIT2:%.*]] = arith.constant 2 : i32
+!CHECK: acc.wait([[WAIT1]], [[WAIT2]] : i32, i32){{$}}
+
+  !$acc wait(1) async
+!CHECK: [[WAIT3:%.*]] = arith.constant 1 : i32
+!CHECK: acc.wait([[WAIT3]] : i32) attributes {async}
+
+  !$acc wait(1) async(async)
+!CHECK: [[WAIT3:%.*]] = arith.constant 1 : i32
+!CHECK: [[ASYNC1:%.*]] = fir.load %{{.*}} : !fir.ref<i32>
+!CHECK: acc.wait([[WAIT3]] : i32) async([[ASYNC1]] : i32){{$}}
+
+  !$acc wait(devnum: 3: queues: 1, 2)
+!CHECK: [[WAIT1:%.*]] = arith.constant 1 : i32
+!CHECK: [[WAIT2:%.*]] = arith.constant 2 : i32
+!CHECK: [[DEVNUM:%.*]] = arith.constant 3 : i32
+!CHECK: acc.wait([[WAIT1]], [[WAIT2]] : i32, i32) wait_devnum([[DEVNUM]] : i32){{$}}
+
+end subroutine acc_update


        


More information about the flang-commits mailing list