[Mlir-commits] [mlir] [MLIR][OpenMP] Group clause operands into structures (PR #86797)

Sergio Afonso llvmlistbot at llvm.org
Thu Mar 28 09:24:42 PDT 2024


https://github.com/skatrak updated https://github.com/llvm/llvm-project/pull/86797

>From c9ee93f1c28e9d4518ec842f66f497bf0911c9d5 Mon Sep 17 00:00:00 2001
From: Sergio Afonso <safonsof at amd.com>
Date: Tue, 26 Mar 2024 16:31:34 +0000
Subject: [PATCH] [MLIR][OpenMP] Group clause operands into structures

This patch introduces a set of composable structures grouping the MLIR operands
associated to each OpenMP clause. This makes it easier to keep the MLIR
representation for the same clause consistent throughout all operations that
accept it.

The relevant clause operand structures are grouped into per-operation
structures using a mixin pattern and used to define new operation constructors.
These constructors can be used to avoid having to get the order of a possibly
large list of operands right.

Missing clauses are documented as TODOs, as well as operands which are part of
the relevant operation's operand structure but cannot be attached to the
associated operation yet, due to missing op arguments to its MLIR definition.

A follow-up patch will update Flang lowering to make use of these structures,
simplifying the passing of information from clause processing to operation-
generating functions and also simplifying the creation of operations through
the use of the new operation constructors.
---
 .../Dialect/OpenMP/OpenMPClauseOperands.h     | 300 ++++++++++++++++++
 .../mlir/Dialect/OpenMP/OpenMPDialect.h       |   7 +-
 mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td |  72 ++++-
 mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp  | 226 ++++++++++++-
 4 files changed, 595 insertions(+), 10 deletions(-)
 create mode 100644 mlir/include/mlir/Dialect/OpenMP/OpenMPClauseOperands.h

diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauseOperands.h b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauseOperands.h
new file mode 100644
index 00000000000000..6454076f7593b3
--- /dev/null
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauseOperands.h
@@ -0,0 +1,300 @@
+//===-- OpenMPClauseOperands.h ----------------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file declares the structures defining MLIR operands associated with each
+// OpenMP clause, and structures grouping the appropriate operands for each
+// construct.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_OPENMP_OPENMPCLAUSEOPERANDS_H_
+#define MLIR_DIALECT_OPENMP_OPENMPCLAUSEOPERANDS_H_
+
+#include "mlir/IR/BuiltinAttributes.h"
+#include "llvm/ADT/SmallVector.h"
+
+#include "mlir/Dialect/OpenMP/OpenMPOpsEnums.h.inc"
+
+#define GET_ATTRDEF_CLASSES
+#include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.h.inc"
+
+namespace mlir {
+namespace omp {
+
+//===----------------------------------------------------------------------===//
+// Mixin structures defining MLIR operands associated with each OpenMP clause.
+//===----------------------------------------------------------------------===//
+
+struct AlignedClauseOps {
+  llvm::SmallVector<Value> alignedVars;
+  llvm::SmallVector<Attribute> alignmentAttrs;
+};
+
+struct AllocateClauseOps {
+  llvm::SmallVector<Value> allocatorVars, allocateVars;
+};
+
+struct CollapseClauseOps {
+  llvm::SmallVector<Value> loopLBVar, loopUBVar, loopStepVar;
+};
+
+struct CopyprivateClauseOps {
+  llvm::SmallVector<Value> copyprivateVars;
+  llvm::SmallVector<Attribute> copyprivateFuncs;
+};
+
+struct DependClauseOps {
+  llvm::SmallVector<Attribute> dependTypeAttrs;
+  llvm::SmallVector<Value> dependVars;
+};
+
+struct DeviceClauseOps {
+  Value deviceVar;
+};
+
+struct DeviceTypeClauseOps {
+  // The default capture type.
+  DeclareTargetDeviceType deviceType = DeclareTargetDeviceType::any;
+};
+
+struct DistScheduleClauseOps {
+  UnitAttr distScheduleStaticAttr;
+  Value distScheduleChunkSizeVar;
+};
+
+struct DoacrossClauseOps {
+  llvm::SmallVector<Value> doacrossVectorVars;
+  ClauseDependAttr doacrossDependTypeAttr;
+  IntegerAttr doacrossNumLoopsAttr;
+};
+
+struct FinalClauseOps {
+  Value finalVar;
+};
+
+struct GrainsizeClauseOps {
+  Value grainsizeVar;
+};
+
+struct HintClauseOps {
+  IntegerAttr hintAttr;
+};
+
+struct IfClauseOps {
+  Value ifVar;
+};
+
+struct InReductionClauseOps {
+  llvm::SmallVector<Value> inReductionVars;
+  llvm::SmallVector<Attribute> inReductionDeclSymbols;
+};
+
+struct LinearClauseOps {
+  llvm::SmallVector<Value> linearVars, linearStepVars;
+};
+
+struct LoopRelatedOps {
+  UnitAttr loopInclusiveAttr;
+};
+
+struct MapClauseOps {
+  llvm::SmallVector<Value> mapVars;
+};
+
+struct MergeableClauseOps {
+  UnitAttr mergeableAttr;
+};
+
+struct NameClauseOps {
+  StringAttr nameAttr;
+};
+
+struct NogroupClauseOps {
+  UnitAttr nogroupAttr;
+};
+
+struct NontemporalClauseOps {
+  llvm::SmallVector<Value> nontemporalVars;
+};
+
+struct NowaitClauseOps {
+  UnitAttr nowaitAttr;
+};
+
+struct NumTasksClauseOps {
+  Value numTasksVar;
+};
+
+struct NumTeamsClauseOps {
+  Value numTeamsLowerVar, numTeamsUpperVar;
+};
+
+struct NumThreadsClauseOps {
+  Value numThreadsVar;
+};
+
+struct OrderClauseOps {
+  ClauseOrderKindAttr orderAttr;
+};
+
+struct OrderedClauseOps {
+  IntegerAttr orderedAttr;
+};
+
+struct ParallelizationLevelClauseOps {
+  UnitAttr parLevelSimdAttr;
+};
+
+struct PriorityClauseOps {
+  Value priorityVar;
+};
+
+struct PrivateClauseOps {
+  // SSA values that correspond to "original" values being privatized.
+  // They refer to the SSA value outside the OpenMP region from which a clone is
+  // created inside the region.
+  llvm::SmallVector<Value> privateVars;
+  // The list of symbols referring to delayed privatizer ops (i.e. `omp.private`
+  // ops).
+  llvm::SmallVector<Attribute> privatizers;
+};
+
+struct ProcBindClauseOps {
+  ClauseProcBindKindAttr procBindKindAttr;
+};
+
+struct ReductionClauseOps {
+  llvm::SmallVector<Value> reductionVars;
+  llvm::SmallVector<Attribute> reductionDeclSymbols;
+  UnitAttr reductionByRefAttr;
+};
+
+struct SafelenClauseOps {
+  IntegerAttr safelenAttr;
+};
+
+struct ScheduleClauseOps {
+  ClauseScheduleKindAttr scheduleValAttr;
+  ScheduleModifierAttr scheduleModAttr;
+  Value scheduleChunkVar;
+  UnitAttr scheduleSimdAttr;
+};
+
+struct SimdlenClauseOps {
+  IntegerAttr simdlenAttr;
+};
+
+struct TaskReductionClauseOps {
+  llvm::SmallVector<Value> taskReductionVars;
+  llvm::SmallVector<Attribute> taskReductionDeclSymbols;
+};
+
+struct ThreadLimitClauseOps {
+  Value threadLimitVar;
+};
+
+struct UntiedClauseOps {
+  UnitAttr untiedAttr;
+};
+
+struct UseDeviceClauseOps {
+  llvm::SmallVector<Value> useDevicePtrVars, useDeviceAddrVars;
+};
+
+//===----------------------------------------------------------------------===//
+// Structures defining clause operands associated with each OpenMP leaf
+// construct.
+//
+// These mirror the arguments expected by the corresponding OpenMP MLIR ops.
+//===----------------------------------------------------------------------===//
+
+namespace detail {
+template <typename... Mixins>
+struct Clauses : public Mixins... {};
+} // namespace detail
+
+using CriticalClauseOps = detail::Clauses<HintClauseOps, NameClauseOps>;
+
+// TODO `indirect` clause.
+using DeclareTargetClauseOps = detail::Clauses<DeviceTypeClauseOps>;
+
+using DistributeClauseOps =
+    detail::Clauses<AllocateClauseOps, DistScheduleClauseOps, OrderClauseOps,
+                    PrivateClauseOps>;
+
+// TODO `filter` clause.
+using MaskedClauseOps = detail::Clauses<>;
+
+using OrderedOpClauseOps = detail::Clauses<DoacrossClauseOps>;
+
+using OrderedRegionClauseOps = detail::Clauses<ParallelizationLevelClauseOps>;
+
+using ParallelClauseOps =
+    detail::Clauses<AllocateClauseOps, IfClauseOps, NumThreadsClauseOps,
+                    PrivateClauseOps, ProcBindClauseOps, ReductionClauseOps>;
+
+using SectionsClauseOps = detail::Clauses<AllocateClauseOps, NowaitClauseOps,
+                                          PrivateClauseOps, ReductionClauseOps>;
+
+// TODO `linear` clause.
+using SimdLoopClauseOps =
+    detail::Clauses<AlignedClauseOps, CollapseClauseOps, IfClauseOps,
+                    LoopRelatedOps, NontemporalClauseOps, OrderClauseOps,
+                    PrivateClauseOps, ReductionClauseOps, SafelenClauseOps,
+                    SimdlenClauseOps>;
+
+using SingleClauseOps = detail::Clauses<AllocateClauseOps, CopyprivateClauseOps,
+                                        NowaitClauseOps, PrivateClauseOps>;
+
+// TODO `defaultmap`, `has_device_addr`, `is_device_ptr`, `uses_allocators`
+// clauses.
+using TargetClauseOps =
+    detail::Clauses<AllocateClauseOps, DependClauseOps, DeviceClauseOps,
+                    IfClauseOps, InReductionClauseOps, MapClauseOps,
+                    NowaitClauseOps, PrivateClauseOps, ReductionClauseOps,
+                    ThreadLimitClauseOps>;
+
+using TargetDataClauseOps = detail::Clauses<DeviceClauseOps, IfClauseOps,
+                                            MapClauseOps, UseDeviceClauseOps>;
+
+using TargetEnterExitUpdateDataClauseOps =
+    detail::Clauses<DependClauseOps, DeviceClauseOps, IfClauseOps, MapClauseOps,
+                    NowaitClauseOps>;
+
+// TODO `affinity`, `detach` clauses.
+using TaskClauseOps =
+    detail::Clauses<AllocateClauseOps, DependClauseOps, FinalClauseOps,
+                    IfClauseOps, InReductionClauseOps, MergeableClauseOps,
+                    PriorityClauseOps, PrivateClauseOps, UntiedClauseOps>;
+
+using TaskgroupClauseOps =
+    detail::Clauses<AllocateClauseOps, TaskReductionClauseOps>;
+
+using TaskloopClauseOps =
+    detail::Clauses<AllocateClauseOps, CollapseClauseOps, FinalClauseOps,
+                    GrainsizeClauseOps, IfClauseOps, InReductionClauseOps,
+                    LoopRelatedOps, MergeableClauseOps, NogroupClauseOps,
+                    NumTasksClauseOps, PriorityClauseOps, PrivateClauseOps,
+                    ReductionClauseOps, UntiedClauseOps>;
+
+using TaskwaitClauseOps = detail::Clauses<DependClauseOps, NowaitClauseOps>;
+
+using TeamsClauseOps =
+    detail::Clauses<AllocateClauseOps, IfClauseOps, NumTeamsClauseOps,
+                    PrivateClauseOps, ReductionClauseOps, ThreadLimitClauseOps>;
+
+using WsloopClauseOps =
+    detail::Clauses<AllocateClauseOps, CollapseClauseOps, LinearClauseOps,
+                    LoopRelatedOps, NowaitClauseOps, OrderClauseOps,
+                    OrderedClauseOps, PrivateClauseOps, ReductionClauseOps,
+                    ScheduleClauseOps>;
+
+} // namespace omp
+} // namespace mlir
+
+#endif // MLIR_DIALECT_OPENMP_OPENMPCLAUSEOPERANDS_H_
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPDialect.h b/mlir/include/mlir/Dialect/OpenMP/OpenMPDialect.h
index 23509c5b607016..c656bdc870976f 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPDialect.h
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPDialect.h
@@ -26,11 +26,10 @@
 #include "mlir/Dialect/OpenMP/OpenMPOpsTypes.h.inc"
 
 #include "mlir/Dialect/OpenMP/OpenMPOpsDialect.h.inc"
-#include "mlir/Dialect/OpenMP/OpenMPOpsEnums.h.inc"
-#include "mlir/Dialect/OpenMP/OpenMPTypeInterfaces.h.inc"
 
-#define GET_ATTRDEF_CLASSES
-#include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.h.inc"
+#include "mlir/Dialect/OpenMP/OpenMPClauseOperands.h"
+
+#include "mlir/Dialect/OpenMP/OpenMPTypeInterfaces.h.inc"
 
 #include "mlir/Dialect/OpenMP/OpenMPInterfaces.h"
 
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index f33942b3c7c02d..2643348d668698 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -287,7 +287,8 @@ def ParallelOp : OpenMP_Op<"parallel", [
   let regions = (region AnyRegion:$region);
 
   let builders = [
-    OpBuilder<(ins CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes)>
+    OpBuilder<(ins CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes)>,
+    OpBuilder<(ins CArg<"const ParallelClauseOps &">:$clauses)>
   ];
   let extraClassDeclaration = [{
     /// Returns the number of reduction variables.
@@ -362,6 +363,10 @@ def TeamsOp : OpenMP_Op<"teams", [
 
   let regions = (region AnyRegion:$region);
 
+  let builders = [
+    OpBuilder<(ins CArg<"const TeamsClauseOps &">:$clauses)>
+  ];
+
   let assemblyFormat = [{
     oilist(
       `num_teams` `(` ( $num_teams_lower^ `:` type($num_teams_lower) )? `to`
@@ -451,6 +456,10 @@ def SectionsOp : OpenMP_Op<"sections", [AttrSizedOperandSegments,
 
   let regions = (region SizedRegion<1>:$region);
 
+  let builders = [
+    OpBuilder<(ins CArg<"const SectionsClauseOps &">:$clauses)>
+  ];
+
   let assemblyFormat = [{
     oilist( `reduction` `(`
               custom<ReductionVarList>(
@@ -495,6 +504,10 @@ def SingleOp : OpenMP_Op<"single", [AttrSizedOperandSegments]> {
 
   let regions = (region AnyRegion:$region);
 
+  let builders = [
+    OpBuilder<(ins CArg<"const SingleClauseOps &">:$clauses)>
+  ];
+
   let assemblyFormat = [{
     oilist(`allocate` `(`
               custom<AllocateAndAllocator>(
@@ -601,6 +614,7 @@ def WsloopOp : OpenMP_Op<"wsloop", [AttrSizedOperandSegments,
     OpBuilder<(ins "ValueRange":$lowerBound, "ValueRange":$upperBound,
                "ValueRange":$step,
                CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes)>,
+    OpBuilder<(ins CArg<"const WsloopClauseOps &">:$clauses)>
   ];
 
   let regions = (region AnyRegion:$region);
@@ -698,6 +712,11 @@ def SimdLoopOp : OpenMP_Op<"simdloop", [AttrSizedOperandSegments,
      );
 
   let regions = (region AnyRegion:$region);
+
+  let builders = [
+    OpBuilder<(ins CArg<"const SimdLoopClauseOps &">:$clauses)>
+  ];
+
   let assemblyFormat = [{
     oilist(`aligned` `(`
               custom<AlignedClause>($aligned_vars, type($aligned_vars),
@@ -781,6 +800,10 @@ def DistributeOp : OpenMP_Op<"distribute", [AttrSizedOperandSegments,
 
   let regions = (region AnyRegion:$region);
 
+  let builders = [
+    OpBuilder<(ins CArg<"const DistributeClauseOps &">:$clauses)>
+  ];
+
   let assemblyFormat = [{
     oilist(`dist_schedule_static` $dist_schedule_static
           |`chunk_size` `(` $chunk_size `:` type($chunk_size) `)`
@@ -883,6 +906,9 @@ def TaskOp : OpenMP_Op<"task", [AttrSizedOperandSegments,
                        Variadic<AnyType>:$allocate_vars,
                        Variadic<AnyType>:$allocators_vars);
   let regions = (region AnyRegion:$region);
+  let builders = [
+    OpBuilder<(ins CArg<"const TaskClauseOps &">:$clauses)>
+  ];
   let assemblyFormat = [{
     oilist(`if` `(` $if_expr `)`
           |`final` `(` $final_expr `)`
@@ -1037,6 +1063,10 @@ def TaskloopOp : OpenMP_Op<"taskloop", [AttrSizedOperandSegments,
 
   let regions = (region AnyRegion:$region);
 
+  let builders = [
+    OpBuilder<(ins CArg<"const TaskloopClauseOps &">:$clauses)>
+  ];
+
   let assemblyFormat = [{
     oilist(`if` `(` $if_expr `)`
           |`final` `(` $final_expr `)`
@@ -1106,6 +1136,10 @@ def TaskgroupOp : OpenMP_Op<"taskgroup", [AttrSizedOperandSegments,
 
   let regions = (region AnyRegion:$region);
 
+  let builders = [
+    OpBuilder<(ins CArg<"const TaskgroupClauseOps &">:$clauses)>
+  ];
+
   let assemblyFormat = [{
     oilist(`task_reduction` `(`
               custom<ReductionVarList>(
@@ -1432,6 +1466,10 @@ def TargetDataOp: OpenMP_Op<"target_data", [AttrSizedOperandSegments,
 
   let regions = (region AnyRegion:$region);
 
+  let builders = [
+    OpBuilder<(ins CArg<"const TargetDataClauseOps &">:$clauses)>
+  ];
+
   let assemblyFormat = [{
     oilist(`if` `(` $if_expr `:` type($if_expr) `)`
     | `device` `(` $device `:` type($device) `)`
@@ -1486,6 +1524,10 @@ def TargetEnterDataOp: OpenMP_Op<"target_enter_data",
                        UnitAttr:$nowait,
                        Variadic<AnyType>:$map_operands);
 
+  let builders = [
+    OpBuilder<(ins CArg<"const TargetEnterExitUpdateDataClauseOps &">:$clauses)>
+  ];
+
   let assemblyFormat = [{
     oilist(`if` `(` $if_expr `:` type($if_expr) `)`
     | `device` `(` $device `:` type($device) `)`
@@ -1540,6 +1582,10 @@ def TargetExitDataOp: OpenMP_Op<"target_exit_data",
                        UnitAttr:$nowait,
                        Variadic<AnyType>:$map_operands);
 
+  let builders = [
+    OpBuilder<(ins CArg<"const TargetEnterExitUpdateDataClauseOps &">:$clauses)>
+  ];
+
   let assemblyFormat = [{
     oilist(`if` `(` $if_expr `:` type($if_expr) `)`
     | `device` `(` $device `:` type($device) `)`
@@ -1596,6 +1642,10 @@ def TargetUpdateOp: OpenMP_Op<"target_update", [AttrSizedOperandSegments,
                        UnitAttr:$nowait,
                        Variadic<OpenMP_PointerLikeType>:$map_operands);
 
+  let builders = [
+    OpBuilder<(ins CArg<"const TargetEnterExitUpdateDataClauseOps &">:$clauses)>
+  ];
+
   let assemblyFormat = [{
     oilist(`if` `(` $if_expr `:` type($if_expr) `)`
     | `device` `(` $device `:` type($device) `)`
@@ -1649,6 +1699,10 @@ def TargetOp : OpenMP_Op<"target", [IsolatedFromAbove, MapClauseOwningOpInterfac
 
   let regions = (region AnyRegion:$region);
 
+  let builders = [
+    OpBuilder<(ins CArg<"const TargetClauseOps &">:$clauses)>
+  ];
+
   let assemblyFormat = [{
     oilist( `if` `(` $if_expr `)`
     | `device` `(` $device `:` type($device) `)`
@@ -1693,6 +1747,10 @@ def CriticalDeclareOp : OpenMP_Op<"critical.declare", [Symbol]> {
   let arguments = (ins SymbolNameAttr:$sym_name,
                        DefaultValuedAttr<I64Attr, "0">:$hint_val);
 
+  let builders = [
+    OpBuilder<(ins CArg<"const CriticalClauseOps &">:$clauses)>
+  ];
+
   let assemblyFormat = [{
     $sym_name oilist(`hint` `(` custom<SynchronizationHint>($hint_val) `)`)
     attr-dict
@@ -1773,6 +1831,10 @@ def OrderedOp : OpenMP_Op<"ordered"> {
              ConfinedAttr<OptionalAttr<I64Attr>, [IntMinValue<0>]>:$num_loops_val,
              Variadic<AnyType>:$depend_vec_vars);
 
+  let builders = [
+    OpBuilder<(ins CArg<"const OrderedOpClauseOps &">:$clauses)>
+  ];
+
   let assemblyFormat = [{
     ( `depend_type` `` $depend_type_val^ )?
     ( `depend_vec` `(` $depend_vec_vars^ `:` type($depend_vec_vars) `)` )?
@@ -1797,6 +1859,10 @@ def OrderedRegionOp : OpenMP_Op<"ordered.region"> {
 
   let regions = (region AnyRegion:$region);
 
+  let builders = [
+    OpBuilder<(ins CArg<"const OrderedRegionClauseOps &">:$clauses)>
+  ];
+
   let assemblyFormat = [{ ( `simd` $simd^ )? $region attr-dict}];
   let hasVerifier = 1;
 }
@@ -1812,6 +1878,10 @@ def TaskwaitOp : OpenMP_Op<"taskwait"> {
     of the current task.
   }];
 
+  let builders = [
+    OpBuilder<(ins CArg<"const TaskwaitClauseOps &">:$clauses)>
+  ];
+
   let assemblyFormat = "attr-dict";
 }
 
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index bf5875071e0dc4..28869c1ddfb3fd 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -41,6 +41,11 @@
 using namespace mlir;
 using namespace mlir::omp;
 
+static ArrayAttr makeArrayAttr(MLIRContext *context,
+                               llvm::ArrayRef<Attribute> attrs) {
+  return attrs.empty() ? nullptr : ArrayAttr::get(context, attrs);
+}
+
 namespace {
 struct MemRefPointerLikeModel
     : public PointerLikeType::ExternalModel<MemRefPointerLikeModel,
@@ -1161,6 +1166,17 @@ static LogicalResult verifyMapClause(Operation *op, OperandRange mapOperands) {
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// TargetDataOp
+//===----------------------------------------------------------------------===//
+
+void TargetDataOp::build(OpBuilder &builder, OperationState &state,
+                         const TargetDataClauseOps &clauses) {
+  TargetDataOp::build(builder, state, clauses.ifVar, clauses.deviceVar,
+                      clauses.useDevicePtrVars, clauses.useDeviceAddrVars,
+                      clauses.mapVars);
+}
+
 LogicalResult TargetDataOp::verify() {
   if (getMapOperands().empty() && getUseDevicePtr().empty() &&
       getUseDeviceAddr().empty()) {
@@ -1170,6 +1186,20 @@ LogicalResult TargetDataOp::verify() {
   return verifyMapClause(*this, getMapOperands());
 }
 
+//===----------------------------------------------------------------------===//
+// TargetEnterDataOp
+//===----------------------------------------------------------------------===//
+
+void TargetEnterDataOp::build(
+    OpBuilder &builder, OperationState &state,
+    const TargetEnterExitUpdateDataClauseOps &clauses) {
+  MLIRContext *ctx = builder.getContext();
+  TargetEnterDataOp::build(builder, state, clauses.ifVar, clauses.deviceVar,
+                           makeArrayAttr(ctx, clauses.dependTypeAttrs),
+                           clauses.dependVars, clauses.nowaitAttr,
+                           clauses.mapVars);
+}
+
 LogicalResult TargetEnterDataOp::verify() {
   LogicalResult verifyDependVars =
       verifyDependVarList(*this, getDepends(), getDependVars());
@@ -1177,6 +1207,20 @@ LogicalResult TargetEnterDataOp::verify() {
                                   : verifyMapClause(*this, getMapOperands());
 }
 
+//===----------------------------------------------------------------------===//
+// TargetExitDataOp
+//===----------------------------------------------------------------------===//
+
+void TargetExitDataOp::build(
+    OpBuilder &builder, OperationState &state,
+    const TargetEnterExitUpdateDataClauseOps &clauses) {
+  MLIRContext *ctx = builder.getContext();
+  TargetExitDataOp::build(builder, state, clauses.ifVar, clauses.deviceVar,
+                          makeArrayAttr(ctx, clauses.dependTypeAttrs),
+                          clauses.dependVars, clauses.nowaitAttr,
+                          clauses.mapVars);
+}
+
 LogicalResult TargetExitDataOp::verify() {
   LogicalResult verifyDependVars =
       verifyDependVarList(*this, getDepends(), getDependVars());
@@ -1184,6 +1228,19 @@ LogicalResult TargetExitDataOp::verify() {
                                   : verifyMapClause(*this, getMapOperands());
 }
 
+//===----------------------------------------------------------------------===//
+// TargetUpdateOp
+//===----------------------------------------------------------------------===//
+
+void TargetUpdateOp::build(OpBuilder &builder, OperationState &state,
+                           const TargetEnterExitUpdateDataClauseOps &clauses) {
+  MLIRContext *ctx = builder.getContext();
+  TargetUpdateOp::build(builder, state, clauses.ifVar, clauses.deviceVar,
+                        makeArrayAttr(ctx, clauses.dependTypeAttrs),
+                        clauses.dependVars, clauses.nowaitAttr,
+                        clauses.mapVars);
+}
+
 LogicalResult TargetUpdateOp::verify() {
   LogicalResult verifyDependVars =
       verifyDependVarList(*this, getDepends(), getDependVars());
@@ -1191,6 +1248,22 @@ LogicalResult TargetUpdateOp::verify() {
                                   : verifyMapClause(*this, getMapOperands());
 }
 
+//===----------------------------------------------------------------------===//
+// TargetOp
+//===----------------------------------------------------------------------===//
+
+void TargetOp::build(OpBuilder &builder, OperationState &state,
+                     const TargetClauseOps &clauses) {
+  MLIRContext *ctx = builder.getContext();
+  // TODO Store clauses in op: allocateVars, allocatorVars, inReductionVars,
+  // inReductionDeclSymbols, privateVars, privatizers, reductionVars,
+  // reductionByRefAttr, reductionDeclSymbols.
+  TargetOp::build(builder, state, clauses.ifVar, clauses.deviceVar,
+                  clauses.threadLimitVar,
+                  makeArrayAttr(ctx, clauses.dependTypeAttrs),
+                  clauses.dependVars, clauses.nowaitAttr, clauses.mapVars);
+}
+
 LogicalResult TargetOp::verify() {
   LogicalResult verifyDependVars =
       verifyDependVarList(*this, getDepends(), getDependVars());
@@ -1213,6 +1286,17 @@ void ParallelOp::build(OpBuilder &builder, OperationState &state,
   state.addAttributes(attributes);
 }
 
+void ParallelOp::build(OpBuilder &builder, OperationState &state,
+                       const ParallelClauseOps &clauses) {
+  MLIRContext *ctx = builder.getContext();
+  ParallelOp::build(
+      builder, state, clauses.ifVar, clauses.numThreadsVar,
+      clauses.allocateVars, clauses.allocatorVars, clauses.reductionVars,
+      makeArrayAttr(ctx, clauses.reductionDeclSymbols),
+      clauses.procBindKindAttr, clauses.privateVars,
+      makeArrayAttr(ctx, clauses.privatizers), clauses.reductionByRefAttr);
+}
+
 template <typename OpType>
 static LogicalResult verifyPrivateVarList(OpType &op) {
   auto privateVars = op.getPrivateVars();
@@ -1280,6 +1364,17 @@ static bool opInGlobalImplicitParallelRegion(Operation *op) {
   return true;
 }
 
+void TeamsOp::build(OpBuilder &builder, OperationState &state,
+                    const TeamsClauseOps &clauses) {
+  MLIRContext *ctx = builder.getContext();
+  // TODO Store clauses in op: reductionByRefAttr, privateVars, privatizers.
+  TeamsOp::build(builder, state, clauses.numTeamsLowerVar,
+                 clauses.numTeamsUpperVar, clauses.ifVar,
+                 clauses.threadLimitVar, clauses.allocateVars,
+                 clauses.allocatorVars, clauses.reductionVars,
+                 makeArrayAttr(ctx, clauses.reductionDeclSymbols));
+}
+
 LogicalResult TeamsOp::verify() {
   // Check parent region
   // TODO If nested inside of a target region, also check that it does not
@@ -1312,9 +1407,19 @@ LogicalResult TeamsOp::verify() {
 }
 
 //===----------------------------------------------------------------------===//
-// Verifier for SectionsOp
+// SectionsOp
 //===----------------------------------------------------------------------===//
 
+void SectionsOp::build(OpBuilder &builder, OperationState &state,
+                       const SectionsClauseOps &clauses) {
+  MLIRContext *ctx = builder.getContext();
+  // TODO Store clauses in op: reductionByRefAttr, privateVars, privatizers.
+  SectionsOp::build(builder, state, clauses.reductionVars,
+                    makeArrayAttr(ctx, clauses.reductionDeclSymbols),
+                    clauses.allocateVars, clauses.allocatorVars,
+                    clauses.nowaitAttr);
+}
+
 LogicalResult SectionsOp::verify() {
   if (getAllocateVars().size() != getAllocatorsVars().size())
     return emitError(
@@ -1334,6 +1439,20 @@ LogicalResult SectionsOp::verifyRegions() {
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// SingleOp
+//===----------------------------------------------------------------------===//
+
+void SingleOp::build(OpBuilder &builder, OperationState &state,
+                     const SingleClauseOps &clauses) {
+  MLIRContext *ctx = builder.getContext();
+  // TODO Store clauses in op: privateVars, privatizers.
+  SingleOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
+                  clauses.copyprivateVars,
+                  makeArrayAttr(ctx, clauses.copyprivateFuncs),
+                  clauses.nowaitAttr);
+}
+
 LogicalResult SingleOp::verify() {
   // Check for allocate clause restrictions
   if (getAllocateVars().size() != getAllocatorsVars().size())
@@ -1481,9 +1600,21 @@ void printLoopControl(OpAsmPrinter &p, Operation *op, Region &region,
 }
 
 //===----------------------------------------------------------------------===//
-// Verifier for Simd construct [2.9.3.1]
+// Simd construct [2.9.3.1]
 //===----------------------------------------------------------------------===//
 
+void SimdLoopOp::build(OpBuilder &builder, OperationState &state,
+                       const SimdLoopClauseOps &clauses) {
+  MLIRContext *ctx = builder.getContext();
+  // TODO Store clauses in op: privateVars, reductionByRefAttr, reductionVars,
+  // privatizers, reductionDeclSymbols.
+  SimdLoopOp::build(
+      builder, state, clauses.loopLBVar, clauses.loopUBVar, clauses.loopStepVar,
+      clauses.alignedVars, makeArrayAttr(ctx, clauses.alignmentAttrs),
+      clauses.ifVar, clauses.nontemporalVars, clauses.orderAttr,
+      clauses.simdlenAttr, clauses.safelenAttr, clauses.loopInclusiveAttr);
+}
+
 LogicalResult SimdLoopOp::verify() {
   if (this->getLowerBound().empty()) {
     return emitOpError() << "empty lowerbound for simd loop operation";
@@ -1504,9 +1635,17 @@ LogicalResult SimdLoopOp::verify() {
 }
 
 //===----------------------------------------------------------------------===//
-// Verifier for Distribute construct [2.9.4.1]
+// Distribute construct [2.9.4.1]
 //===----------------------------------------------------------------------===//
 
+void DistributeOp::build(OpBuilder &builder, OperationState &state,
+                         const DistributeClauseOps &clauses) {
+  // TODO Store clauses in op: privateVars, privatizers.
+  DistributeOp::build(builder, state, clauses.distScheduleStaticAttr,
+                      clauses.distScheduleChunkSizeVar, clauses.allocateVars,
+                      clauses.allocatorVars, clauses.orderAttr);
+}
+
 LogicalResult DistributeOp::verify() {
   if (this->getChunkSize() && !this->getDistScheduleStatic())
     return emitOpError() << "chunk size set without "
@@ -1607,6 +1746,19 @@ LogicalResult ReductionOp::verify() {
 //===----------------------------------------------------------------------===//
 // TaskOp
 //===----------------------------------------------------------------------===//
+
+void TaskOp::build(OpBuilder &builder, OperationState &state,
+                   const TaskClauseOps &clauses) {
+  MLIRContext *ctx = builder.getContext();
+  // TODO Store clauses in op: privateVars, privatizers.
+  TaskOp::build(
+      builder, state, clauses.ifVar, clauses.finalVar, clauses.untiedAttr,
+      clauses.mergeableAttr, clauses.inReductionVars,
+      makeArrayAttr(ctx, clauses.inReductionDeclSymbols), clauses.priorityVar,
+      makeArrayAttr(ctx, clauses.dependTypeAttrs), clauses.dependVars,
+      clauses.allocateVars, clauses.allocatorVars);
+}
+
 LogicalResult TaskOp::verify() {
   LogicalResult verifyDependVars =
       verifyDependVarList(*this, getDepends(), getDependVars());
@@ -1619,6 +1771,15 @@ LogicalResult TaskOp::verify() {
 //===----------------------------------------------------------------------===//
 // TaskgroupOp
 //===----------------------------------------------------------------------===//
+
+void TaskgroupOp::build(OpBuilder &builder, OperationState &state,
+                        const TaskgroupClauseOps &clauses) {
+  MLIRContext *ctx = builder.getContext();
+  TaskgroupOp::build(builder, state, clauses.taskReductionVars,
+                     makeArrayAttr(ctx, clauses.taskReductionDeclSymbols),
+                     clauses.allocateVars, clauses.allocatorVars);
+}
+
 LogicalResult TaskgroupOp::verify() {
   return verifyReductionVarList(*this, getTaskReductions(),
                                 getTaskReductionVars());
@@ -1627,6 +1788,21 @@ LogicalResult TaskgroupOp::verify() {
 //===----------------------------------------------------------------------===//
 // TaskloopOp
 //===----------------------------------------------------------------------===//
+
+void TaskloopOp::build(OpBuilder &builder, OperationState &state,
+                       const TaskloopClauseOps &clauses) {
+  MLIRContext *ctx = builder.getContext();
+  // TODO Store clauses in op: reductionByRefAttr, privateVars, privatizers.
+  TaskloopOp::build(
+      builder, state, clauses.loopLBVar, clauses.loopUBVar, clauses.loopStepVar,
+      clauses.loopInclusiveAttr, clauses.ifVar, clauses.finalVar,
+      clauses.untiedAttr, clauses.mergeableAttr, clauses.inReductionVars,
+      makeArrayAttr(ctx, clauses.inReductionDeclSymbols), clauses.reductionVars,
+      makeArrayAttr(ctx, clauses.reductionDeclSymbols), clauses.priorityVar,
+      clauses.allocateVars, clauses.allocatorVars, clauses.grainsizeVar,
+      clauses.numTasksVar, clauses.nogroupAttr);
+}
+
 SmallVector<Value> TaskloopOp::getAllReductionVars() {
   SmallVector<Value> allReductionNvars(getInReductionVars().begin(),
                                        getInReductionVars().end());
@@ -1680,14 +1856,33 @@ void WsloopOp::build(OpBuilder &builder, OperationState &state,
   state.addAttributes(attributes);
 }
 
+void WsloopOp::build(OpBuilder &builder, OperationState &state,
+                     const WsloopClauseOps &clauses) {
+  MLIRContext *ctx = builder.getContext();
+  // TODO Store clauses in op: allocateVars, allocatorVars, privateVars,
+  // privatizers.
+  WsloopOp::build(
+      builder, state, clauses.loopLBVar, clauses.loopUBVar, clauses.loopStepVar,
+      clauses.linearVars, clauses.linearStepVars, clauses.reductionVars,
+      makeArrayAttr(ctx, clauses.reductionDeclSymbols), clauses.scheduleValAttr,
+      clauses.scheduleChunkVar, clauses.scheduleModAttr,
+      clauses.scheduleSimdAttr, clauses.nowaitAttr, clauses.reductionByRefAttr,
+      clauses.orderedAttr, clauses.orderAttr, clauses.loopInclusiveAttr);
+}
+
 LogicalResult WsloopOp::verify() {
   return verifyReductionVarList(*this, getReductions(), getReductionVars());
 }
 
 //===----------------------------------------------------------------------===//
-// Verifier for critical construct (2.17.1)
+// Critical construct (2.17.1)
 //===----------------------------------------------------------------------===//
 
+void CriticalDeclareOp::build(OpBuilder &builder, OperationState &state,
+                              const CriticalClauseOps &clauses) {
+  CriticalDeclareOp::build(builder, state, clauses.nameAttr, clauses.hintAttr);
+}
+
 LogicalResult CriticalDeclareOp::verify() {
   return verifySynchronizationHint(*this, getHintVal());
 }
@@ -1707,9 +1902,15 @@ LogicalResult CriticalOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
 }
 
 //===----------------------------------------------------------------------===//
-// Verifier for ordered construct
+// Ordered construct
 //===----------------------------------------------------------------------===//
 
+void OrderedOp::build(OpBuilder &builder, OperationState &state,
+                      const OrderedOpClauseOps &clauses) {
+  OrderedOp::build(builder, state, clauses.doacrossDependTypeAttr,
+                   clauses.doacrossNumLoopsAttr, clauses.doacrossVectorVars);
+}
+
 LogicalResult OrderedOp::verify() {
   auto container = (*this)->getParentOfType<WsloopOp>();
   if (!container || !container.getOrderedValAttr() ||
@@ -1726,6 +1927,11 @@ LogicalResult OrderedOp::verify() {
   return success();
 }
 
+void OrderedRegionOp::build(OpBuilder &builder, OperationState &state,
+                            const OrderedRegionClauseOps &clauses) {
+  OrderedRegionOp::build(builder, state, clauses.parLevelSimdAttr);
+}
+
 LogicalResult OrderedRegionOp::verify() {
   // TODO: The code generation for ordered simd directive is not supported yet.
   if (getSimd())
@@ -1742,6 +1948,16 @@ LogicalResult OrderedRegionOp::verify() {
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// TaskwaitOp
+//===----------------------------------------------------------------------===//
+
+void TaskwaitOp::build(OpBuilder &builder, OperationState &state,
+                       const TaskwaitClauseOps &clauses) {
+  // TODO Store clauses in op: dependTypeAttrs, dependVars, nowaitAttr.
+  TaskwaitOp::build(builder, state);
+}
+
 //===----------------------------------------------------------------------===//
 // Verifier for AtomicReadOp
 //===----------------------------------------------------------------------===//



More information about the Mlir-commits mailing list