[clang] abfb2ce - [OpenACC][NFCI] Implement 'helpers' for all of the clauses I've used so far (#137396)

via cfe-commits cfe-commits at lists.llvm.org
Mon Apr 28 06:06:46 PDT 2025


Author: Erich Keane
Date: 2025-04-28T06:06:42-07:00
New Revision: abfb2ce2f57fc02e222936aeb602681add752d9b

URL: https://github.com/llvm/llvm-project/commit/abfb2ce2f57fc02e222936aeb602681add752d9b
DIFF: https://github.com/llvm/llvm-project/commit/abfb2ce2f57fc02e222936aeb602681add752d9b.diff

LOG: [OpenACC][NFCI] Implement 'helpers' for all of the clauses I've used so far (#137396)

As a follow up to 3c4dff3ac6884b85fe93fe512c5bdaf014738c45 I audited all
uses of 'process clause and use additive methods', and added explicit
functions to the construct to make it easier for the next project to
attempt to use this mechanism (vs construct all operands/etc in advance,
then add all at once).

I've only done ones that I have attempted to use so far(as a catch-up,
so no var-list clauses, and no constructs that can't be used without a
var-list, and no loop, and no compound constructs). I intend to do those
"as I go" with the lowering of each of those things instead.

---------

Co-authored-by: Andy Kaylor <akaylor at nvidia.com>

Added: 
    

Modified: 
    clang/lib/CIR/CodeGen/CIRGenStmtOpenACC.cpp
    mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
    mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp

Removed: 
    


################################################################################
diff  --git a/clang/lib/CIR/CodeGen/CIRGenStmtOpenACC.cpp b/clang/lib/CIR/CodeGen/CIRGenStmtOpenACC.cpp
index 6e65f94c78bed..6f86d2b681a1e 100644
--- a/clang/lib/CIR/CodeGen/CIRGenStmtOpenACC.cpp
+++ b/clang/lib/CIR/CodeGen/CIRGenStmtOpenACC.cpp
@@ -46,7 +46,17 @@ class OpenACCClauseCIREmitter final
   // diagnostics are gone.
   SourceLocation dirLoc;
 
-  const OpenACCDeviceTypeClause *lastDeviceTypeClause = nullptr;
+  llvm::SmallVector<mlir::acc::DeviceType> lastDeviceTypeValues;
+
+  void setLastDeviceTypeClause(const OpenACCDeviceTypeClause &clause) {
+    lastDeviceTypeValues.clear();
+
+    llvm::for_each(clause.getArchitectures(),
+                   [this](const DeviceTypeArgument &arg) {
+                     lastDeviceTypeValues.push_back(
+                         decodeDeviceType(arg.getIdentifierInfo()));
+                   });
+  }
 
   void clauseNotImplemented(const OpenACCClause &c) {
     cgf.cgm.errorNYI(c.getSourceRange(), "OpenACC Clause", c.getClauseKind());
@@ -95,114 +105,6 @@ class OpenACCClauseCIREmitter final
         .CaseLower("radeon", mlir::acc::DeviceType::Radeon);
   }
 
