[flang-commits] [flang] c6a9ce2 - [flang][OpenACC] Lower update directive

Valentin Clement via flang-commits flang-commits at lists.llvm.org
Thu Mar 24 07:19:13 PDT 2022


Author: Valentin Clement
Date: 2022-03-24T15:19:05+01:00
New Revision: c6a9ce2b6b746161ea5ed8aaf9e0c4b23ffb87b3

URL: https://github.com/llvm/llvm-project/commit/c6a9ce2b6b746161ea5ed8aaf9e0c4b23ffb87b3
DIFF: https://github.com/llvm/llvm-project/commit/c6a9ce2b6b746161ea5ed8aaf9e0c4b23ffb87b3.diff

LOG: [flang][OpenACC] Lower update directive

This patch adds lowering for the `!$acc update`
from the PFT to OpenACC dialect.

This patch is part of the upstreaming effort from fir-dev branch.

Depends on D122387

Reviewed By: PeteSteinfeld

Differential Revision: https://reviews.llvm.org/D122396

Added: 
    flang/test/Lower/OpenACC/acc-update.f90

Modified: 
    flang/lib/Lower/OpenACC.cpp

Removed: 
    


################################################################################
diff  --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp
index ed20ee2288c43..922c53ef43123 100644
--- a/flang/lib/Lower/OpenACC.cpp
+++ b/flang/lib/Lower/OpenACC.cpp
@@ -802,7 +802,7 @@ static void
 genACCUpdateOp(Fortran::lower::AbstractConverter &converter,
                const Fortran::parser::AccClauseList &accClauseList) {
   mlir::Value ifCond, async, waitDevnum;
-  SmallVector<Value, 2> hostOperands, deviceOperands, waitOperands,
+  SmallVector<mlir::Value> hostOperands, deviceOperands, waitOperands,
       deviceTypeOperands;
 
   // Async and wait clause have optional values but can be present with
@@ -812,68 +812,29 @@ genACCUpdateOp(Fortran::lower::AbstractConverter &converter,
   bool addWaitAttr = false;
   bool addIfPresentAttr = false;
 
-  auto &firOpBuilder = converter.getFirOpBuilder();
-  auto currentLocation = converter.getCurrentLocation();
+  fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
+  mlir::Location currentLocation = converter.getCurrentLocation();
   Fortran::lower::StatementContext stmtCtx;
 
   // Lower clauses values mapped to operands.
   // Keep track of each group of operands separatly as clauses can appear
   // more than once.
-  for (const auto &clause : accClauseList.v) {
+  for (const Fortran::parser::AccClause &clause : accClauseList.v) {
     if (const auto *ifClause =
             std::get_if<Fortran::parser::AccClause::If>(&clause.u)) {
-      mlir::Value cond = fir::getBase(converter.genExprValue(
-          *Fortran::semantics::GetExpr(ifClause->v), stmtCtx));
-      ifCond = firOpBuilder.createConvert(currentLocation,
-                                          firOpBuilder.getI1Type(), cond);
+      genIfClause(converter, ifClause, ifCond, stmtCtx);
     } else if (const auto *asyncClause =
                    std::get_if<Fortran::parser::AccClause::Async>(&clause.u)) {
-      const auto &asyncClauseValue = asyncClause->v;
-      if (asyncClauseValue) { // async has a value.
-        async = fir::getBase(converter.genExprValue(
-            *Fortran::semantics::GetExpr(*asyncClauseValue), stmtCtx));
-      } else {
-        addAsyncAttr = true;
-      }
+      genAsyncClause(converter, asyncClause, async, addAsyncAttr, stmtCtx);
     } else if (const auto *waitClause =
                    std::get_if<Fortran::parser::AccClause::Wait>(&clause.u)) {
-      const auto &waitClauseValue = waitClause->v;
-      if (waitClauseValue) { // wait has a value.
-        const Fortran::parser::AccWaitArgument &waitArg = *waitClauseValue;
-        const std::list<Fortran::parser::ScalarIntExpr> &waitList =
-            std::get<std::list<Fortran::parser::ScalarIntExpr>>(waitArg.t);
-        for (const Fortran::parser::ScalarIntExpr &value : waitList) {
-          mlir::Value v = fir::getBase(converter.genExprValue(
-              *Fortran::semantics::GetExpr(value), stmtCtx));
-          waitOperands.push_back(v);
-        }
-
-        const std::optional<Fortran::parser::ScalarIntExpr> &waitDevnumValue =
-            std::get<std::optional<Fortran::parser::ScalarIntExpr>>(waitArg.t);
-        if (waitDevnumValue)
-          waitDevnum = fir::getBase(converter.genExprValue(
-              *Fortran::semantics::GetExpr(*waitDevnumValue), stmtCtx));
-      } else {
-        addWaitAttr = true;
-      }
+      genWaitClause(converter, waitClause, waitOperands, waitDevnum,
+                    addWaitAttr, stmtCtx);
     } else if (const auto *deviceTypeClause =
                    std::get_if<Fortran::parser::AccClause::DeviceType>(
                        &clause.u)) {
-
-      const auto &deviceTypeValue = deviceTypeClause->v;
-      if (deviceTypeValue) {
-        for (const auto &scalarIntExpr : *deviceTypeValue) {
-          mlir::Value expr = fir::getBase(converter.genExprValue(
-              *Fortran::semantics::GetExpr(scalarIntExpr), stmtCtx));
-          deviceTypeOperands.push_back(expr);
-        }
-      } else {
-        // * was passed as value and will be represented as a -1 constant
-        // integer.
-        mlir::Value star = firOpBuilder.createIntegerConstant(
-            currentLocation, firOpBuilder.getIntegerType(32), /* STAR */ -1);
-        deviceTypeOperands.push_back(star);
-      }
+      genDeviceTypeClause(converter, deviceTypeClause, deviceTypeOperands,
+                          stmtCtx);
     } else if (const auto *hostClause =
                    std::get_if<Fortran::parser::AccClause::Host>(&clause.u)) {
       genObjectList(hostClause->v, converter, hostOperands);
@@ -884,17 +845,17 @@ genACCUpdateOp(Fortran::lower::AbstractConverter &converter,
   }
 
   // Prepare the operand segement size attribute and the operands value range.
-  SmallVector<mlir::Value, 14> operands;
-  SmallVector<int32_t, 7> operandSegments;
+  SmallVector<mlir::Value> operands;
+  SmallVector<int32_t> operandSegments;
+  addOperand(operands, operandSegments, ifCond);
   addOperand(operands, operandSegments, async);
   addOperand(operands, operandSegments, waitDevnum);
   addOperands(operands, operandSegments, waitOperands);
   addOperands(operands, operandSegments, deviceTypeOperands);
-  addOperand(operands, operandSegments, ifCond);
   addOperands(operands, operandSegments, hostOperands);
   addOperands(operands, operandSegments, deviceOperands);
 
-  auto updateOp = createSimpleOp<mlir::acc::UpdateOp>(
+  mlir::acc::UpdateOp updateOp = createSimpleOp<mlir::acc::UpdateOp>(
       firOpBuilder, currentLocation, operands, operandSegments);
 
   if (addAsyncAttr)

diff  --git a/flang/test/Lower/OpenACC/acc-update.f90 b/flang/test/Lower/OpenACC/acc-update.f90
new file mode 100644
index 0000000000000..1b3aef2dc674b
--- /dev/null
+++ b/flang/test/Lower/OpenACC/acc-update.f90
@@ -0,0 +1,72 @@
+! This test checks lowering of OpenACC update directive.
+
+! RUN: bbc -fopenacc -emit-fir %s -o - | FileCheck %s
+
+subroutine acc_update
+  integer :: async = 1
+  real, dimension(10, 10) :: a, b, c
+  logical :: ifCondition = .TRUE.
+
+!CHECK: [[A:%.*]] = fir.alloca !fir.array<10x10xf32> {{{.*}}uniq_name = "{{.*}}Ea"}
+!CHECK: [[B:%.*]] = fir.alloca !fir.array<10x10xf32> {{{.*}}uniq_name = "{{.*}}Eb"}
+!CHECK: [[C:%.*]] = fir.alloca !fir.array<10x10xf32> {{{.*}}uniq_name = "{{.*}}Ec"}
+
+  !$acc update host(a)
+!CHECK: acc.update host([[A]] : !fir.ref<!fir.array<10x10xf32>>){{$}}
+
+  !$acc update host(a) if(.true.)
+!CHECK: [[IF1:%.*]] = arith.constant true
+!CHECK: acc.update if([[IF1]]) host([[A]] : !fir.ref<!fir.array<10x10xf32>>){{$}}
+
+  !$acc update host(a) if(ifCondition)
+!CHECK: [[IFCOND:%.*]] = fir.load %{{.*}} : !fir.ref<!fir.logical<4>>
+!CHECK: [[IF2:%.*]] = fir.convert [[IFCOND]] : (!fir.logical<4>) -> i1
+!CHECK: acc.update if([[IF2]]) host([[A]] : !fir.ref<!fir.array<10x10xf32>>){{$}}
+
+  !$acc update host(a) host(b) host(c)
+!CHECK: acc.update host([[A]], [[B]], [[C]] : !fir.ref<!fir.array<10x10xf32>>, !fir.ref<!fir.array<10x10xf32>>, !fir.ref<!fir.array<10x10xf32>>){{$}}
+
+  !$acc update host(a) host(b) device(c)
+!CHECK: acc.update host([[A]], [[B]] : !fir.ref<!fir.array<10x10xf32>>, !fir.ref<!fir.array<10x10xf32>>) device([[C]] : !fir.ref<!fir.array<10x10xf32>>){{$}}
+
+  !$acc update host(a) async
+!CHECK: acc.update host([[A]] : !fir.ref<!fir.array<10x10xf32>>) attributes {async}
+
+  !$acc update host(a) wait
+!CHECK: acc.update host([[A]] : !fir.ref<!fir.array<10x10xf32>>) attributes {wait}
+
+  !$acc update host(a) async wait
+!CHECK: acc.update host([[A]] : !fir.ref<!fir.array<10x10xf32>>) attributes {async, wait}
+
+  !$acc update host(a) async(1)
+!CHECK: [[ASYNC1:%.*]] = arith.constant 1 : i32
+!CHECK: acc.update async([[ASYNC1]] : i32) host([[A]] : !fir.ref<!fir.array<10x10xf32>>)
+
+  !$acc update host(a) async(async)
+!CHECK: [[ASYNC2:%.*]] = fir.load %{{.*}} : !fir.ref<i32>
+!CHECK: acc.update async([[ASYNC2]] : i32) host([[A]] : !fir.ref<!fir.array<10x10xf32>>)
+
+  !$acc update host(a) wait(1)
+!CHECK: [[WAIT1:%.*]] = arith.constant 1 : i32
+!CHECK: acc.update wait([[WAIT1]] : i32) host([[A]] : !fir.ref<!fir.array<10x10xf32>>)
+
+  !$acc update host(a) wait(queues: 1, 2)
+!CHECK: [[WAIT2:%.*]] = arith.constant 1 : i32
+!CHECK: [[WAIT3:%.*]] = arith.constant 2 : i32
+!CHECK: acc.update wait([[WAIT2]], [[WAIT3]] : i32, i32) host([[A]] : !fir.ref<!fir.array<10x10xf32>>)
+
+  !$acc update host(a) wait(devnum: 1: queues: 1, 2)
+!CHECK: [[WAIT4:%.*]] = arith.constant 1 : i32
+!CHECK: [[WAIT5:%.*]] = arith.constant 2 : i32
+!CHECK: [[WAIT6:%.*]] = arith.constant 1 : i32
+!CHECK: acc.update wait_devnum([[WAIT6]] : i32) wait([[WAIT4]], [[WAIT5]] : i32, i32) host([[A]] : !fir.ref<!fir.array<10x10xf32>>)
+
+  !$acc update host(a) device_type(1, 2)
+!CHECK: [[DEVTYPE1:%.*]] = arith.constant 1 : i32
+!CHECK: [[DEVTYPE2:%.*]] = arith.constant 2 : i32
+!CHECK: acc.update device_type([[DEVTYPE1]], [[DEVTYPE2]] : i32, i32) host([[A]] : !fir.ref<!fir.array<10x10xf32>>){{$}}
+
+  !$acc update host(a) device_type(*)
+!CHECK: [[DEVTYPE3:%.*]] = arith.constant -1 : index
+!CHECK: acc.update device_type([[DEVTYPE3]] : index) host([[A]] : !fir.ref<!fir.array<10x10xf32>>){{$}}
+end subroutine acc_update


        


More information about the flang-commits mailing list