[Mlir-commits] [mlir] [flang] [mlir][openacc] Add device_type support for data operation (PR #76126)

Valentin Clement バレンタイン クレメン llvmlistbot at llvm.org
Thu Jan 4 11:08:22 PST 2024


https://github.com/clementval updated https://github.com/llvm/llvm-project/pull/76126

>From 929fba20132162c00898ed5f33364eba34a76675 Mon Sep 17 00:00:00 2001
From: Valentin Clement <clementval at gmail.com>
Date: Wed, 20 Dec 2023 16:14:28 -0800
Subject: [PATCH] [mlir][openacc] Add device_type support for operation

---
 flang/lib/Lower/OpenACC.cpp                   | 125 ++++++++++++------
 flang/test/Lower/OpenACC/acc-data.f90         |   8 +-
 .../mlir/Dialect/OpenACC/OpenACCOps.td        |  47 +++++--
 mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp       |  43 +++++-
 mlir/test/Dialect/OpenACC/ops.mlir            |   8 +-
 5 files changed, 176 insertions(+), 55 deletions(-)

diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp
index ecf70818c4ac0f..d10e56e5d11779 100644
--- a/flang/lib/Lower/OpenACC.cpp
+++ b/flang/lib/Lower/OpenACC.cpp
@@ -1464,6 +1464,24 @@ static void genAsyncClause(Fortran::lower::AbstractConverter &converter,
   }
 }
 