-  // Overload of this function that only returns the device-types list.
-  mlir::ArrayAttr
-  handleDeviceTypeAffectedClause(mlir::ArrayAttr existingDeviceTypes) {
-    mlir::ValueRange argument;
-    mlir::MutableOperandRange range{operation};
-
-    return handleDeviceTypeAffectedClause(existingDeviceTypes, argument, range);
-  }
-  // Overload of this function for when 'segments' aren't necessary.
-  mlir::ArrayAttr
-  handleDeviceTypeAffectedClause(mlir::ArrayAttr existingDeviceTypes,
-                                 mlir::ValueRange argument,
-                                 mlir::MutableOperandRange argCollection) {
-    llvm::SmallVector<int32_t> segments;
-    assert(argument.size() <= 1 &&
-           "Overload only for cases where segments don't need to be added");
-    return handleDeviceTypeAffectedClause(existingDeviceTypes, argument,
-                                          argCollection, segments);
-  }
-
-  // Handle a clause affected by the 'device_type' to the point that they need
-  // to have attributes added in the correct/corresponding order, such as
-  // 'num_workers' or 'vector_length' on a compute construct. The 'argument' is
-  // a collection of operands that need to be appended to the `argCollection` as
-  // we're adding a 'device_type' entry.  If there is more than 0 elements in
-  // the 'argument', the collection must be non-null, as it is needed to add to
-  // it.
-  // As some clauses, such as 'num_gangs' or 'wait' require a 'segments' list to
-  // be maintained, this takes a list of segments that will be updated with the
-  // proper counts as 'argument' elements are added.
-  //
-  // In MLIR, the 'operands' are stored as a large array, with a separate array
-  // of 'segments' that show which 'operand' applies to which 'operand-kind'.
-  // That is, a 'num_workers' operand-kind or 'num_vectors' operand-kind.
-  //
-  // So the operands array might have 4 elements, but the 'segments' array will
-  // be something like:
-  //
-  // {0, 0, 0, 2, 0, 1, 1, 0, 0...}
-  //
-  // Where each position belongs to a specific 'operand-kind'.  So that
-  // specifies that whichever operand-kind corresponds with index '3' has 2
-  // elements, and should take the 1st 2 operands off the list (since all
-  // preceding values are 0). operand-kinds corresponding to 5 and 6 each have
-  // 1 element.
-  //
-  // Fortunately, the `MutableOperandRange` append function actually takes care
-  // of that for us at the 'top level'.
-  //
-  // However, in cases like `num_gangs' or 'wait', where each individual
-  // 'element' might be itself array-like, there is a separate 'segments' array
-  // for them. So in the case of:
-  //
-  // device_type(nvidia, radeon) num_gangs(1, 2, 3)
-  //
-  // We have to emit that as TWO arrays into the IR (where the device_type is an
-  // attribute), so they look like:
-  //
-  // num_gangs({One : i32, Two : i32, Three : i32} [#acc.device_type<nvidia>],\
-  //           {One : i32, Two : i32, Three : i32} [#acc.device_type<radeon>])
-  //
-  // When stored in the 'operands' list, the top-level 'segment' for
-  // 'num_gangs' just shows 6 elements. In order to get the array-like
-  // apperance, the 'numGangsSegments' list is kept as well. In the above case,
-  // we've inserted 6 operands, so the 'numGangsSegments' must contain 2
-  // elements, 1 per array, and each will have a value of 3.  The verifier will
-  // ensure that the collections counts are correct.
-  mlir::ArrayAttr
-  handleDeviceTypeAffectedClause(mlir::ArrayAttr existingDeviceTypes,
-                                 mlir::ValueRange argument,
-                                 mlir::MutableOperandRange argCollection,
-                                 llvm::SmallVector<int32_t> &segments) {
-    llvm::SmallVector<mlir::Attribute> deviceTypes;
-
-    // Collect the 'existing' device-type attributes so we can re-create them
-    // and insert them.
-    if (existingDeviceTypes) {
-      for (const mlir::Attribute &Attr : existingDeviceTypes)
-        deviceTypes.push_back(mlir::acc::DeviceTypeAttr::get(
-            builder.getContext(),
-            cast<mlir::acc::DeviceTypeAttr>(Attr).getValue()));
-    }
-
-    // Insert 1 version of the 'expr' to the NumWorkers list per-current
-    // device type.
-    if (lastDeviceTypeClause) {
-      for (const DeviceTypeArgument &arch :
-           lastDeviceTypeClause->getArchitectures()) {
-        deviceTypes.push_back(mlir::acc::DeviceTypeAttr::get(
-            builder.getContext(), decodeDeviceType(arch.getIdentifierInfo())));
-        if (!argument.empty()) {
-          argCollection.append(argument);
-          segments.push_back(argument.size());
-        }
-      }
-    } else {
-      // Else, we just add a single for 'none'.
-      deviceTypes.push_back(mlir::acc::DeviceTypeAttr::get(
-          builder.getContext(), mlir::acc::DeviceType::None));
-      if (!argument.empty()) {
-        argCollection.append(argument);
-        segments.push_back(argument.size());
-      }
-    }
-
-    return mlir::ArrayAttr::get(builder.getContext(), deviceTypes);
-  }
-
 public:
   OpenACCClauseCIREmitter(OpTy &operation, CIRGenFunction &cgf,
                           CIRGenBuilderTy &builder,
@@ -236,7 +138,8 @@ class OpenACCClauseCIREmitter final
   }
 
   void VisitDeviceTypeClause(const OpenACCDeviceTypeClause &clause) {
-    lastDeviceTypeClause = &clause;
+    setLastDeviceTypeClause(clause);
+
     if constexpr (isOneOfTypes<OpTy, InitOp, ShutdownOp>) {
       llvm::for_each(
           clause.getArchitectures(), [this](const DeviceTypeArgument &arg) {
@@ -253,8 +156,8 @@ class OpenACCClauseCIREmitter final
     } else if constexpr (isOneOfTypes<OpTy, ParallelOp, SerialOp, KernelsOp,
                                       DataOp>) {
       // Nothing to do here, these constructs don't have any IR for these, as
-      // they just modify the other clauses IR.  So setting of `lastDeviceType`
-      // (done above) is all we need.
+      // they just modify the other clauses IR.  So setting of
+      // `lastDeviceTypeValues` (done above) is all we need.
     } else {
       // TODO: When we've implemented this for everything, switch this to an
       // unreachable. update, data, loop, routine, combined constructs remain.
@@ -264,10 +167,9 @@ class OpenACCClauseCIREmitter final
 
   void VisitNumWorkersClause(const OpenACCNumWorkersClause &clause) {
     if constexpr (isOneOfTypes<OpTy, ParallelOp, KernelsOp>) {
-      mlir::MutableOperandRange range = operation.getNumWorkersMutable();
-      operation.setNumWorkersDeviceTypeAttr(handleDeviceTypeAffectedClause(
-          operation.getNumWorkersDeviceTypeAttr(),
-          createIntExpr(clause.getIntExpr()), range));
+      operation.addNumWorkersOperand(builder.getContext(),
+                                     createIntExpr(clause.getIntExpr()),
+                                     lastDeviceTypeValues);
     } else if constexpr (isOneOfTypes<OpTy, SerialOp>) {
       llvm_unreachable("num_workers not valid on serial");
     } else {
@@ -279,10 +181,9 @@ class OpenACCClauseCIREmitter final
 
   void VisitVectorLengthClause(const OpenACCVectorLengthClause &clause) {
     if constexpr (isOneOfTypes<OpTy, ParallelOp, KernelsOp>) {
-      mlir::MutableOperandRange range = operation.getVectorLengthMutable();
-      operation.setVectorLengthDeviceTypeAttr(handleDeviceTypeAffectedClause(
-          operation.getVectorLengthDeviceTypeAttr(),
-          createIntExpr(clause.getIntExpr()), range));
+      operation.addVectorLengthOperand(builder.getContext(),
+                                       createIntExpr(clause.getIntExpr()),
+                                       lastDeviceTypeValues);
     } else if constexpr (isOneOfTypes<OpTy, SerialOp>) {
       llvm_unreachable("vector_length not valid on serial");
     } else {
@@ -294,15 +195,12 @@ class OpenACCClauseCIREmitter final
 
   void VisitAsyncClause(const OpenACCAsyncClause &clause) {
     if constexpr (isOneOfTypes<OpTy, ParallelOp, SerialOp, KernelsOp, DataOp>) {
-      if (!clause.hasIntExpr()) {
-        operation.setAsyncOnlyAttr(
-            handleDeviceTypeAffectedClause(operation.getAsyncOnlyAttr()));
-      } else {
-        mlir::MutableOperandRange range = operation.getAsyncOperandsMutable();
-        operation.setAsyncOperandsDeviceTypeAttr(handleDeviceTypeAffectedClause(
-            operation.getAsyncOperandsDeviceTypeAttr(),
-            createIntExpr(clause.getIntExpr()), range));
-      }
+      if (!clause.hasIntExpr())
+        operation.addAsyncOnly(builder.getContext(), lastDeviceTypeValues);
+      else
+        operation.addAsyncOperand(builder.getContext(),
+                                  createIntExpr(clause.getIntExpr()),
+                                  lastDeviceTypeValues);
     } else if constexpr (isOneOfTypes<OpTy, WaitOp>) {
       // Wait doesn't have a device_type, so its handling here is slightly
       // 
diff erent.
@@ -366,19 +264,11 @@ class OpenACCClauseCIREmitter final
   void VisitNumGangsClause(const OpenACCNumGangsClause &clause) {
     if constexpr (isOneOfTypes<OpTy, ParallelOp, KernelsOp>) {
       llvm::SmallVector<mlir::Value> values;
-
       for (const Expr *E : clause.getIntExprs())
         values.push_back(createIntExpr(E));
 
-      llvm::SmallVector<int32_t> segments;
-      if (operation.getNumGangsSegments())
-        llvm::copy(*operation.getNumGangsSegments(),
-                   std::back_inserter(segments));
-
-      mlir::MutableOperandRange range = operation.getNumGangsMutable();
-      operation.setNumGangsDeviceTypeAttr(handleDeviceTypeAffectedClause(
-          operation.getNumGangsDeviceTypeAttr(), values, range, segments));
-      operation.setNumGangsSegments(llvm::ArrayRef<int32_t>{segments});
+      operation.addNumGangsOperands(builder.getContext(), values,
+                                    lastDeviceTypeValues);
     } else {
       // TODO: When we've implemented this for everything, switch this to an
       // unreachable. Combined constructs remain.
@@ -389,42 +279,15 @@ class OpenACCClauseCIREmitter final
   void VisitWaitClause(const OpenACCWaitClause &clause) {
     if constexpr (isOneOfTypes<OpTy, ParallelOp, SerialOp, KernelsOp, DataOp>) {
       if (!clause.hasExprs()) {
-        operation.setWaitOnlyAttr(
-            handleDeviceTypeAffectedClause(operation.getWaitOnlyAttr()));
+        operation.addWaitOnly(builder.getContext(), lastDeviceTypeValues);
       } else {
         llvm::SmallVector<mlir::Value> values;
-
         if (clause.hasDevNumExpr())
           values.push_back(createIntExpr(clause.getDevNumExpr()));
         for (const Expr *E : clause.getQueueIdExprs())
           values.push_back(createIntExpr(E));
-
-        llvm::SmallVector<int32_t> segments;
-        if (operation.getWaitOperandsSegments())
-          llvm::copy(*operation.getWaitOperandsSegments(),
-                     std::back_inserter(segments));
-
-        unsigned beforeSegmentSize = segments.size();
-
-        mlir::MutableOperandRange range = operation.getWaitOperandsMutable();
-        operation.setWaitOperandsDeviceTypeAttr(handleDeviceTypeAffectedClause(
-            operation.getWaitOperandsDeviceTypeAttr(), values, range,
-            segments));
-        operation.setWaitOperandsSegments(segments);
-
-        // In addition to having to set the 'segments', wait also has a list of
-        // bool attributes whether it is annotated with 'devnum'.  We can use
-        // our knowledge of how much the 'segments' array grew to determine how
-        // many we need to add.
-        llvm::SmallVector<bool> hasDevNums;
-        if (operation.getHasWaitDevnumAttr())
-          for (mlir::Attribute A : operation.getHasWaitDevnumAttr())
-            hasDevNums.push_back(cast<mlir::BoolAttr>(A).getValue());
-
-        hasDevNums.insert(hasDevNums.end(), segments.size() - beforeSegmentSize,
-                          clause.hasDevNumExpr());
-
-        operation.setHasWaitDevnumAttr(builder.getBoolArrayAttr(hasDevNums));
+        operation.addWaitOperands(builder.getContext(), clause.hasDevNumExpr(),
+                                  values, lastDeviceTypeValues);
       }
     } else {
       // TODO: When we've implemented this for everything, switch this to an
@@ -589,7 +452,7 @@ CIRGenFunction::emitOpenACCWaitConstruct(const OpenACCWaitConstruct &s) {
     if (s.hasDevNumExpr())
       waitOp.getWaitDevnumMutable().append(createIntExpr(s.getDevNumExpr()));
 
-    for (Expr *QueueExpr  : s.getQueueIdExprs())
+    for (Expr *QueueExpr : s.getQueueIdExprs())
       waitOp.getWaitOperandsMutable().append(createIntExpr(QueueExpr));
   }
 

diff  --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
index 2167129e9e1c7..41cec89fdf598 100644
--- a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
+++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
@@ -1408,6 +1408,31 @@ def OpenACC_ParallelOp : OpenACC_Op<"parallel",
     static mlir::acc::Construct getConstructId() {
       return mlir::acc::Construct::acc_construct_parallel;
     }
+    /// Add a value to 'num_workers' with the current list of device types.
+    void addNumWorkersOperand(MLIRContext *, mlir::Value,
+                              llvm::ArrayRef<DeviceType>);
+    /// Add a value to 'vector_length' with the current list of device types.
+    void addVectorLengthOperand(MLIRContext *, mlir::Value,
+                                llvm::ArrayRef<DeviceType>);
+    /// Add an entry to the 'async-only' attribute (clause spelled without
+    /// arguments)for each of the additional device types (or a none if it is
+    /// empty).
+    void addAsyncOnly(MLIRContext *, llvm::ArrayRef<DeviceType>);
+    /// Add a value to the 'async' with the current list of device types.
+    void addAsyncOperand(MLIRContext *, mlir::Value,
+                         llvm::ArrayRef<DeviceType>);
+    /// Add an array-like entry to the 'num_gangs' with the current list of
+    /// device types.
+    void addNumGangsOperands(MLIRContext *, mlir::ValueRange,
+                             llvm::ArrayRef<DeviceType>);
+    /// Add an entry to the 'wait-only' attribute (clause spelled without
+    /// arguments)for each of the additional device types (or a none if it is
+    /// empty).
+    void addWaitOnly(MLIRContext *, llvm::ArrayRef<DeviceType>);
+    /// Add an array-like entry  to the 'wait' with the current list of device
+    /// types.
+    void addWaitOperands(MLIRContext *, bool hasDevnum, mlir::ValueRange,
+                         llvm::ArrayRef<DeviceType>);
   }];
 
   let assemblyFormat = [{
@@ -1535,6 +1560,21 @@ def OpenACC_SerialOp : OpenACC_Op<"serial",
     static mlir::acc::Construct getConstructId() {
       return mlir::acc::Construct::acc_construct_serial;
     }
+    /// Add an entry to the 'async-only' attribute (clause spelled without
+    /// arguments) for each of the additional device types (or a none if it is
+    /// empty).
+    void addAsyncOnly(MLIRContext *, llvm::ArrayRef<DeviceType>);
+    /// Add a value to the 'async' with the current list of device types.
+    void addAsyncOperand(MLIRContext *, mlir::Value,
+                         llvm::ArrayRef<DeviceType>);
+    /// Add an entry to the 'wait-only' attribute (clause spelled without
+    /// arguments) for each of the additional device types (or a none if it is
+    /// empty).
+    void addWaitOnly(MLIRContext *, llvm::ArrayRef<DeviceType>);
+    /// Add an array-like entry  to the 'wait' with the current list of device
+    /// types.
+    void addWaitOperands(MLIRContext *, bool hasDevnum, mlir::ValueRange,
+                         llvm::ArrayRef<DeviceType>);
   }];
 
   let assemblyFormat = [{
@@ -1679,6 +1719,31 @@ def OpenACC_KernelsOp : OpenACC_Op<"kernels",
     static mlir::acc::Construct getConstructId() {
       return mlir::acc::Construct::acc_construct_kernels;
     }
+    /// Add a value to 'num_workers' with the current list of device types.
+    void addNumWorkersOperand(MLIRContext *, mlir::Value,
+                              llvm::ArrayRef<DeviceType>);
+    /// Add a value to 'vector_length' with the current list of device types.
+    void addVectorLengthOperand(MLIRContext *, mlir::Value,
+                                llvm::ArrayRef<DeviceType>);
+    /// Add an entry to the 'async-only' attribute (clause spelled without
+    /// arguments) for each of the additional device types (or a none if it is
+    /// empty).
+    void addAsyncOnly(MLIRContext *, llvm::ArrayRef<DeviceType>);
+    /// Add a value to the 'async' with the current list of device types.
+    void addAsyncOperand(MLIRContext *, mlir::Value,
+                         llvm::ArrayRef<DeviceType>);
+    /// Add an array-like entry to the 'num_gangs' with the current list of
+    /// device types.
+    void addNumGangsOperands(MLIRContext *, mlir::ValueRange,
+                             llvm::ArrayRef<DeviceType>);
+    /// Add an entry to the 'wait-only' attribute (clause spelled without
+    /// arguments) for each of the additional device types (or a none if it is
+    /// empty).
+    void addWaitOnly(MLIRContext *, llvm::ArrayRef<DeviceType>);
+    /// Add an array-like entry  to the 'wait' with the current list of device
+    /// types.
+    void addWaitOperands(MLIRContext *, bool hasDevnum, mlir::ValueRange,
+                         llvm::ArrayRef<DeviceType>);
   }];
 
   let assemblyFormat = [{
@@ -1785,6 +1850,21 @@ def OpenACC_DataOp : OpenACC_Op<"data",
     /// Return the wait devnum value clause for the given device_type if
     /// present.
     mlir::Value getWaitDevnum(mlir::acc::DeviceType deviceType);
+    /// Add an entry to the 'async-only' attribute (clause spelled without
+    /// arguments) for each of the additional device types (or a none if it is
+    /// empty).
+    void addAsyncOnly(MLIRContext *, llvm::ArrayRef<DeviceType>);
+    /// Add a value to the 'async' with the current list of device types.
+    void addAsyncOperand(MLIRContext *, mlir::Value,
+                         llvm::ArrayRef<DeviceType>);
+    /// Add an entry to the 'wait-only' attribute (clause spelled without
+    /// arguments) for each of the additional device types (or a none if it is
+    /// empty).
+    void addWaitOnly(MLIRContext *, llvm::ArrayRef<DeviceType>);
+    /// Add an array-like entry  to the 'wait' with the current list of device
+    /// types.
+    void addWaitOperands(MLIRContext *, bool hasDevnum, mlir::ValueRange,
+                         llvm::ArrayRef<DeviceType>);
   }];
 
   let assemblyFormat = [{

diff  --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
index 04cbe200eafe9..56f3228d3a652 100644
--- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
+++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
@@ -76,6 +76,69 @@ struct LLVMPointerPointerLikeModel
                                             LLVM::LLVMPointerType> {
   Type getElementType(Type pointer) const { return Type(); }
 };
+
+/// Helper function for any of the times we need to modify an ArrayAttr based on
+/// a device type list.  Returns a new ArrayAttr with all of the
+/// existingDeviceTypes, plus the effective new ones(or an added none if hte new
+/// list is empty).
+mlir::ArrayAttr addDeviceTypeAffectedOperandHelper(
+    MLIRContext *context, mlir::ArrayAttr existingDeviceTypes,
+    llvm::ArrayRef<acc::DeviceType> newDeviceTypes) {
+  llvm::SmallVector<mlir::Attribute> deviceTypes;
+  if (existingDeviceTypes)
+    llvm::copy(existingDeviceTypes, std::back_inserter(deviceTypes));
+
+  if (newDeviceTypes.empty())
+    deviceTypes.push_back(
+        acc::DeviceTypeAttr::get(context, acc::DeviceType::None));
+
+  for (DeviceType DT : newDeviceTypes)
+    deviceTypes.push_back(acc::DeviceTypeAttr::get(context, DT));
+
+  return mlir::ArrayAttr::get(context, deviceTypes);
+}
+
+/// Helper function for any of the times we need to add operands that are
+/// affected by a device type list. Returns a new ArrayAttr with all of the
+/// existingDeviceTypes, plus the effective new ones (or an added none, if the
+/// new list is empty). Additionally, adds the arguments to the argCollection
+/// the correct number of times. This will also update a 'segments' array, even
+/// if it won't be used.
+mlir::ArrayAttr addDeviceTypeAffectedOperandHelper(
+    MLIRContext *context, mlir::ArrayAttr existingDeviceTypes,
+    llvm::ArrayRef<acc::DeviceType> newDeviceTypes, mlir::ValueRange arguments,
+    mlir::MutableOperandRange argCollection,
+    llvm::SmallVector<int32_t> &segments) {
+  llvm::SmallVector<mlir::Attribute> deviceTypes;
+  if (existingDeviceTypes)
+    llvm::copy(existingDeviceTypes, std::back_inserter(deviceTypes));
+
+  if (newDeviceTypes.empty()) {
+    argCollection.append(arguments);
+    segments.push_back(arguments.size());
+    deviceTypes.push_back(
+        acc::DeviceTypeAttr::get(context, acc::DeviceType::None));
+  }
+
+  for (DeviceType DT : newDeviceTypes) {
+    argCollection.append(arguments);
+    segments.push_back(arguments.size());
+    deviceTypes.push_back(acc::DeviceTypeAttr::get(context, DT));
+  }
+
+  return mlir::ArrayAttr::get(context, deviceTypes);
+}
+
+/// Overload for when the 'segments' aren't needed.
+mlir::ArrayAttr addDeviceTypeAffectedOperandHelper(
+    MLIRContext *context, mlir::ArrayAttr existingDeviceTypes,
+    llvm::ArrayRef<acc::DeviceType> newDeviceTypes, mlir::ValueRange arguments,
+    mlir::MutableOperandRange argCollection) {
+  llvm::SmallVector<int32_t> segments;
+  return addDeviceTypeAffectedOperandHelper(context, existingDeviceTypes,
+                                            newDeviceTypes, arguments,
+                                            argCollection, segments);
+}
 } // namespace
 
 //===----------------------------------------------------------------------===//
@@ -1170,6 +1233,76 @@ void ParallelOp::build(mlir::OpBuilder &odsBuilder,
       /*defaultAttr=*/nullptr, /*combined=*/nullptr);
 }
 
+void acc::ParallelOp::addNumWorkersOperand(
+    MLIRContext *context, mlir::Value newValue,
+    llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
+  setNumWorkersDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
+      context, getNumWorkersDeviceTypeAttr(), effectiveDeviceTypes, newValue,
+      getNumWorkersMutable()));
+}
+void acc::ParallelOp::addVectorLengthOperand(
+    MLIRContext *context, mlir::Value newValue,
+    llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
+  setVectorLengthDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
+      context, getVectorLengthDeviceTypeAttr(), effectiveDeviceTypes, newValue,
+      getVectorLengthMutable()));
+}
+
+void acc::ParallelOp::addAsyncOnly(
+    MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
+  setAsyncOnlyAttr(addDeviceTypeAffectedOperandHelper(
+      context, getAsyncOnlyAttr(), effectiveDeviceTypes));
+}
+
+void acc::ParallelOp::addAsyncOperand(
+    MLIRContext *context, mlir::Value newValue,
+    llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
+  setAsyncOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
+      context, getAsyncOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValue,
+      getAsyncOperandsMutable()));
+}
+
+void acc::ParallelOp::addNumGangsOperands(
+    MLIRContext *context, mlir::ValueRange newValues,
+    llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
+  llvm::SmallVector<int32_t> segments;
+  if (getNumGangsSegments())
+    llvm::copy(*getNumGangsSegments(), std::back_inserter(segments));
+
+  setNumGangsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
+      context, getNumGangsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
+      getNumGangsMutable(), segments));
+
+  setNumGangsSegments(segments);
+}
+void acc::ParallelOp::addWaitOnly(
+    MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
+  setWaitOnlyAttr(addDeviceTypeAffectedOperandHelper(context, getWaitOnlyAttr(),
+                                                     effectiveDeviceTypes));
+}
+void acc::ParallelOp::addWaitOperands(
+    MLIRContext *context, bool hasDevnum, mlir::ValueRange newValues,
+    llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
+
+  llvm::SmallVector<int32_t> segments;
+  if (getWaitOperandsSegments())
+    llvm::copy(*getWaitOperandsSegments(), std::back_inserter(segments));
+
+  setWaitOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
+      context, getWaitOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
+      getWaitOperandsMutable(), segments));
+  setWaitOperandsSegments(segments);
+
+  llvm::SmallVector<mlir::Attribute> hasDevnums;
+  if (getHasWaitDevnumAttr())
+    llvm::copy(getHasWaitDevnumAttr(), std::back_inserter(hasDevnums));
+  hasDevnums.insert(
+      hasDevnums.end(),
+      std::max(effectiveDeviceTypes.size(), static_cast<size_t>(1)),
+      mlir::BoolAttr::get(context, hasDevnum));
+  setHasWaitDevnumAttr(mlir::ArrayAttr::get(context, hasDevnums));
+}
+
 static ParseResult parseNumGangs(
     mlir::OpAsmParser &parser,
     llvm::SmallVectorImpl<mlir::OpAsmParser::UnresolvedOperand> &operands,
@@ -1686,6 +1819,48 @@ LogicalResult acc::SerialOp::verify() {
   return checkDataOperands<acc::SerialOp>(*this, getDataClauseOperands());
 }
 
+void acc::SerialOp::addAsyncOnly(
+    MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
+  setAsyncOnlyAttr(addDeviceTypeAffectedOperandHelper(
+      context, getAsyncOnlyAttr(), effectiveDeviceTypes));
+}
+
+void acc::SerialOp::addAsyncOperand(
+    MLIRContext *context, mlir::Value newValue,
+    llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
+  setAsyncOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
+      context, getAsyncOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValue,
+      getAsyncOperandsMutable()));
+}
+
+void acc::SerialOp::addWaitOnly(
+    MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
+  setWaitOnlyAttr(addDeviceTypeAffectedOperandHelper(context, getWaitOnlyAttr(),
+                                                     effectiveDeviceTypes));
+}
+void acc::SerialOp::addWaitOperands(
+    MLIRContext *context, bool hasDevnum, mlir::ValueRange newValues,
+    llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
+
+  llvm::SmallVector<int32_t> segments;
+  if (getWaitOperandsSegments())
+    llvm::copy(*getWaitOperandsSegments(), std::back_inserter(segments));
+
+  setWaitOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
+      context, getWaitOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
+      getWaitOperandsMutable(), segments));
+  setWaitOperandsSegments(segments);
+
+  llvm::SmallVector<mlir::Attribute> hasDevnums;
+  if (getHasWaitDevnumAttr())
+    llvm::copy(getHasWaitDevnumAttr(), std::back_inserter(hasDevnums));
+  hasDevnums.insert(
+      hasDevnums.end(),
+      std::max(effectiveDeviceTypes.size(), static_cast<size_t>(1)),
+      mlir::BoolAttr::get(context, hasDevnum));
+  setHasWaitDevnumAttr(mlir::ArrayAttr::get(context, hasDevnums));
+}
+
 //===----------------------------------------------------------------------===//
 // KernelsOp
 //===----------------------------------------------------------------------===//
