[flang-commits] [flang] d21580c - [MLIR][OpenMP]Add Flang lowering support for device_ptr and device_addr clauses

Akash Banerjee via flang-commits flang-commits at lists.llvm.org
Thu Jun 22 07:52:42 PDT 2023


Author: Akash Banerjee
Date: 2023-06-22T15:52:33+01:00
New Revision: d21580c3065740ac32f3713d737f302ed7f0ac94

URL: https://github.com/llvm/llvm-project/commit/d21580c3065740ac32f3713d737f302ed7f0ac94
DIFF: https://github.com/llvm/llvm-project/commit/d21580c3065740ac32f3713d737f302ed7f0ac94.diff

LOG: [MLIR][OpenMP]Add Flang lowering support for device_ptr and device_addr clauses

Add lowering support for the use_device_ptr and use_Device_addr clauses for the Target Data directive.

Depends on D152822

Differential Revision: https://reviews.llvm.org/D152824

Added: 
    

Modified: 
    flang/lib/Lower/OpenMP.cpp
    flang/test/Lower/OpenMP/target.f90

Removed: 
    


################################################################################
diff  --git a/flang/lib/Lower/OpenMP.cpp b/flang/lib/Lower/OpenMP.cpp
index 141912a18e29e..42fad0d483d19 100644
--- a/flang/lib/Lower/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP.cpp
@@ -723,6 +723,48 @@ createBodyOfOp(Op &op, Fortran::lower::AbstractConverter &converter,
   }
 }
 