+static void
+genAsyncClause(Fortran::lower::AbstractConverter &converter,
+               const Fortran::parser::AccClause::Async *asyncClause,
+               llvm::SmallVector<mlir::Value> &async,
+               llvm::SmallVector<mlir::Attribute> &asyncDeviceTypes,
+               llvm::SmallVector<mlir::Attribute> &asyncOnlyDeviceTypes,
+               mlir::acc::DeviceTypeAttr deviceTypeAttr,
+               Fortran::lower::StatementContext &stmtCtx) {
+  const auto &asyncClauseValue = asyncClause->v;
+  if (asyncClauseValue) { // async has a value.
+    async.push_back(fir::getBase(converter.genExprValue(
+        *Fortran::semantics::GetExpr(*asyncClauseValue), stmtCtx)));
+    asyncDeviceTypes.push_back(deviceTypeAttr);
+  } else {
+    asyncOnlyDeviceTypes.push_back(deviceTypeAttr);
+  }
+}
+
 static mlir::acc::DeviceType
 getDeviceType(Fortran::parser::AccDeviceTypeExpr::Device device) {
   switch (device) {
@@ -1533,6 +1551,39 @@ static void genWaitClause(Fortran::lower::AbstractConverter &converter,
   }
 }
 
+static void
+genWaitClause(Fortran::lower::AbstractConverter &converter,
+              const Fortran::parser::AccClause::Wait *waitClause,
+              llvm::SmallVector<mlir::Value> &waitOperands,
+              llvm::SmallVector<mlir::Attribute> &waitOperandsDeviceTypes,
+              llvm::SmallVector<mlir::Attribute> &waitOnlyDeviceTypes,
+              llvm::SmallVector<int32_t> &waitOperandsSegments,
+              mlir::Value &waitDevnum, mlir::acc::DeviceTypeAttr deviceTypeAttr,
+              Fortran::lower::StatementContext &stmtCtx) {
+  const auto &waitClauseValue = waitClause->v;
+  if (waitClauseValue) { // wait has a value.
+    const Fortran::parser::AccWaitArgument &waitArg = *waitClauseValue;
+    const auto &waitList =
+        std::get<std::list<Fortran::parser::ScalarIntExpr>>(waitArg.t);
+    auto crtWaitOperands = waitOperands.size();
+    for (const Fortran::parser::ScalarIntExpr &value : waitList) {
+      waitOperands.push_back(fir::getBase(converter.genExprValue(
+          *Fortran::semantics::GetExpr(value), stmtCtx)));
+    }
+    waitOperandsDeviceTypes.push_back(deviceTypeAttr);
+    waitOperandsSegments.push_back(waitOperands.size() - crtWaitOperands);
+
+    // TODO: move to device_type model.
+    const auto &waitDevnumValue =
+        std::get<std::optional<Fortran::parser::ScalarIntExpr>>(waitArg.t);
+    if (waitDevnumValue)
+      waitDevnum = fir::getBase(converter.genExprValue(
+          *Fortran::semantics::GetExpr(*waitDevnumValue), stmtCtx));
+  } else {
+    waitOnlyDeviceTypes.push_back(deviceTypeAttr);
+  }
+}
+
 static mlir::acc::LoopOp
 createLoopOp(Fortran::lower::AbstractConverter &converter,
              mlir::Location currentLocation,
@@ -1795,6 +1846,7 @@ createComputeOp(Fortran::lower::AbstractConverter &converter,
       firstprivateOperands;
   llvm::SmallVector<mlir::Attribute> privatizations, firstPrivatizations,
       reductionRecipes;
+  mlir::Value waitDevnum; // TODO not yet implemented on compute op.
 
   // Self clause has optional values but can be present with
   // no value as well. When there is no value, the op has an attribute to
@@ -1818,31 +1870,14 @@ createComputeOp(Fortran::lower::AbstractConverter &converter,
     mlir::Location clauseLocation = converter.genLocation(clause.source);
     if (const auto *asyncClause =
             std::get_if<Fortran::parser::AccClause::Async>(&clause.u)) {
-      const auto &asyncClauseValue = asyncClause->v;
-      if (asyncClauseValue) { // async has a value.
-        async.push_back(fir::getBase(converter.genExprValue(
-            *Fortran::semantics::GetExpr(*asyncClauseValue), stmtCtx)));
-        asyncDeviceTypes.push_back(crtDeviceTypeAttr);
-      } else {
-        asyncOnlyDeviceTypes.push_back(crtDeviceTypeAttr);
-      }
+      genAsyncClause(converter, asyncClause, async, asyncDeviceTypes,
+                     asyncOnlyDeviceTypes, crtDeviceTypeAttr, stmtCtx);
     } else if (const auto *waitClause =
                    std::get_if<Fortran::parser::AccClause::Wait>(&clause.u)) {
-      const auto &waitClauseValue = waitClause->v;
-      if (waitClauseValue) { // wait has a value.
-        const Fortran::parser::AccWaitArgument &waitArg = *waitClauseValue;
-        const auto &waitList =
-            std::get<std::list<Fortran::parser::ScalarIntExpr>>(waitArg.t);
-        auto crtWaitOperands = waitOperands.size();
-        for (const Fortran::parser::ScalarIntExpr &value : waitList) {
-          waitOperands.push_back(fir::getBase(converter.genExprValue(
-              *Fortran::semantics::GetExpr(value), stmtCtx)));
-        }
-        waitOperandsDeviceTypes.push_back(crtDeviceTypeAttr);
-        waitOperandsSegments.push_back(waitOperands.size() - crtWaitOperands);
-      } else {
-        waitOnlyDeviceTypes.push_back(crtDeviceTypeAttr);
-      }
+      genWaitClause(converter, waitClause, waitOperands,
+                    waitOperandsDeviceTypes, waitOnlyDeviceTypes,
+                    waitOperandsSegments, waitDevnum, crtDeviceTypeAttr,
+                    stmtCtx);
     } else if (const auto *numGangsClause =
                    std::get_if<Fortran::parser::AccClause::NumGangs>(
                        &clause.u)) {
@@ -2126,21 +2161,24 @@ static void genACCDataOp(Fortran::lower::AbstractConverter &converter,
                          Fortran::semantics::SemanticsContext &semanticsContext,
                          Fortran::lower::StatementContext &stmtCtx,
                          const Fortran::parser::AccClauseList &accClauseList) {
-  mlir::Value ifCond, async, waitDevnum;
+  mlir::Value ifCond, waitDevnum;
   llvm::SmallVector<mlir::Value> attachEntryOperands, createEntryOperands,
-      copyEntryOperands, copyoutEntryOperands, dataClauseOperands, waitOperands;
-
-  // Async and wait have an optional value but can be present with
-  // no value as well. When there is no value, the op has an attribute to
-  // represent the clause.
-  bool addAsyncAttr = false;
-  bool addWaitAttr = false;
+      copyEntryOperands, copyoutEntryOperands, dataClauseOperands, waitOperands,
+      async;
+  llvm::SmallVector<mlir::Attribute> asyncDeviceTypes, asyncOnlyDeviceTypes,
+      waitOperandsDeviceTypes, waitOnlyDeviceTypes;
+  llvm::SmallVector<int32_t> waitOperandsSegments;
 
   bool hasDefaultNone = false;
   bool hasDefaultPresent = false;
 
   fir::FirOpBuilder &builder = converter.getFirOpBuilder();
 
+  // device_type attribute is set to `none` until a device_type clause is
+  // encountered.
+  auto crtDeviceTypeAttr = mlir::acc::DeviceTypeAttr::get(
+      builder.getContext(), mlir::acc::DeviceType::None);
+
   // Lower clauses values mapped to operands.
   // Keep track of each group of operands separately as clauses can appear
   // more than once.
@@ -2221,11 +2259,14 @@ static void genACCDataOp(Fortran::lower::AbstractConverter &converter,
                                  dataClauseOperands.end());
     } else if (const auto *asyncClause =
                    std::get_if<Fortran::parser::AccClause::Async>(&clause.u)) {
-      genAsyncClause(converter, asyncClause, async, addAsyncAttr, stmtCtx);
+      genAsyncClause(converter, asyncClause, async, asyncDeviceTypes,
+                     asyncOnlyDeviceTypes, crtDeviceTypeAttr, stmtCtx);
     } else if (const auto *waitClause =
                    std::get_if<Fortran::parser::AccClause::Wait>(&clause.u)) {
-      genWaitClause(converter, waitClause, waitOperands, waitDevnum,
-                    addWaitAttr, stmtCtx);
+      genWaitClause(converter, waitClause, waitOperands,
+                    waitOperandsDeviceTypes, waitOnlyDeviceTypes,
+                    waitOperandsSegments, waitDevnum, crtDeviceTypeAttr,
+                    stmtCtx);
     } else if(const auto *defaultClause = 
                   std::get_if<Fortran::parser::AccClause::Default>(&clause.u)) {
       if ((defaultClause->v).v == llvm::acc::DefaultValue::ACC_Default_none)
@@ -2239,7 +2280,7 @@ static void genACCDataOp(Fortran::lower::AbstractConverter &converter,
   llvm::SmallVector<mlir::Value> operands;
   llvm::SmallVector<int32_t> operandSegments;
   addOperand(operands, operandSegments, ifCond);
-  addOperand(operands, operandSegments, async);
+  addOperands(operands, operandSegments, async);
   addOperand(operands, operandSegments, waitDevnum);
   addOperands(operands, operandSegments, waitOperands);
   addOperands(operands, operandSegments, dataClauseOperands);
@@ -2250,8 +2291,18 @@ static void genACCDataOp(Fortran::lower::AbstractConverter &converter,
   auto dataOp = createRegionOp<mlir::acc::DataOp, mlir::acc::TerminatorOp>(
       builder, currentLocation, eval, operands, operandSegments);
 
-  dataOp.setAsyncAttr(addAsyncAttr);
-  dataOp.setWaitAttr(addWaitAttr);
+  if (!asyncDeviceTypes.empty())
+    dataOp.setAsyncDeviceTypeAttr(builder.getArrayAttr(asyncDeviceTypes));
+  if (!asyncOnlyDeviceTypes.empty())
+    dataOp.setAsyncOnlyAttr(builder.getArrayAttr(asyncOnlyDeviceTypes));
+  if (!waitOperandsDeviceTypes.empty())
+    dataOp.setWaitOperandsDeviceTypeAttr(
+        builder.getArrayAttr(waitOperandsDeviceTypes));
+  if (!waitOperandsSegments.empty())
+    dataOp.setWaitOperandsSegmentsAttr(
+        builder.getDenseI32ArrayAttr(waitOperandsSegments));
+  if (!waitOnlyDeviceTypes.empty())
+    dataOp.setWaitOnlyAttr(builder.getArrayAttr(waitOnlyDeviceTypes));
 
   if (hasDefaultNone)
     dataOp.setDefaultAttr(mlir::acc::ClauseDefaultValue::None);
diff --git a/flang/test/Lower/OpenACC/acc-data.f90 b/flang/test/Lower/OpenACC/acc-data.f90
index a6572e14707606..75ffd1fc3fcab2 100644
--- a/flang/test/Lower/OpenACC/acc-data.f90
+++ b/flang/test/Lower/OpenACC/acc-data.f90
@@ -153,7 +153,7 @@ subroutine acc_data
   !$acc end data
 
 ! CHECK: acc.data dataOperands(%{{.*}}) {
-! CHECK: } attributes {asyncAttr}
+! CHECK: } attributes {asyncOnly = [#acc.device_type<none>]}
 
   !$acc data present(a) async(1)
   !$acc end data
@@ -165,18 +165,18 @@ subroutine acc_data
   !$acc end data
 
 ! CHECK: acc.data dataOperands(%{{.*}}) {
-! CHECK: } attributes {waitAttr}
+! CHECK: } attributes {waitOnly = [#acc.device_type<none>]}
 
   !$acc data present(a) wait(1)
   !$acc end data
 
-! CHECK: acc.data dataOperands(%{{.*}}) wait(%{{.*}} : i32) {
+! CHECK: acc.data dataOperands(%{{.*}}) wait({%{{.*}} : i32}) {
 ! CHECK: }{{$}}
 
   !$acc data present(a) wait(devnum: 0: 1)
   !$acc end data
 
-! CHECK: acc.data dataOperands(%{{.*}}) wait_devnum(%{{.*}} : i32) wait(%{{.*}} : i32) {
+! CHECK: acc.data dataOperands(%{{.*}}) wait_devnum(%{{.*}} : i32) wait({%{{.*}} : i32}) {
 ! CHECK: }{{$}}
 
   !$acc data default(none)
diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
index 4312bd4de1bd4f..1dd83e933034ab 100644
--- a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
+++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
@@ -1236,13 +1236,16 @@ def OpenACC_DataOp : OpenACC_Op<"data",
 
 
   let arguments = (ins Optional<I1>:$ifCond,
-                       Optional<IntOrIndex>:$async,
-                       UnitAttr:$asyncAttr,
-                       Optional<IntOrIndex>:$waitDevnum,
-                       Variadic<IntOrIndex>:$waitOperands,
-                       UnitAttr:$waitAttr,
-                       Variadic<OpenACC_PointerLikeTypeInterface>:$dataClauseOperands,
-                       OptionalAttr<DefaultValueAttr>:$defaultAttr);
+      Variadic<IntOrIndex>:$async,
+      OptionalAttr<DeviceTypeArrayAttr>:$asyncDeviceType,
+      OptionalAttr<DeviceTypeArrayAttr>:$asyncOnly,
+      Optional<IntOrIndex>:$waitDevnum,
+      Variadic<IntOrIndex>:$waitOperands,
+      OptionalAttr<DenseI32ArrayAttr>:$waitOperandsSegments,
+      OptionalAttr<DeviceTypeArrayAttr>:$waitOperandsDeviceType,
+      OptionalAttr<DeviceTypeArrayAttr>:$waitOnly,
+      Variadic<OpenACC_PointerLikeTypeInterface>:$dataClauseOperands,
+      OptionalAttr<DefaultValueAttr>:$defaultAttr);
 
   let regions = (region AnyRegion:$region);
 
@@ -1252,15 +1255,41 @@ def OpenACC_DataOp : OpenACC_Op<"data",
 
     /// The i-th data operand passed.
     Value getDataOperand(unsigned i);
+
+    /// Return true if the op has the async attribute for the
+    /// mlir::acc::DeviceType::None device_type.
+    bool hasAsyncOnly();
+    /// Return true if the op has the async attribute for the given device_type.
+    bool hasAsyncOnly(mlir::acc::DeviceType deviceType);
+    /// Return the value of the async clause if present.
+    mlir::Value getAsyncValue();
+    /// Return the value of the async clause for the given device_type if
+    /// present.
+    mlir::Value getAsyncValue(mlir::acc::DeviceType deviceType);
+
+    /// Return true if the op has the wait attribute for the
+    /// mlir::acc::DeviceType::None device_type.
+    bool hasWaitOnly();
+    /// Return true if the op has the wait attribute for the given device_type.
+    bool hasWaitOnly(mlir::acc::DeviceType deviceType);
+    /// Return the values of the wait clause if present.
+    mlir::Operation::operand_range getWaitValues();
+    /// Return the values of the wait clause for the given device_type if
+    /// present.
+    mlir::Operation::operand_range
+    getWaitValues(mlir::acc::DeviceType deviceType);
   }];
 
   let assemblyFormat = [{
     oilist(
         `if` `(` $ifCond `)`
-      | `async` `(` $async `:` type($async) `)`
+      | `async` `(` custom<DeviceTypeOperands>($async,
+            type($async), $asyncDeviceType) `)`
       | `dataOperands` `(` $dataClauseOperands `:` type($dataClauseOperands) `)`
       | `wait_devnum` `(` $waitDevnum `:` type($waitDevnum) `)`
-      | `wait` `(` $waitOperands `:` type($waitOperands) `)`
+      | `wait` `(` custom<DeviceTypeOperandsWithSegment>($waitOperands,
+            type($waitOperands), $waitOperandsDeviceType, 
+            $waitOperandsSegments) `)`
     )
     $region attr-dict-with-keyword
   }];
diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
index e299b67b10a9c7..66605ead0529df 100644
--- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
+++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
@@ -1417,11 +1417,52 @@ unsigned DataOp::getNumDataOperands() { return getDataClauseOperands().size(); }
 
 Value DataOp::getDataOperand(unsigned i) {
   unsigned numOptional = getIfCond() ? 1 : 0;
-  numOptional += getAsync() ? 1 : 0;
+  numOptional += getAsync().size() ? 1 : 0;
   numOptional += getWaitOperands().size();
   return getOperand(numOptional + i);
 }
 
+bool acc::DataOp::hasAsyncOnly() {
+  return hasAsyncOnly(mlir::acc::DeviceType::None);
+}
+
+bool acc::DataOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
+  if (auto arrayAttr = getAsyncOnly()) {
+    if (findSegment(*arrayAttr, deviceType))
+      return true;
+  }
+  return false;
+}
+
+mlir::Value DataOp::getAsyncValue() {
+  return getAsyncValue(mlir::acc::DeviceType::None);
+}
+
+mlir::Value DataOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
+  return getValueInDeviceTypeSegment(getAsyncDeviceType(), getAsync(),
+                                     deviceType);
+}
+
+bool DataOp::hasWaitOnly() { return hasWaitOnly(mlir::acc::DeviceType::None); }
+
+bool DataOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
+  if (auto arrayAttr = getWaitOnly()) {
+    if (findSegment(*arrayAttr, deviceType))
+      return true;
+  }
+  return false;
+}
+
+mlir::Operation::operand_range DataOp::getWaitValues() {
+  return getWaitValues(mlir::acc::DeviceType::None);
+}
+
+mlir::Operation::operand_range
+DataOp::getWaitValues(mlir::acc::DeviceType deviceType) {
+  return getValuesFromSegments(getWaitOperandsDeviceType(), getWaitOperands(),
+                               getWaitOperandsSegments(), deviceType);
+}
+
 //===----------------------------------------------------------------------===//
 // ExitDataOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/OpenACC/ops.mlir b/mlir/test/Dialect/OpenACC/ops.mlir
index 5a95811685f845..52375b1af3141c 100644
--- a/mlir/test/Dialect/OpenACC/ops.mlir
+++ b/mlir/test/Dialect/OpenACC/ops.mlir
@@ -836,11 +836,11 @@ func.func @testdataop(%a: memref<f32>, %b: memref<f32>, %c: memref<f32>) -> () {
   } attributes { defaultAttr = #acc<defaultvalue none>, wait }
 
   %w1 = arith.constant 1 : i64
-  acc.data wait(%w1 : i64) {
+  acc.data wait({%w1 : i64}) {
   } attributes { defaultAttr = #acc<defaultvalue none>, wait }
 
   %wd1 = arith.constant 1 : i64
-  acc.data wait_devnum(%wd1 : i64) wait(%w1 : i64) {
+  acc.data wait_devnum(%wd1 : i64) wait({%w1 : i64}) {
   } attributes { defaultAttr = #acc<defaultvalue none>, wait }
 
   return
@@ -951,10 +951,10 @@ func.func @testdataop(%a: memref<f32>, %b: memref<f32>, %c: memref<f32>) -> () {
 // CHECK:      acc.data {
 // CHECK-NEXT: } attributes {defaultAttr = #acc<defaultvalue none>, wait}
 
-// CHECK:      acc.data wait(%{{.*}} : i64) {
+// CHECK:      acc.data wait({%{{.*}} : i64}) {
 // CHECK-NEXT: } attributes {defaultAttr = #acc<defaultvalue none>, wait}
 
-// CHECK:      acc.data wait_devnum(%{{.*}} : i64) wait(%{{.*}} : i64) {
+// CHECK:      acc.data wait_devnum(%{{.*}} : i64) wait({%{{.*}} : i64}) {
 // CHECK-NEXT: } attributes {defaultAttr = #acc<defaultvalue none>, wait}
 
 // -----



More information about the Mlir-commits mailing list