@@ -1813,6 +1988,77 @@ LogicalResult acc::KernelsOp::verify() {
   return checkDataOperands<acc::KernelsOp>(*this, getDataClauseOperands());
 }
 
+void acc::KernelsOp::addNumWorkersOperand(
+    MLIRContext *context, mlir::Value newValue,
+    llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
+  setNumWorkersDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
+      context, getNumWorkersDeviceTypeAttr(), effectiveDeviceTypes, newValue,
+      getNumWorkersMutable()));
+}
+
+void acc::KernelsOp::addVectorLengthOperand(
+    MLIRContext *context, mlir::Value newValue,
+    llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
+  setVectorLengthDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
+      context, getVectorLengthDeviceTypeAttr(), effectiveDeviceTypes, newValue,
+      getVectorLengthMutable()));
+}
+void acc::KernelsOp::addAsyncOnly(
+    MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
+  setAsyncOnlyAttr(addDeviceTypeAffectedOperandHelper(
+      context, getAsyncOnlyAttr(), effectiveDeviceTypes));
+}
+
+void acc::KernelsOp::addAsyncOperand(
+    MLIRContext *context, mlir::Value newValue,
+    llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
+  setAsyncOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
+      context, getAsyncOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValue,
+      getAsyncOperandsMutable()));
+}
+
+void acc::KernelsOp::addNumGangsOperands(
+    MLIRContext *context, mlir::ValueRange newValues,
+    llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
+  llvm::SmallVector<int32_t> segments;
+  if (getNumGangsSegmentsAttr())
+    llvm::copy(*getNumGangsSegments(), std::back_inserter(segments));
+
+  setNumGangsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
+      context, getNumGangsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
+      getNumGangsMutable(), segments));
+
+  setNumGangsSegments(segments);
+}
+
+void acc::KernelsOp::addWaitOnly(
+    MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
+  setWaitOnlyAttr(addDeviceTypeAffectedOperandHelper(context, getWaitOnlyAttr(),
+                                                     effectiveDeviceTypes));
+}
+void acc::KernelsOp::addWaitOperands(
+    MLIRContext *context, bool hasDevnum, mlir::ValueRange newValues,
+    llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
+
+  llvm::SmallVector<int32_t> segments;
+  if (getWaitOperandsSegments())
+    llvm::copy(*getWaitOperandsSegments(), std::back_inserter(segments));
+
+  setWaitOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
+      context, getWaitOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
+      getWaitOperandsMutable(), segments));
+  setWaitOperandsSegments(segments);
+
+  llvm::SmallVector<mlir::Attribute> hasDevnums;
+  if (getHasWaitDevnumAttr())
+    llvm::copy(getHasWaitDevnumAttr(), std::back_inserter(hasDevnums));
+  hasDevnums.insert(
+      hasDevnums.end(),
+      std::max(effectiveDeviceTypes.size(), static_cast<size_t>(1)),
+      mlir::BoolAttr::get(context, hasDevnum));
+  setHasWaitDevnumAttr(mlir::ArrayAttr::get(context, hasDevnums));
+}
+
 //===----------------------------------------------------------------------===//
 // HostDataOp
 //===----------------------------------------------------------------------===//
