[flang-commits] [flang] [mlir] [MLIR][OpenMP] Automate operand structure definition (PR #99508)

Sergio Afonso via flang-commits flang-commits at lists.llvm.org
Wed Sep 11 03:45:36 PDT 2024


https://github.com/skatrak updated https://github.com/llvm/llvm-project/pull/99508

>From 0cbfcd4876d881524b79f4cf27f3c86852eed14f Mon Sep 17 00:00:00 2001
From: Sergio Afonso <safonsof at amd.com>
Date: Wed, 17 Jul 2024 13:26:09 +0100
Subject: [PATCH] [MLIR][OpenMP] Automate operand structure definition

This patch adds the "gen-openmp-clause-ops" `mlir-tblgen` generator to produce
the structure definitions previously in OpenMPClauseOperands.h automatically
from the information contained in OpenMPOps.td and OpenMPClauses.td.

The original header is maintained to enable the definition of similar
structures that are not directly related to any single `OpenMP_Clause` or
`OpenMP_Op` tablegen definition.
---
 flang/lib/Lower/OpenMP/ClauseProcessor.cpp    |   6 +-
 flang/lib/Lower/OpenMP/ClauseProcessor.h      |   2 +-
 .../mlir/Dialect/OpenMP/CMakeLists.txt        |   1 +
 .../Dialect/OpenMP/OpenMPClauseOperands.h     | 292 +-----------------
 mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp  |  34 +-
 mlir/test/mlir-tblgen/openmp-clause-ops.td    |  86 ++++++
 mlir/tools/mlir-tblgen/OmpOpGen.cpp           | 215 ++++++++++++-
 7 files changed, 319 insertions(+), 317 deletions(-)
 create mode 100644 mlir/test/mlir-tblgen/openmp-clause-ops.td

diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
index 3f54234b176e3f..f336d213cc8620 100644
--- a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
+++ b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
@@ -181,7 +181,7 @@ static void addUseDeviceClause(
 
 static void convertLoopBounds(lower::AbstractConverter &converter,
                               mlir::Location loc,
-                              mlir::omp::LoopRelatedOps &result,
+                              mlir::omp::LoopRelatedClauseOps &result,
                               std::size_t loopVarTypeSize) {
   fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
   // The types of lower bound, upper bound, and step are converted into the
@@ -203,7 +203,7 @@ static void convertLoopBounds(lower::AbstractConverter &converter,
 
 bool ClauseProcessor::processCollapse(
     mlir::Location currentLocation, lower::pft::Evaluation &eval,
-    mlir::omp::LoopRelatedOps &result,
+    mlir::omp::LoopRelatedClauseOps &result,
     llvm::SmallVectorImpl<const semantics::Symbol *> &iv) const {
   bool found = false;
   fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
@@ -855,7 +855,7 @@ bool ClauseProcessor::processIf(
     // Assume that, at most, a single 'if' clause will be applicable to the
     // given directive.
     if (operand) {
-      result.ifVar = operand;
+      result.ifExpr = operand;
       found = true;
     }
   });
diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.h b/flang/lib/Lower/OpenMP/ClauseProcessor.h
index f6b319c726a2d1..8d02d368f4ee04 100644
--- a/flang/lib/Lower/OpenMP/ClauseProcessor.h
+++ b/flang/lib/Lower/OpenMP/ClauseProcessor.h
@@ -55,7 +55,7 @@ class ClauseProcessor {
   // 'Unique' clauses: They can appear at most once in the clause list.
   bool
   processCollapse(mlir::Location currentLocation, lower::pft::Evaluation &eval,
-                  mlir::omp::LoopRelatedOps &result,
+                  mlir::omp::LoopRelatedClauseOps &result,
                   llvm::SmallVectorImpl<const semantics::Symbol *> &iv) const;
   bool processDevice(lower::StatementContext &stmtCtx,
                      mlir::omp::DeviceClauseOps &result) const;
diff --git a/mlir/include/mlir/Dialect/OpenMP/CMakeLists.txt b/mlir/include/mlir/Dialect/OpenMP/CMakeLists.txt
index dd349d1392e7bf..a65c6b1d3c96bc 100644
--- a/mlir/include/mlir/Dialect/OpenMP/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/OpenMP/CMakeLists.txt
@@ -17,6 +17,7 @@ mlir_tablegen(OpenMPOpsDialect.h.inc -gen-dialect-decls -dialect=omp)
 mlir_tablegen(OpenMPOpsDialect.cpp.inc -gen-dialect-defs -dialect=omp)
 mlir_tablegen(OpenMPOps.h.inc -gen-op-decls)
 mlir_tablegen(OpenMPOps.cpp.inc -gen-op-defs)
+mlir_tablegen(OpenMPClauseOps.h.inc -gen-openmp-clause-ops)
 mlir_tablegen(OpenMPOpsTypes.h.inc -gen-typedef-decls -typedefs-dialect=omp)
 mlir_tablegen(OpenMPOpsTypes.cpp.inc -gen-typedef-defs -typedefs-dialect=omp)
 mlir_tablegen(OpenMPOpsEnums.h.inc -gen-enum-decls)
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauseOperands.h b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauseOperands.h
index 38e4d8f245e4fa..1247a871f93c6d 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauseOperands.h
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauseOperands.h
@@ -23,303 +23,31 @@
 #define GET_ATTRDEF_CLASSES
 #include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.h.inc"
 
+#include "mlir/Dialect/OpenMP/OpenMPClauseOps.h.inc"
+
 namespace mlir {
 namespace omp {
 
 //===----------------------------------------------------------------------===//
-// Mixin structures defining MLIR operands associated with each OpenMP clause.
+// Extra clause operand structures.
 //===----------------------------------------------------------------------===//
 
-struct AlignedClauseOps {
-  llvm::SmallVector<Value> alignedVars;
-  llvm::SmallVector<Attribute> alignments;
-};
-
-struct AllocateClauseOps {
-  llvm::SmallVector<Value> allocateVars, allocatorVars;
-};
-
-struct CancelDirectiveNameClauseOps {
-  ClauseCancellationConstructTypeAttr cancelDirective;
-};
-
-struct CopyprivateClauseOps {
-  llvm::SmallVector<Value> copyprivateVars;
-  llvm::SmallVector<Attribute> copyprivateSyms;
-};
-
-struct CriticalNameClauseOps {
-  /// This field has a generic name because it's mirroring the `sym_name`
-  /// argument of the `OpenMP_CriticalNameClause` tablegen definition. That one
-  /// can't be renamed to anything more specific because the `sym_name` name is
-  /// a requirement of the `Symbol` MLIR trait associated with that clause.
-  StringAttr symName;
-};
-
-struct DependClauseOps {
-  llvm::SmallVector<Attribute> dependKinds;
-  llvm::SmallVector<Value> dependVars;
-};
-
-struct DeviceClauseOps {
-  Value device;
-};
-
 struct DeviceTypeClauseOps {
-  // The default capture type.
+  /// The default capture type.
   DeclareTargetDeviceType deviceType = DeclareTargetDeviceType::any;
 };
 
-struct DistScheduleClauseOps {
-  UnitAttr distScheduleStatic;
-  Value distScheduleChunkSize;
-};
-
-struct DoacrossClauseOps {
-  ClauseDependAttr doacrossDependType;
-  IntegerAttr doacrossNumLoops;
-  llvm::SmallVector<Value> doacrossDependVars;
-};
-
-struct FilterClauseOps {
-  Value filteredThreadId;
-};
-
-struct FinalClauseOps {
-  Value final;
-};
-
-struct GrainsizeClauseOps {
-  Value grainsize;
-};
-
-struct HasDeviceAddrClauseOps {
-  llvm::SmallVector<Value> hasDeviceAddrVars;
-};
-
-struct HintClauseOps {
-  IntegerAttr hint;
-};
-
-struct IfClauseOps {
-  Value ifVar;
-};
-
-struct InReductionClauseOps {
-  llvm::SmallVector<Value> inReductionVars;
-  llvm::SmallVector<bool> inReductionByref;
-  llvm::SmallVector<Attribute> inReductionSyms;
-};
-
-struct IsDevicePtrClauseOps {
-  llvm::SmallVector<Value> isDevicePtrVars;
-};
-
-struct LinearClauseOps {
-  llvm::SmallVector<Value> linearVars, linearStepVars;
-};
-
-struct LoopRelatedOps {
-  llvm::SmallVector<Value> loopLowerBounds, loopUpperBounds, loopSteps;
-  UnitAttr loopInclusive;
-};
-
-struct MapClauseOps {
-  llvm::SmallVector<Value> mapVars;
-};
-
-struct MergeableClauseOps {
-  UnitAttr mergeable;
-};
-
-struct NogroupClauseOps {
-  UnitAttr nogroup;
-};
-
-struct NontemporalClauseOps {
-  llvm::SmallVector<Value> nontemporalVars;
-};
-
-struct NowaitClauseOps {
-  UnitAttr nowait;
-};
-
-struct NumTasksClauseOps {
-  Value numTasks;
-};
-
-struct NumTeamsClauseOps {
-  Value numTeamsLower, numTeamsUpper;
-};
-
-struct NumThreadsClauseOps {
-  Value numThreads;
-};
-
-struct OrderClauseOps {
-  ClauseOrderKindAttr order;
-  OrderModifierAttr orderMod;
-};
-
-struct OrderedClauseOps {
-  IntegerAttr ordered;
-};
-
-struct ParallelizationLevelClauseOps {
-  UnitAttr parLevelSimd;
-};
-
-struct PriorityClauseOps {
-  Value priority;
-};
-
-struct PrivateClauseOps {
-  // SSA values that correspond to "original" values being privatized.
-  // They refer to the SSA value outside the OpenMP region from which a clone is
-  // created inside the region.
-  llvm::SmallVector<Value> privateVars;
-  // The list of symbols referring to delayed privatizer ops (i.e. `omp.private`
-  // ops).
-  llvm::SmallVector<Attribute> privateSyms;
-};
-
-struct ProcBindClauseOps {
-  ClauseProcBindKindAttr procBindKind;
-};
-
-struct ReductionClauseOps {
-  llvm::SmallVector<Value> reductionVars;
-  llvm::SmallVector<bool> reductionByref;
-  llvm::SmallVector<Attribute> reductionSyms;
-};
-
-struct SafelenClauseOps {
-  IntegerAttr safelen;
-};
-
-struct ScheduleClauseOps {
-  ClauseScheduleKindAttr scheduleKind;
-  Value scheduleChunk;
-  ScheduleModifierAttr scheduleMod;
-  UnitAttr scheduleSimd;
-};
-
-struct SimdlenClauseOps {
-  IntegerAttr simdlen;
-};
-
-struct TaskReductionClauseOps {
-  llvm::SmallVector<Value> taskReductionVars;
-  llvm::SmallVector<bool> taskReductionByref;
-  llvm::SmallVector<Attribute> taskReductionSyms;
-};
-
-struct ThreadLimitClauseOps {
-  Value threadLimit;
-};
-
-struct UntiedClauseOps {
-  UnitAttr untied;
-};
-
-struct UseDeviceAddrClauseOps {
-  llvm::SmallVector<Value> useDeviceAddrVars;
-};
-
-struct UseDevicePtrClauseOps {
-  llvm::SmallVector<Value> useDevicePtrVars;
-};
-
 //===----------------------------------------------------------------------===//
-// Structures defining clause operands associated with each OpenMP leaf
-// construct.
-//
-// These mirror the arguments expected by the corresponding OpenMP MLIR ops.
+// Extra operation operand structures.
 //===----------------------------------------------------------------------===//
 
-namespace detail {
-template <typename... Mixins>
-struct Clauses : public Mixins... {};
-} // namespace detail
-
-using CancelOperands =
-    detail::Clauses<CancelDirectiveNameClauseOps, IfClauseOps>;
-
-using CancellationPointOperands = detail::Clauses<CancelDirectiveNameClauseOps>;
-
-using CriticalDeclareOperands =
-    detail::Clauses<CriticalNameClauseOps, HintClauseOps>;
-
-// TODO `indirect` clause.
+// TODO: Add `indirect` clause.
 using DeclareTargetOperands = detail::Clauses<DeviceTypeClauseOps>;
 
-using DistributeOperands =
-    detail::Clauses<AllocateClauseOps, DistScheduleClauseOps, OrderClauseOps,
-                    PrivateClauseOps>;
-
-using LoopNestOperands = detail::Clauses<LoopRelatedOps>;
-
-using MaskedOperands = detail::Clauses<FilterClauseOps>;
-
-using OrderedOperands = detail::Clauses<DoacrossClauseOps>;
-
-using OrderedRegionOperands = detail::Clauses<ParallelizationLevelClauseOps>;
-
-using ParallelOperands =
-    detail::Clauses<AllocateClauseOps, IfClauseOps, NumThreadsClauseOps,
-                    PrivateClauseOps, ProcBindClauseOps, ReductionClauseOps>;
-
-using SectionsOperands = detail::Clauses<AllocateClauseOps, NowaitClauseOps,
-                                         PrivateClauseOps, ReductionClauseOps>;
-
-using SimdOperands =
-    detail::Clauses<AlignedClauseOps, IfClauseOps, LinearClauseOps,
-                    NontemporalClauseOps, OrderClauseOps, PrivateClauseOps,
-                    ReductionClauseOps, SafelenClauseOps, SimdlenClauseOps>;
-
-using SingleOperands = detail::Clauses<AllocateClauseOps, CopyprivateClauseOps,
-                                       NowaitClauseOps, PrivateClauseOps>;
-
-// TODO `defaultmap`, `uses_allocators` clauses.
-using TargetOperands =
-    detail::Clauses<AllocateClauseOps, DependClauseOps, DeviceClauseOps,
-                    HasDeviceAddrClauseOps, IfClauseOps, InReductionClauseOps,
-                    IsDevicePtrClauseOps, MapClauseOps, NowaitClauseOps,
-                    PrivateClauseOps, ThreadLimitClauseOps>;
-
-using TargetDataOperands =
-    detail::Clauses<DeviceClauseOps, IfClauseOps, MapClauseOps,
-                    UseDeviceAddrClauseOps, UseDevicePtrClauseOps>;
-
-using TargetEnterExitUpdateDataOperands =
-    detail::Clauses<DependClauseOps, DeviceClauseOps, IfClauseOps, MapClauseOps,
-                    NowaitClauseOps>;
-
-// TODO `affinity`, `detach` clauses.
-using TaskOperands =
-    detail::Clauses<AllocateClauseOps, DependClauseOps, FinalClauseOps,
-                    IfClauseOps, InReductionClauseOps, MergeableClauseOps,
-                    PriorityClauseOps, PrivateClauseOps, UntiedClauseOps>;
-
-using TaskgroupOperands =
-    detail::Clauses<AllocateClauseOps, TaskReductionClauseOps>;
-
-using TaskloopOperands =
-    detail::Clauses<AllocateClauseOps, FinalClauseOps, GrainsizeClauseOps,
-                    IfClauseOps, InReductionClauseOps, MergeableClauseOps,
-                    NogroupClauseOps, NumTasksClauseOps, PriorityClauseOps,
-                    PrivateClauseOps, ReductionClauseOps, UntiedClauseOps>;
-
-using TaskwaitOperands = detail::Clauses<DependClauseOps, NowaitClauseOps>;
-
-using TeamsOperands =
-    detail::Clauses<AllocateClauseOps, IfClauseOps, NumTeamsClauseOps,
-                    PrivateClauseOps, ReductionClauseOps, ThreadLimitClauseOps>;
-
-using WsloopOperands =
-    detail::Clauses<AllocateClauseOps, LinearClauseOps, NowaitClauseOps,
-                    OrderClauseOps, OrderedClauseOps, PrivateClauseOps,
-                    ReductionClauseOps, ScheduleClauseOps>;
+/// omp.target_enter_data, omp.target_exit_data and omp.target_update take the
+/// same clauses, so we give the structure to be shared by all of them a
+/// representative name.
+using TargetEnterExitUpdateDataOperands = TargetEnterDataOperands;
 
 } // namespace omp
 } // namespace mlir
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index 1a9b87f0d68c9d..e4ed58f26016a5 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -1370,7 +1370,7 @@ static LogicalResult verifyMapClause(Operation *op, OperandRange mapVars) {
 
 void TargetDataOp::build(OpBuilder &builder, OperationState &state,
                          const TargetDataOperands &clauses) {
-  TargetDataOp::build(builder, state, clauses.device, clauses.ifVar,
+  TargetDataOp::build(builder, state, clauses.device, clauses.ifExpr,
                       clauses.mapVars, clauses.useDeviceAddrVars,
                       clauses.useDevicePtrVars);
 }
@@ -1395,7 +1395,7 @@ void TargetEnterDataOp::build(
   MLIRContext *ctx = builder.getContext();
   TargetEnterDataOp::build(builder, state,
                            makeArrayAttr(ctx, clauses.dependKinds),
-                           clauses.dependVars, clauses.device, clauses.ifVar,
+                           clauses.dependVars, clauses.device, clauses.ifExpr,
                            clauses.mapVars, clauses.nowait);
 }
 
@@ -1415,7 +1415,7 @@ void TargetExitDataOp::build(OpBuilder &builder, OperationState &state,
   MLIRContext *ctx = builder.getContext();
   TargetExitDataOp::build(builder, state,
                           makeArrayAttr(ctx, clauses.dependKinds),
-                          clauses.dependVars, clauses.device, clauses.ifVar,
+                          clauses.dependVars, clauses.device, clauses.ifExpr,
                           clauses.mapVars, clauses.nowait);
 }
 
@@ -1434,7 +1434,7 @@ void TargetUpdateOp::build(OpBuilder &builder, OperationState &state,
                            const TargetEnterExitUpdateDataOperands &clauses) {
   MLIRContext *ctx = builder.getContext();
   TargetUpdateOp::build(builder, state, makeArrayAttr(ctx, clauses.dependKinds),
-                        clauses.dependVars, clauses.device, clauses.ifVar,
+                        clauses.dependVars, clauses.device, clauses.ifExpr,
                         clauses.mapVars, clauses.nowait);
 }
 
@@ -1456,7 +1456,7 @@ void TargetOp::build(OpBuilder &builder, OperationState &state,
   // inReductionByref, inReductionSyms.
   TargetOp::build(builder, state, /*allocate_vars=*/{}, /*allocator_vars=*/{},
                   makeArrayAttr(ctx, clauses.dependKinds), clauses.dependVars,
-                  clauses.device, clauses.hasDeviceAddrVars, clauses.ifVar,
+                  clauses.device, clauses.hasDeviceAddrVars, clauses.ifExpr,
                   /*in_reduction_vars=*/{}, /*in_reduction_byref=*/nullptr,
                   /*in_reduction_syms=*/nullptr, clauses.isDevicePtrVars,
                   clauses.mapVars, clauses.nowait, clauses.privateVars,
@@ -1488,9 +1488,8 @@ void ParallelOp::build(OpBuilder &builder, OperationState &state,
 void ParallelOp::build(OpBuilder &builder, OperationState &state,
                        const ParallelOperands &clauses) {
   MLIRContext *ctx = builder.getContext();
-
   ParallelOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
-                    clauses.ifVar, clauses.numThreads, clauses.privateVars,
+                    clauses.ifExpr, clauses.numThreads, clauses.privateVars,
                     makeArrayAttr(ctx, clauses.privateSyms),
                     clauses.procBindKind, clauses.reductionVars,
                     makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
@@ -1588,13 +1587,12 @@ void TeamsOp::build(OpBuilder &builder, OperationState &state,
                     const TeamsOperands &clauses) {
   MLIRContext *ctx = builder.getContext();
   // TODO Store clauses in op: privateVars, privateSyms.
-  TeamsOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
-                 clauses.ifVar, clauses.numTeamsLower, clauses.numTeamsUpper,
-                 /*private_vars=*/{},
-                 /*private_syms=*/nullptr, clauses.reductionVars,
-                 makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
-                 makeArrayAttr(ctx, clauses.reductionSyms),
-                 clauses.threadLimit);
+  TeamsOp::build(
+      builder, state, clauses.allocateVars, clauses.allocatorVars,
+      clauses.ifExpr, clauses.numTeamsLower, clauses.numTeamsUpper,
+      /*private_vars=*/{}, /*private_syms=*/nullptr, clauses.reductionVars,
+      makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
+      makeArrayAttr(ctx, clauses.reductionSyms), clauses.threadLimit);
 }
 
 LogicalResult TeamsOp::verify() {
@@ -1814,7 +1812,7 @@ void SimdOp::build(OpBuilder &builder, OperationState &state,
   // TODO Store clauses in op: linearVars, linearStepVars, privateVars,
   // privateSyms, reductionVars, reductionByref, reductionSyms.
   SimdOp::build(builder, state, clauses.alignedVars,
-                makeArrayAttr(ctx, clauses.alignments), clauses.ifVar,
+                makeArrayAttr(ctx, clauses.alignments), clauses.ifExpr,
                 /*linear_vars=*/{}, /*linear_step_vars=*/{},
                 clauses.nontemporalVars, clauses.order, clauses.orderMod,
                 /*private_vars=*/{}, /*private_syms=*/nullptr,
@@ -1996,7 +1994,7 @@ void TaskOp::build(OpBuilder &builder, OperationState &state,
   // TODO Store clauses in op: privateVars, privateSyms.
   TaskOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
                 makeArrayAttr(ctx, clauses.dependKinds), clauses.dependVars,
-                clauses.final, clauses.ifVar, clauses.inReductionVars,
+                clauses.final, clauses.ifExpr, clauses.inReductionVars,
                 makeDenseBoolArrayAttr(ctx, clauses.inReductionByref),
                 makeArrayAttr(ctx, clauses.inReductionSyms), clauses.mergeable,
                 clauses.priority, /*private_vars=*/{}, /*private_syms=*/nullptr,
@@ -2042,7 +2040,7 @@ void TaskloopOp::build(OpBuilder &builder, OperationState &state,
   // TODO Store clauses in op: privateVars, privateSyms.
   TaskloopOp::build(
       builder, state, clauses.allocateVars, clauses.allocatorVars,
-      clauses.final, clauses.grainsize, clauses.ifVar, clauses.inReductionVars,
+      clauses.final, clauses.grainsize, clauses.ifExpr, clauses.inReductionVars,
       makeDenseBoolArrayAttr(ctx, clauses.inReductionByref),
       makeArrayAttr(ctx, clauses.inReductionSyms), clauses.mergeable,
       clauses.nogroup, clauses.numTasks, clauses.priority, /*private_vars=*/{},
@@ -2424,7 +2422,7 @@ LogicalResult AtomicCaptureOp::verifyRegions() {
 
 void CancelOp::build(OpBuilder &builder, OperationState &state,
                      const CancelOperands &clauses) {
-  CancelOp::build(builder, state, clauses.cancelDirective, clauses.ifVar);
+  CancelOp::build(builder, state, clauses.cancelDirective, clauses.ifExpr);
 }
 
 LogicalResult CancelOp::verify() {
diff --git a/mlir/test/mlir-tblgen/openmp-clause-ops.td b/mlir/test/mlir-tblgen/openmp-clause-ops.td
new file mode 100644
index 00000000000000..cee3f2a693bf8c
--- /dev/null
+++ b/mlir/test/mlir-tblgen/openmp-clause-ops.td
@@ -0,0 +1,86 @@
+// Tablegen tests for the automatic generation of OpenMP clause operand
+// structure definitions.
+
+// Run tablegen to generate OmpCommon.td in temp directory first.
+// RUN: mkdir -p %t/mlir/Dialect/OpenMP
+// RUN: mlir-tblgen --gen-directive-decl --directives-dialect=OpenMP \
+// RUN:   %S/../../../llvm/include/llvm/Frontend/OpenMP/OMP.td \
+// RUN:   -I %S/../../../llvm/include > %t/mlir/Dialect/OpenMP/OmpCommon.td
+
+// RUN: mlir-tblgen -gen-openmp-clause-ops -I %S/../../include -I %t %s 2>&1 | FileCheck %s
+
+include "mlir/Dialect/OpenMP/OpenMPOpBase.td"
+
+
+def OpenMP_MyFirstClause : OpenMP_Clause<
+    /*isRequired=*/false, /*skipTraits=*/false, /*skipArguments=*/false,
+    /*skipAssemblyFormat=*/false, /*skipDescription=*/false,
+    /*skipExtraClassDeclaration=*/false> {
+  let arguments = (ins
+    // Simple attributes
+    I32Attr:$int_attr,
+    TypeAttr:$type_attr,
+    DeclareTargetAttr:$omp_attr,
+
+    // Array attributes
+    F32ArrayAttr:$float_array_attr,
+    StrArrayAttr:$str_array_attr,
+    AnyIntElementsAttr:$anyint_elems_attr,
+    RankedF32ElementsAttr<[3, 4, 5]>:$float_nd_elems_attr,
+
+    // Optional attributes
+    OptionalAttr<BoolAttr>:$opt_bool_attr,
+    OptionalAttr<I64ArrayAttr>:$opt_int_array_attr,
+    OptionalAttr<DenseI8ArrayAttr>:$opt_int_elems_attr,
+
+    // Multi-level composition
+    ConfinedAttr<OptionalAttr<I64Attr>, [IntMinValue<0>]>:$complex_opt_int_attr,
+
+    // ElementsAttrBase-related edge cases.
+    // CHECK: warning: could not infer array-like attribute element type for argument 'elements_attr', will use bare `storageType`
+    ElementsAttr:$elements_attr,
+    // CHECK: warning: could not infer array-like attribute element type for argument 'string_elements_attr', will use bare `storageType`
+    StringElementsAttr:$string_elements_attr
+  );
+}
+// CHECK:      struct MyFirstClauseOps {
+// CHECK-NEXT:   ::mlir::IntegerAttr intAttr;
+// CHECK-NEXT:   ::mlir::TypeAttr typeAttr;
+// CHECK-NEXT:   ::mlir::omp::DeclareTargetAttr ompAttr;
+
+// CHECK-NEXT:   ::llvm::SmallVector<::mlir::Attribute> floatArrayAttr;
+// CHECK-NEXT:   ::llvm::SmallVector<::mlir::Attribute> strArrayAttr;
+// CHECK-NEXT:   ::llvm::SmallVector<::llvm::APInt> anyintElemsAttr;
+// CHECK-NEXT:   ::llvm::SmallVector<::llvm::APFloat> floatNdElemsAttr;
+
+// CHECK-NEXT:   ::mlir::BoolAttr optBoolAttr;
+// CHECK-NEXT:   ::llvm::SmallVector<::mlir::Attribute> optIntArrayAttr;
+// CHECK-NEXT:   ::llvm::SmallVector<int8_t> optIntElemsAttr;
+
+// CHECK-NEXT:   ::mlir::IntegerAttr complexOptIntAttr;
+
+// CHECK-NEXT:   ::mlir::ElementsAttr elementsAttr;
+// CHECK-NEXT:   ::mlir::DenseElementsAttr stringElementsAttr;
+// CHECK-NEXT: }
+
+def OpenMP_MySecondClause : OpenMP_Clause<
+    /*isRequired=*/false, /*skipTraits=*/false, /*skipArguments=*/false,
+    /*skipAssemblyFormat=*/false, /*skipDescription=*/false,
+    /*skipExtraClassDeclaration=*/false> {
+  let arguments = (ins
+    I32:$int_val,
+    Optional<AnyType>:$opt_any_val,
+    Variadic<Index>:$variadic_index_val
+  );
+}
+// CHECK:      struct MySecondClauseOps {
+// CHECK-NEXT:   ::mlir::Value intVal;
+// CHECK-NEXT:   ::mlir::Value optAnyVal;
+// CHECK-NEXT:   ::llvm::SmallVector<::mlir::Value> variadicIndexVal;
+// CHECK-NEXT: }
+
+def OpenMP_MyFirstOp : OpenMP_Op<"op", clauses=[OpenMP_MyFirstClause]>;
+// CHECK: using MyFirstOperands = detail::Clauses<MyFirstClauseOps>;
+
+def OpenMP_MySecondOp : OpenMP_Op<"op", clauses=[OpenMP_MyFirstClause, OpenMP_MySecondClause]>;
+// CHECK: using MySecondOperands = detail::Clauses<MyFirstClauseOps, MySecondClauseOps>;
diff --git a/mlir/tools/mlir-tblgen/OmpOpGen.cpp b/mlir/tools/mlir-tblgen/OmpOpGen.cpp
index 15458212637888..f546e1b1b6691c 100644
--- a/mlir/tools/mlir-tblgen/OmpOpGen.cpp
+++ b/mlir/tools/mlir-tblgen/OmpOpGen.cpp
@@ -12,11 +12,54 @@
 
 #include "mlir/TableGen/GenInfo.h"
 
+#include "mlir/TableGen/CodeGenHelpers.h"
+#include "llvm/ADT/StringExtras.h"
+#include "llvm/ADT/StringSet.h"
+#include "llvm/ADT/TypeSwitch.h"
+#include "llvm/Support/FormatAdapters.h"
 #include "llvm/TableGen/Error.h"
 #include "llvm/TableGen/Record.h"
 
 using namespace llvm;
 
+/// The code block defining the base mixin class for combining clause operand
+/// structures.
+static const char *const baseMixinClass = R"(
+namespace detail {
+template <typename... Mixins>
+struct Clauses : public Mixins... {};
+} // namespace detail
+)";
+
+/// The code block defining operation argument structures.
+static const char *const operationArgStruct = R"(
+using {0}Operands = detail::Clauses<{1}>;
+)";
+
+/// Remove multiple optional prefixes and suffixes from \c str.
+///
+/// Prefixes and suffixes are attempted to be removed once in the order they
+/// appear in the \c prefixes and \c suffixes arguments. All prefixes are
+/// processed before suffixes are. This means it will behave as shown in the
+/// following example:
+///   - str: "PrePreNameSuf1Suf2"
+///   - prefixes: ["Pre"]
+///   - suffixes: ["Suf1", "Suf2"]
+///   - return: "PreNameSuf1"
+static StringRef stripPrefixAndSuffix(StringRef str,
+                                      llvm::ArrayRef<StringRef> prefixes,
+                                      llvm::ArrayRef<StringRef> suffixes) {
+  for (StringRef prefix : prefixes)
+    if (str.starts_with(prefix))
+      str = str.drop_front(prefix.size());
+
+  for (StringRef suffix : suffixes)
+    if (str.ends_with(suffix))
+      str = str.drop_back(suffix.size());
+
+  return str;
+}
+
 /// Obtain the name of the OpenMP clause a given record inheriting
 /// `OpenMP_Clause` refers to.
 ///
@@ -53,19 +96,8 @@ static StringRef extractOmpClauseName(const Record *clause) {
   assert(!clauseClassName.empty() && "clause name must be found");
 
   // Keep only the OpenMP clause name itself for reporting purposes.
-  StringRef prefix = "OpenMP_";
-  StringRef suffixes[] = {"Skip", "Clause"};
-
-  if (clauseClassName.starts_with(prefix))
-    clauseClassName = clauseClassName.substr(prefix.size());
-
-  for (StringRef suffix : suffixes) {
-    if (clauseClassName.ends_with(suffix))
-      clauseClassName =
-          clauseClassName.substr(0, clauseClassName.size() - suffix.size());
-  }
-
-  return clauseClassName;
+  return stripPrefixAndSuffix(clauseClassName, /*prefixes=*/{"OpenMP_"},
+                              /*suffixes=*/{"Skip", "Clause"});
 }
 
 /// Check that the given argument, identified by its name and initialization
@@ -148,6 +180,139 @@ static void verifyClause(const Record *op, const Record *clause) {
             "or explicitly skipping this field.");
 }
 
+/// Translate the type of an OpenMP clause's argument to its corresponding
+/// representation for clause operand structures.
+///
+/// All kinds of values are represented as `mlir::Value` fields, whereas
+/// attributes are represented based on their `storageType`.
+///
+/// \param[in] name The name of the argument.
+/// \param[in] init The `DefInit` object representing the argument.
+/// \param[out] nest Number of levels of array nesting associated with the
+///                  type. Must be initially set to 0.
+/// \param[out] rank Rank (number of dimensions, if an array type) of the base
+///                  type. Must be initially set to 1.
+///
+/// \return the name of the base type to represent elements of the argument
+///         type.
+static StringRef translateArgumentType(ArrayRef<SMLoc> loc, StringInit *name,
+                                       Init *init, int &nest, int &rank) {
+  Record *def = cast<DefInit>(init)->getDef();
+
+  llvm::StringSet superClasses;
+  for (auto [sc, _] : def->getSuperClasses())
+    superClasses.insert(sc->getNameInitAsString());
+
+  // Handle wrapper-style superclasses.
+  if (superClasses.contains("OptionalAttr"))
+    return translateArgumentType(
+        loc, name, def->getValue("baseAttr")->getValue(), nest, rank);
+
+  if (superClasses.contains("TypedArrayAttrBase"))
+    return translateArgumentType(
+        loc, name, def->getValue("elementAttr")->getValue(), ++nest, rank);
+
+  // Handle ElementsAttrBase superclasses.
+  if (superClasses.contains("ElementsAttrBase")) {
+    // TODO: Obtain the rank from ranked types.
+    ++nest;
+
+    if (superClasses.contains("IntElementsAttrBase"))
+      return "::llvm::APInt";
+    if (superClasses.contains("FloatElementsAttr") ||
+        superClasses.contains("RankedFloatElementsAttr"))
+      return "::llvm::APFloat";
+    if (superClasses.contains("DenseArrayAttrBase"))
+      return stripPrefixAndSuffix(def->getValueAsString("returnType"),
+                                  {"::llvm::ArrayRef<"}, {">"});
+
+    // Decrease the nesting depth in the case where the base type cannot be
+    // inferred, so that the bare storageType is used instead of a vector.
+    --nest;
+    PrintWarning(
+        loc,
+        "could not infer array-like attribute element type for argument '" +
+            name->getAsUnquotedString() + "', will use bare `storageType`");
+  }
+
+  // Handle simple attribute and value types.
+  bool isAttr = superClasses.contains("Attr");
+  bool isValue = superClasses.contains("TypeConstraint");
+  if (superClasses.contains("Variadic"))
+    ++nest;
+
+  if (isValue) {
+    assert(!isAttr &&
+           "argument can't be simultaneously a value and an attribute");
+    return "::mlir::Value";
+  }
+
+  assert(isAttr && "argument must be an attribute if it's not a value");
+  return nest > 0 ? "::mlir::Attribute"
+                  : def->getValueAsString("storageType").trim();
+}
+
+/// Generate the structure that represents the arguments of the given \c clause
+/// record of type \c OpenMP_Clause.
+///
+/// It will contain a field for each argument, using the same name translated to
+/// camel case and the corresponding base type as returned by
+/// translateArgumentType() optionally wrapped in one or more llvm::SmallVector.
+///
+/// An additional field containing a tuple of integers to hold the size of each
+/// dimension will also be created for multi-rank types. This is not yet
+/// supported.
+static void genClauseOpsStruct(const Record *clause, raw_ostream &os) {
+  if (clause->isAnonymous())
+    return;
+
+  StringRef clauseName = extractOmpClauseName(clause);
+  os << "struct " << clauseName << "ClauseOps {\n";
+
+  DagInit *arguments = clause->getValueAsDag("arguments");
+  for (auto [name, arg] :
+       zip_equal(arguments->getArgNames(), arguments->getArgs())) {
+    int nest = 0, rank = 1;
+    StringRef baseType =
+        translateArgumentType(clause->getLoc(), name, arg, nest, rank);
+    std::string fieldName =
+        convertToCamelFromSnakeCase(name->getAsUnquotedString(),
+                                    /*capitalizeFirst=*/false);
+
+    os << formatv("  {0}{1}{2} {3};\n",
+                  fmt_repeat("::llvm::SmallVector<", nest), baseType,
+                  fmt_repeat(">", nest), fieldName);
+
+    if (rank > 1) {
+      assert(nest >= 1 && "must be nested if it's a ranked type");
+      os << formatv("  {0}::std::tuple<{1}int>{2} {3}Dims;\n",
+                    fmt_repeat("::llvm::SmallVector<", nest - 1),
+                    fmt_repeat("int, ", rank - 1), fmt_repeat(">", nest - 1),
+                    fieldName);
+    }
+  }
+
+  os << "};\n";
+}
+
+/// Generate the structure that represents the clause-related arguments of the
+/// given \c op record of type \c OpenMP_Op.
+///
+/// This structure will be defined in terms of the clause operand structures
+/// associated to the clauses of the operation.
+static void genOperandsDef(const Record *op, raw_ostream &os) {
+  if (op->isAnonymous())
+    return;
+
+  SmallVector<std::string> clauseNames;
+  for (Record *clause : op->getValueAsListOfDefs("clauseList"))
+    clauseNames.push_back((extractOmpClauseName(clause) + "ClauseOps").str());
+
+  StringRef opName = stripPrefixAndSuffix(
+      op->getName(), /*prefixes=*/{"OpenMP_"}, /*suffixes=*/{"Op"});
+  os << formatv(operationArgStruct, opName, join(clauseNames, ", "));
+}
+
 /// Verify that all properties of `OpenMP_Clause`s of records deriving from
 /// `OpenMP_Op`s have been inherited by the latter.
 static bool verifyDecls(const RecordKeeper &recordKeeper, raw_ostream &) {
@@ -159,8 +324,32 @@ static bool verifyDecls(const RecordKeeper &recordKeeper, raw_ostream &) {
   return false;
 }
 
+/// Generate structures to represent clause-related operands, based on existing
+/// `OpenMP_Clause` definitions and aggregate them into operation-specific
+/// structures according to the `clauses` argument of each definition deriving
+/// from `OpenMP_Op`.
+static bool genClauseOps(const RecordKeeper &recordKeeper, raw_ostream &os) {
+  mlir::tblgen::NamespaceEmitter ns(os, "mlir::omp");
+  for (const Record *clause :
+       recordKeeper.getAllDerivedDefinitions("OpenMP_Clause"))
+    genClauseOpsStruct(clause, os);
+
+  // Produce base mixin class.
+  os << baseMixinClass;
+
+  for (const Record *op : recordKeeper.getAllDerivedDefinitions("OpenMP_Op"))
+    genOperandsDef(op, os);
+
+  return false;
+}
+
 // Registers the generator to mlir-tblgen.
 static mlir::GenRegistration
     verifyOpenmpOps("verify-openmp-ops",
                     "Verify OpenMP operations (produce no output file)",
                     verifyDecls);
+
+static mlir::GenRegistration
+    genOpenmpClauseOps("gen-openmp-clause-ops",
+                       "Generate OpenMP clause operand structures",
+                       genClauseOps);



More information about the flang-commits mailing list