+static void createBodyOfTargetOp(
+    Fortran::lower::AbstractConverter &converter, mlir::omp::DataOp &dataOp,
+    const llvm::SmallVector<mlir::Type> &useDeviceTypes,
+    const llvm::SmallVector<mlir::Location> &useDeviceLocs,
+    const SmallVector<const Fortran::semantics::Symbol *> &useDeviceSymbols,
+    const mlir::Location &currentLocation) {
+  fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
+  mlir::Region &region = dataOp.getRegion();
+
+  firOpBuilder.createBlock(&region, {}, useDeviceTypes, useDeviceLocs);
+  firOpBuilder.create<mlir::omp::TerminatorOp>(currentLocation);
+  firOpBuilder.setInsertionPointToStart(&region.front());
+
+  unsigned argIndex = 0;
+  for (auto *sym : useDeviceSymbols) {
+    const mlir::BlockArgument &arg = region.front().getArgument(argIndex);
+    mlir::Value val = fir::getBase(arg);
+    fir::ExtendedValue extVal = converter.getSymbolExtendedValue(*sym);
+    if (auto refType = val.getType().dyn_cast<fir::ReferenceType>()) {
+      if (fir::isa_builtin_cptr_type(refType.getElementType())) {
+        converter.bindSymbol(*sym, val);
+      } else {
+        extVal.match(
+            [&](const fir::MutableBoxValue &mbv) {
+              converter.bindSymbol(
+                  *sym,
+                  fir::MutableBoxValue(
+                      val, fir::factory::getNonDeferredLenParams(extVal), {}));
+            },
+            [&](const auto &) {
+              TODO(converter.getCurrentLocation(),
+                   "use_device clause operand unsupported type");
+            });
+      }
+    } else {
+      TODO(converter.getCurrentLocation(),
+           "use_device clause operand unsupported type");
+    }
+    argIndex++;
+  }
+}
+
 static void createTargetOp(Fortran::lower::AbstractConverter &converter,
                            const Fortran::parser::OmpClauseList &opClauseList,
                            const llvm::omp::Directive &directive,
@@ -732,13 +774,24 @@ static void createTargetOp(Fortran::lower::AbstractConverter &converter,
 
   mlir::Value ifClauseOperand, deviceOperand, threadLmtOperand;
   mlir::UnitAttr nowaitAttr;
-  llvm::SmallVector<mlir::Value> useDevicePtrOperand, useDeviceAddrOperand,
-      mapOperands;
+  llvm::SmallVector<mlir::Value> mapOperands, devicePtrOperands,
+      deviceAddrOperands;
   llvm::SmallVector<mlir::IntegerAttr> mapTypes;
+  llvm::SmallVector<mlir::Type> useDeviceTypes;
+  llvm::SmallVector<mlir::Location> useDeviceLocs;
+  SmallVector<const Fortran::semantics::Symbol *> useDeviceSymbols;
+
+  /// Check for unsupported map operand types.
+  auto checkType = [](auto currentLocation, mlir::Type type) {
+    if (auto refType = type.dyn_cast<fir::ReferenceType>())
+      type = refType.getElementType();
+    if (auto boxType = type.dyn_cast_or_null<fir::BoxType>())
+      if (!boxType.getElementType().isa<fir::PointerType>())
+        TODO(currentLocation, "OMPD_target_data MapOperand BoxType");
+  };
 
-  auto addMapClause = [&firOpBuilder, &converter, &mapOperands,
-                       &mapTypes](const auto &mapClause,
-                                  mlir::Location &currentLocation) {
+  auto addMapClause = [&](const auto &mapClause,
+                          mlir::Location &currentLocation) {
     auto mapType = std::get<Fortran::parser::OmpMapType::Type>(
         std::get<std::optional<Fortran::parser::OmpMapType>>(mapClause->v.t)
             ->t);
@@ -793,18 +846,25 @@ static void createTargetOp(Fortran::lower::AbstractConverter &converter,
                   converter, mapOperand);
 
     for (mlir::Value mapOp : mapOperand) {
-      /// Check for unsupported map operand types.
-      mlir::Type checkType = mapOp.getType();
-      if (auto refType = checkType.dyn_cast<fir::ReferenceType>())
-        checkType = refType.getElementType();
-      if (checkType.isa<fir::BoxType>())
-        TODO(currentLocation, "OMPD_target_data MapOperand BoxType");
-
+      checkType(mapOp.getLoc(), mapOp.getType());
       mapOperands.push_back(mapOp);
       mapTypes.push_back(mapTypeAttr);
     }
   };
 
+  auto addUseDeviceClause = [&](const auto &useDeviceClause, auto &operands) {
+    genObjectList(useDeviceClause, converter, operands);
+    for (auto &operand : operands) {
+      checkType(operand.getLoc(), operand.getType());
+      useDeviceTypes.push_back(operand.getType());
+      useDeviceLocs.push_back(operand.getLoc());
+    }
+    for (const Fortran::parser::OmpObject &ompObject : useDeviceClause.v) {
+      Fortran::semantics::Symbol *sym = getOmpObjectSymbol(ompObject);
+      useDeviceSymbols.push_back(sym);
+    }
+  };
+
   for (const Fortran::parser::OmpClause &clause : opClauseList.v) {
     mlir::Location currentLocation = converter.genLocation(clause.source);
     if (const auto &ifClause =
@@ -825,12 +885,6 @@ static void createTargetOp(Fortran::lower::AbstractConverter &converter,
         deviceOperand =
             fir::getBase(converter.genExprValue(*deviceExpr, stmtCtx));
       }
-    } else if (std::get_if<Fortran::parser::OmpClause::UseDevicePtr>(
-                   &clause.u)) {
-      TODO(currentLocation, "OMPD_target Use Device Ptr");
-    } else if (std::get_if<Fortran::parser::OmpClause::UseDeviceAddr>(
-                   &clause.u)) {
-      TODO(currentLocation, "OMPD_target Use Device Addr");
     } else if (const auto &threadLmtClause =
                    std::get_if<Fortran::parser::OmpClause::ThreadLimit>(
                        &clause.u)) {
@@ -838,6 +892,14 @@ static void createTargetOp(Fortran::lower::AbstractConverter &converter,
           *Fortran::semantics::GetExpr(threadLmtClause->v), stmtCtx));
     } else if (std::get_if<Fortran::parser::OmpClause::Nowait>(&clause.u)) {
       nowaitAttr = firOpBuilder.getUnitAttr();
+    } else if (const auto &devPtrClause =
+                   std::get_if<Fortran::parser::OmpClause::UseDevicePtr>(
+                       &clause.u)) {
+      addUseDeviceClause(devPtrClause->v, devicePtrOperands);
+    } else if (const auto &devAddrClause =
+                   std::get_if<Fortran::parser::OmpClause::UseDeviceAddr>(
+                       &clause.u)) {
+      addUseDeviceClause(devAddrClause->v, deviceAddrOperands);
     } else if (const auto &mapClause =
                    std::get_if<Fortran::parser::OmpClause::Map>(&clause.u)) {
       addMapClause(mapClause, currentLocation);
@@ -859,9 +921,10 @@ static void createTargetOp(Fortran::lower::AbstractConverter &converter,
     createBodyOfOp(targetOp, converter, currentLocation, *eval, &opClauseList);
   } else if (directive == llvm::omp::Directive::OMPD_target_data) {
     auto dataOp = firOpBuilder.create<omp::DataOp>(
-        currentLocation, ifClauseOperand, deviceOperand, useDevicePtrOperand,
-        useDeviceAddrOperand, mapOperands, mapTypesArrayAttr);
-    createBodyOfOp(dataOp, converter, currentLocation, *eval, &opClauseList);
+        currentLocation, ifClauseOperand, deviceOperand, devicePtrOperands,
+        deviceAddrOperands, mapOperands, mapTypesArrayAttr);
+    createBodyOfTargetOp(converter, dataOp, useDeviceTypes, useDeviceLocs,
+                         useDeviceSymbols, currentLocation);
   } else if (directive == llvm::omp::Directive::OMPD_target_enter_data) {
     firOpBuilder.create<omp::EnterDataOp>(currentLocation, ifClauseOperand,
                                           deviceOperand, nowaitAttr,
@@ -1157,7 +1220,17 @@ genOMP(Fortran::lower::AbstractConverter &converter,
       continue;
     } else if (std::get_if<Fortran::parser::OmpClause::Map>(&clause.u)) {
       // Map clause is exclusive to Target Data directives. It is handled
-      // as part of the DataOp creation.
+      // as part of the TargetOp creation.
+      continue;
+    } else if (std::get_if<Fortran::parser::OmpClause::UseDevicePtr>(
+                   &clause.u)) {
+      // UseDevicePtr clause is exclusive to Target Data directives. It is
+      // handled as part of the TargetOp creation.
+      continue;
+    } else if (std::get_if<Fortran::parser::OmpClause::UseDeviceAddr>(
+                   &clause.u)) {
+      // UseDeviceAddr clause is exclusive to Target Data directives. It is
+      // handled as part of the TargetOp creation.
       continue;
     } else if (std::get_if<Fortran::parser::OmpClause::ThreadLimit>(
                    &clause.u)) {

diff  --git a/flang/test/Lower/OpenMP/target.f90 b/flang/test/Lower/OpenMP/target.f90
index 0e8574a821940..44b2b7b53531e 100644
--- a/flang/test/Lower/OpenMP/target.f90
+++ b/flang/test/Lower/OpenMP/target.f90
@@ -162,3 +162,39 @@ subroutine omp_target_thread_limit
    !$omp end target
    !CHECK: }
 end subroutine omp_target_thread_limit
+
+!===============================================================================
+! Target `use_device_ptr` clause
+!===============================================================================
+
+!CHECK-LABEL: func.func @_QPomp_target_device_ptr() {
+subroutine omp_target_device_ptr
+   use iso_c_binding, only : c_ptr, c_loc
+   type(c_ptr) :: a
+   integer, target :: b
+   !CHECK: omp.target_data map((tofrom -> %[[VAL_0:.*]] : !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>)) use_device_ptr(%[[VAL_0]] : !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>)
+   !$omp target data map(tofrom: a) use_device_ptr(a)
+   !CHECK: ^bb0(%[[VAL_1:.*]]: !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>):
+   !CHECK: {{.*}} = fir.coordinate_of %[[VAL_1:.*]], {{.*}} : (!fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>, !fir.field) -> !fir.ref<i64>
+      a = c_loc(b)
+   !CHECK: omp.terminator
+   !$omp end target data
+   !CHECK: }
+end subroutine omp_target_device_ptr
+
+ !===============================================================================
+ ! Target `use_device_addr` clause
+ !===============================================================================
+
+ !CHECK-LABEL: func.func @_QPomp_target_device_addr() {
+ subroutine omp_target_device_addr
+   integer, pointer :: a
+   !CHECK:   omp.target_data map((tofrom -> %[[VAL_0:.*]] : !fir.ref<!fir.box<!fir.ptr<i32>>>)) use_device_addr(%[[VAL_0]] : !fir.ref<!fir.box<!fir.ptr<i32>>>)
+   !$omp target data map(tofrom: a) use_device_addr(a)
+   !CHECK: ^bb0(%[[VAL_1:.*]]: !fir.ref<!fir.box<!fir.ptr<i32>>>):
+   !CHECK: {{.*}} = fir.load %[[VAL_1]] : !fir.ref<!fir.box<!fir.ptr<i32>>>
+      a = 10
+   !CHECK: omp.terminator
+   !$omp end target data
+   !CHECK: }
+end subroutine omp_target_device_addr


        


More information about the flang-commits mailing list