@@ -2439,6 +2685,49 @@ mlir::Value DataOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
                             deviceType);
 }
 
+void acc::DataOp::addAsyncOnly(
+    MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
+  setAsyncOnlyAttr(addDeviceTypeAffectedOperandHelper(
+      context, getAsyncOnlyAttr(), effectiveDeviceTypes));
+}
+
+void acc::DataOp::addAsyncOperand(
+    MLIRContext *context, mlir::Value newValue,
+    llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
+  setAsyncOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
+      context, getAsyncOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValue,
+      getAsyncOperandsMutable()));
+}
+
+void acc::DataOp::addWaitOnly(MLIRContext *context,
+                              llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
+  setWaitOnlyAttr(addDeviceTypeAffectedOperandHelper(context, getWaitOnlyAttr(),
+                                                     effectiveDeviceTypes));
+}
+
+void acc::DataOp::addWaitOperands(
+    MLIRContext *context, bool hasDevnum, mlir::ValueRange newValues,
+    llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
+
+  llvm::SmallVector<int32_t> segments;
+  if (getWaitOperandsSegments())
+    llvm::copy(*getWaitOperandsSegments(), std::back_inserter(segments));
+
+  setWaitOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
+      context, getWaitOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
+      getWaitOperandsMutable(), segments));
+  setWaitOperandsSegments(segments);
+
+  llvm::SmallVector<mlir::Attribute> hasDevnums;
+  if (getHasWaitDevnumAttr())
+    llvm::copy(getHasWaitDevnumAttr(), std::back_inserter(hasDevnums));
+  hasDevnums.insert(
+      hasDevnums.end(),
+      std::max(effectiveDeviceTypes.size(), static_cast<size_t>(1)),
+      mlir::BoolAttr::get(context, hasDevnum));
+  setHasWaitDevnumAttr(mlir::ArrayAttr::get(context, hasDevnums));
+}
+
 //===----------------------------------------------------------------------===//
 // ExitDataOp
 //===----------------------------------------------------------------------===//


        


More information about the cfe-commits mailing list