[flang-commits] [flang] [mlir] [MLIR][OpenMP] Automate operand structure definition (PR #99508)

Sergio Afonso via flang-commits flang-commits at lists.llvm.org
Tue Aug 27 09:21:24 PDT 2024


https://github.com/skatrak updated https://github.com/llvm/llvm-project/pull/99508

>From 8948a6bd0f0ce8438fcf5b2b1e3c4d538d2fd782 Mon Sep 17 00:00:00 2001
From: Sergio Afonso <safonsof at amd.com>
Date: Wed, 17 Jul 2024 13:26:09 +0100
Subject: [PATCH 1/2] [MLIR][OpenMP] Automate operand structure definition

This patch adds the "gen-openmp-clause-ops" `mlir-tblgen` generator to produce
the structure definitions previously in OpenMPClauseOperands.h automatically
from the information contained in OpenMPOps.td and OpenMPClauses.td.

Changes introduced to the `ElementsAttrBase` common tablegen class, as well as
some of its subclasses, add more fine-grained information on their shape and
type of their elements. This information is needed in order to properly
generate the corresponding types to represent these attributes within the
produced operand structures.

The original header is maintained to enable the definition of similar
structures that are not directly related to any single `OpenMP_Clause` or
`OpenMP_Op` tablegen definition.
---
 flang/lib/Lower/OpenMP/ClauseProcessor.cpp    |   6 +-
 flang/lib/Lower/OpenMP/ClauseProcessor.h      |   2 +-
 .../mlir/Dialect/OpenMP/CMakeLists.txt        |   1 +
 .../Dialect/OpenMP/OpenMPClauseOperands.h     | 292 +-----------------
 mlir/include/mlir/IR/CommonAttrConstraints.td |  18 +-
 mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp  |  34 +-
 mlir/test/mlir-tblgen/openmp-clause-ops.td    |  78 +++++
 mlir/tools/mlir-tblgen/OmpOpGen.cpp           | 184 ++++++++++-
 8 files changed, 297 insertions(+), 318 deletions(-)
 create mode 100644 mlir/test/mlir-tblgen/openmp-clause-ops.td

diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
index 310c0b0b5fb636..4d510ebe7df589 100644
--- a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
+++ b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
@@ -181,7 +181,7 @@ static void addUseDeviceClause(
 
 static void convertLoopBounds(lower::AbstractConverter &converter,
                               mlir::Location loc,
-                              mlir::omp::LoopRelatedOps &result,
+                              mlir::omp::LoopRelatedClauseOps &result,
                               std::size_t loopVarTypeSize) {
   fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
   // The types of lower bound, upper bound, and step are converted into the
@@ -203,7 +203,7 @@ static void convertLoopBounds(lower::AbstractConverter &converter,
 
 bool ClauseProcessor::processCollapse(
     mlir::Location currentLocation, lower::pft::Evaluation &eval,
-    mlir::omp::LoopRelatedOps &result,
+    mlir::omp::LoopRelatedClauseOps &result,
     llvm::SmallVectorImpl<const semantics::Symbol *> &iv) const {
   bool found = false;
   fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
@@ -875,7 +875,7 @@ bool ClauseProcessor::processIf(
     // Assume that, at most, a single 'if' clause will be applicable to the
     // given directive.
     if (operand) {
-      result.ifVar = operand;
+      result.ifExpr = operand;
       found = true;
     }
   });
diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.h b/flang/lib/Lower/OpenMP/ClauseProcessor.h
index 2c4b3857fda9f3..c870680b61756a 100644
--- a/flang/lib/Lower/OpenMP/ClauseProcessor.h
+++ b/flang/lib/Lower/OpenMP/ClauseProcessor.h
@@ -55,7 +55,7 @@ class ClauseProcessor {
   // 'Unique' clauses: They can appear at most once in the clause list.
   bool
   processCollapse(mlir::Location currentLocation, lower::pft::Evaluation &eval,
-                  mlir::omp::LoopRelatedOps &result,
+                  mlir::omp::LoopRelatedClauseOps &result,
                   llvm::SmallVectorImpl<const semantics::Symbol *> &iv) const;
   bool processDefault() const;
   bool processDevice(lower::StatementContext &stmtCtx,
diff --git a/mlir/include/mlir/Dialect/OpenMP/CMakeLists.txt b/mlir/include/mlir/Dialect/OpenMP/CMakeLists.txt
index dd349d1392e7bf..a65c6b1d3c96bc 100644
--- a/mlir/include/mlir/Dialect/OpenMP/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/OpenMP/CMakeLists.txt
@@ -17,6 +17,7 @@ mlir_tablegen(OpenMPOpsDialect.h.inc -gen-dialect-decls -dialect=omp)
 mlir_tablegen(OpenMPOpsDialect.cpp.inc -gen-dialect-defs -dialect=omp)
 mlir_tablegen(OpenMPOps.h.inc -gen-op-decls)
 mlir_tablegen(OpenMPOps.cpp.inc -gen-op-defs)
+mlir_tablegen(OpenMPClauseOps.h.inc -gen-openmp-clause-ops)
 mlir_tablegen(OpenMPOpsTypes.h.inc -gen-typedef-decls -typedefs-dialect=omp)
 mlir_tablegen(OpenMPOpsTypes.cpp.inc -gen-typedef-defs -typedefs-dialect=omp)
 mlir_tablegen(OpenMPOpsEnums.h.inc -gen-enum-decls)
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauseOperands.h b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauseOperands.h
index 38e4d8f245e4fa..1247a871f93c6d 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauseOperands.h
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauseOperands.h
@@ -23,303 +23,31 @@
 #define GET_ATTRDEF_CLASSES
 #include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.h.inc"
 
+#include "mlir/Dialect/OpenMP/OpenMPClauseOps.h.inc"
+
 namespace mlir {
 namespace omp {
 
 //===----------------------------------------------------------------------===//
-// Mixin structures defining MLIR operands associated with each OpenMP clause.
+// Extra clause operand structures.
 //===----------------------------------------------------------------------===//
 
-struct AlignedClauseOps {
-  llvm::SmallVector<Value> alignedVars;
-  llvm::SmallVector<Attribute> alignments;
-};
-
-struct AllocateClauseOps {
-  llvm::SmallVector<Value> allocateVars, allocatorVars;
-};
-
-struct CancelDirectiveNameClauseOps {
-  ClauseCancellationConstructTypeAttr cancelDirective;
-};
-
-struct CopyprivateClauseOps {
-  llvm::SmallVector<Value> copyprivateVars;
-  llvm::SmallVector<Attribute> copyprivateSyms;
-};
-
-struct CriticalNameClauseOps {
-  /// This field has a generic name because it's mirroring the `sym_name`
-  /// argument of the `OpenMP_CriticalNameClause` tablegen definition. That one
-  /// can't be renamed to anything more specific because the `sym_name` name is
-  /// a requirement of the `Symbol` MLIR trait associated with that clause.
-  StringAttr symName;
-};
-
-struct DependClauseOps {
-  llvm::SmallVector<Attribute> dependKinds;
-  llvm::SmallVector<Value> dependVars;
-};
-
-struct DeviceClauseOps {
-  Value device;
-};
-
 struct DeviceTypeClauseOps {
-  // The default capture type.
+  /// The default capture type.
   DeclareTargetDeviceType deviceType = DeclareTargetDeviceType::any;
 };
 
-struct DistScheduleClauseOps {
-  UnitAttr distScheduleStatic;
-  Value distScheduleChunkSize;
-};
-
-struct DoacrossClauseOps {
-  ClauseDependAttr doacrossDependType;
-  IntegerAttr doacrossNumLoops;
-  llvm::SmallVector<Value> doacrossDependVars;
-};
-
-struct FilterClauseOps {
-  Value filteredThreadId;
-};
-
-struct FinalClauseOps {
-  Value final;
-};
-
-struct GrainsizeClauseOps {
-  Value grainsize;
-};
-
-struct HasDeviceAddrClauseOps {
-  llvm::SmallVector<Value> hasDeviceAddrVars;
-};
-
-struct HintClauseOps {
-  IntegerAttr hint;
-};
-
-struct IfClauseOps {
-  Value ifVar;
-};
-
-struct InReductionClauseOps {
-  llvm::SmallVector<Value> inReductionVars;
-  llvm::SmallVector<bool> inReductionByref;
-  llvm::SmallVector<Attribute> inReductionSyms;
-};
-
-struct IsDevicePtrClauseOps {
-  llvm::SmallVector<Value> isDevicePtrVars;
-};
-
-struct LinearClauseOps {
-  llvm::SmallVector<Value> linearVars, linearStepVars;
-};
-
-struct LoopRelatedOps {
-  llvm::SmallVector<Value> loopLowerBounds, loopUpperBounds, loopSteps;
-  UnitAttr loopInclusive;
-};
-
-struct MapClauseOps {
-  llvm::SmallVector<Value> mapVars;
-};
-
-struct MergeableClauseOps {
-  UnitAttr mergeable;
-};
-
-struct NogroupClauseOps {
-  UnitAttr nogroup;
-};
-
-struct NontemporalClauseOps {
-  llvm::SmallVector<Value> nontemporalVars;
-};
-
-struct NowaitClauseOps {
-  UnitAttr nowait;
-};
-
-struct NumTasksClauseOps {
-  Value numTasks;
-};
-
-struct NumTeamsClauseOps {
-  Value numTeamsLower, numTeamsUpper;
-};
-
-struct NumThreadsClauseOps {
-  Value numThreads;
-};
-
-struct OrderClauseOps {
-  ClauseOrderKindAttr order;
-  OrderModifierAttr orderMod;
-};
-
-struct OrderedClauseOps {
-  IntegerAttr ordered;
-};
-
-struct ParallelizationLevelClauseOps {
-  UnitAttr parLevelSimd;
-};
-
-struct PriorityClauseOps {
-  Value priority;
-};
-
-struct PrivateClauseOps {
-  // SSA values that correspond to "original" values being privatized.
-  // They refer to the SSA value outside the OpenMP region from which a clone is
-  // created inside the region.
-  llvm::SmallVector<Value> privateVars;
-  // The list of symbols referring to delayed privatizer ops (i.e. `omp.private`
-  // ops).
-  llvm::SmallVector<Attribute> privateSyms;
-};
-
-struct ProcBindClauseOps {
-  ClauseProcBindKindAttr procBindKind;
-};
-
-struct ReductionClauseOps {
-  llvm::SmallVector<Value> reductionVars;
-  llvm::SmallVector<bool> reductionByref;
-  llvm::SmallVector<Attribute> reductionSyms;
-};
-
-struct SafelenClauseOps {
-  IntegerAttr safelen;
-};
-
-struct ScheduleClauseOps {
-  ClauseScheduleKindAttr scheduleKind;
-  Value scheduleChunk;
-  ScheduleModifierAttr scheduleMod;
-  UnitAttr scheduleSimd;
-};
-
-struct SimdlenClauseOps {
-  IntegerAttr simdlen;
-};
-
-struct TaskReductionClauseOps {
-  llvm::SmallVector<Value> taskReductionVars;
-  llvm::SmallVector<bool> taskReductionByref;
-  llvm::SmallVector<Attribute> taskReductionSyms;
-};
-
-struct ThreadLimitClauseOps {
-  Value threadLimit;
-};
-
-struct UntiedClauseOps {
-  UnitAttr untied;
-};
-
-struct UseDeviceAddrClauseOps {
-  llvm::SmallVector<Value> useDeviceAddrVars;
-};
-
-struct UseDevicePtrClauseOps {
-  llvm::SmallVector<Value> useDevicePtrVars;
-};
-
 //===----------------------------------------------------------------------===//
-// Structures defining clause operands associated with each OpenMP leaf
-// construct.
-//
-// These mirror the arguments expected by the corresponding OpenMP MLIR ops.
+// Extra operation operand structures.
 //===----------------------------------------------------------------------===//
 
-namespace detail {
-template <typename... Mixins>
-struct Clauses : public Mixins... {};
-} // namespace detail
-
-using CancelOperands =
-    detail::Clauses<CancelDirectiveNameClauseOps, IfClauseOps>;
-
-using CancellationPointOperands = detail::Clauses<CancelDirectiveNameClauseOps>;
-
-using CriticalDeclareOperands =
-    detail::Clauses<CriticalNameClauseOps, HintClauseOps>;
-
-// TODO `indirect` clause.
+// TODO: Add `indirect` clause.
 using DeclareTargetOperands = detail::Clauses<DeviceTypeClauseOps>;
 
-using DistributeOperands =
-    detail::Clauses<AllocateClauseOps, DistScheduleClauseOps, OrderClauseOps,
-                    PrivateClauseOps>;
-
-using LoopNestOperands = detail::Clauses<LoopRelatedOps>;
-
-using MaskedOperands = detail::Clauses<FilterClauseOps>;
-
-using OrderedOperands = detail::Clauses<DoacrossClauseOps>;
-
-using OrderedRegionOperands = detail::Clauses<ParallelizationLevelClauseOps>;
-
-using ParallelOperands =
-    detail::Clauses<AllocateClauseOps, IfClauseOps, NumThreadsClauseOps,
-                    PrivateClauseOps, ProcBindClauseOps, ReductionClauseOps>;
-
-using SectionsOperands = detail::Clauses<AllocateClauseOps, NowaitClauseOps,
-                                         PrivateClauseOps, ReductionClauseOps>;
-
-using SimdOperands =
-    detail::Clauses<AlignedClauseOps, IfClauseOps, LinearClauseOps,
-                    NontemporalClauseOps, OrderClauseOps, PrivateClauseOps,
-                    ReductionClauseOps, SafelenClauseOps, SimdlenClauseOps>;
-
-using SingleOperands = detail::Clauses<AllocateClauseOps, CopyprivateClauseOps,
-                                       NowaitClauseOps, PrivateClauseOps>;
-
-// TODO `defaultmap`, `uses_allocators` clauses.
-using TargetOperands =
-    detail::Clauses<AllocateClauseOps, DependClauseOps, DeviceClauseOps,
-                    HasDeviceAddrClauseOps, IfClauseOps, InReductionClauseOps,
-                    IsDevicePtrClauseOps, MapClauseOps, NowaitClauseOps,
-                    PrivateClauseOps, ThreadLimitClauseOps>;
-
-using TargetDataOperands =
-    detail::Clauses<DeviceClauseOps, IfClauseOps, MapClauseOps,
-                    UseDeviceAddrClauseOps, UseDevicePtrClauseOps>;
-
-using TargetEnterExitUpdateDataOperands =
-    detail::Clauses<DependClauseOps, DeviceClauseOps, IfClauseOps, MapClauseOps,
-                    NowaitClauseOps>;
-
-// TODO `affinity`, `detach` clauses.
-using TaskOperands =
-    detail::Clauses<AllocateClauseOps, DependClauseOps, FinalClauseOps,
-                    IfClauseOps, InReductionClauseOps, MergeableClauseOps,
-                    PriorityClauseOps, PrivateClauseOps, UntiedClauseOps>;
-
-using TaskgroupOperands =
-    detail::Clauses<AllocateClauseOps, TaskReductionClauseOps>;
-
-using TaskloopOperands =
-    detail::Clauses<AllocateClauseOps, FinalClauseOps, GrainsizeClauseOps,
-                    IfClauseOps, InReductionClauseOps, MergeableClauseOps,
-                    NogroupClauseOps, NumTasksClauseOps, PriorityClauseOps,
-                    PrivateClauseOps, ReductionClauseOps, UntiedClauseOps>;
-
-using TaskwaitOperands = detail::Clauses<DependClauseOps, NowaitClauseOps>;
-
-using TeamsOperands =
-    detail::Clauses<AllocateClauseOps, IfClauseOps, NumTeamsClauseOps,
-                    PrivateClauseOps, ReductionClauseOps, ThreadLimitClauseOps>;
-
-using WsloopOperands =
-    detail::Clauses<AllocateClauseOps, LinearClauseOps, NowaitClauseOps,
-                    OrderClauseOps, OrderedClauseOps, PrivateClauseOps,
-                    ReductionClauseOps, ScheduleClauseOps>;
+/// omp.target_enter_data, omp.target_exit_data and omp.target_update take the
+/// same clauses, so we give the structure to be shared by all of them a
+/// representative name.
+using TargetEnterExitUpdateDataOperands = TargetEnterDataOperands;
 
 } // namespace omp
 } // namespace mlir
diff --git a/mlir/include/mlir/IR/CommonAttrConstraints.td b/mlir/include/mlir/IR/CommonAttrConstraints.td
index d99bde1f87ef00..f9dcfedfee1051 100644
--- a/mlir/include/mlir/IR/CommonAttrConstraints.td
+++ b/mlir/include/mlir/IR/CommonAttrConstraints.td
@@ -408,10 +408,18 @@ class ElementsAttrBase<Pred condition, string summary> :
   let storageType = [{ ::mlir::ElementsAttr }];
   let returnType = [{ ::mlir::ElementsAttr }];
   let convertFromStorage = "$_self";
+
+  // The underlying C++ value type of each element.
+  string elementReturnType = ?;
+
+  // The number of dimensions represented by the element collection.
+  int rank = 1;
 }
 
 def ElementsAttr : ElementsAttrBase<CPred<"::llvm::isa<::mlir::ElementsAttr>($_self)">,
-                                    "constant vector/tensor attribute">;
+                                    "constant vector/tensor attribute"> {
+  let elementReturnType = [{ ::mlir::Attribute }];
+}
 
 class IntElementsAttrBase<Pred condition, string summary> :
     ElementsAttrBase<And<[CPred<"::llvm::isa<::mlir::DenseIntElementsAttr>($_self)">,
@@ -419,6 +427,7 @@ class IntElementsAttrBase<Pred condition, string summary> :
                      summary> {
   let storageType = [{ ::mlir::DenseIntElementsAttr }];
   let returnType = [{ ::mlir::DenseIntElementsAttr }];
+  let elementReturnType = [{ ::llvm::APInt }];
 
   let convertFromStorage = "$_self";
 }
@@ -428,6 +437,7 @@ class DenseArrayAttrBase<string denseAttrName, string cppType, string summaryNam
                      summaryName # " dense array attribute"> {
   let storageType = "::mlir::" # denseAttrName;
   let returnType = "::llvm::ArrayRef<" # cppType # ">";
+  let elementReturnType = cppType;
   let constBuilderCall = "$_builder.get" # denseAttrName # "($0)";
 }
 def DenseBoolArrayAttr : DenseArrayAttrBase<"DenseBoolArrayAttr", "bool", "i1">;
@@ -486,6 +496,8 @@ class RankedSignlessIntElementsAttr<int width, list<int> dims> :
   let constBuilderCall = "::mlir::DenseIntElementsAttr::get("
     "::mlir::RankedTensorType::get({" # !interleave(dims, ", ") #
     "}, $_builder.getIntegerType(" # width # ")), ::llvm::ArrayRef($0))";
+  
+  let rank = !size(dims);
 }
 
 class RankedI32ElementsAttr<list<int> dims> :
@@ -501,6 +513,7 @@ class FloatElementsAttr<int width> : ElementsAttrBase<
 
   let storageType = [{ ::mlir::DenseElementsAttr }];
   let returnType = [{ ::mlir::DenseElementsAttr }];
+  let elementReturnType = [{ ::llvm::APFloat }];
 
   // Note that this is only constructing scalar elements attribute.
   let constBuilderCall = "::mlir::DenseElementsAttr::get("
@@ -526,6 +539,8 @@ class RankedFloatElementsAttr<int width, list<int> dims> : ElementsAttrBase<
 
   let storageType = [{ ::mlir::DenseFPElementsAttr }];
   let returnType = [{ ::mlir::DenseFPElementsAttr }];
+  let elementReturnType = [{ ::llvm::APFloat }];
+  let rank = !size(dims);
 
   let constBuilderCall = "::llvm::cast<::mlir::DenseFPElementsAttr>("
   "::mlir::DenseElementsAttr::get("
@@ -544,6 +559,7 @@ def StringElementsAttr : ElementsAttrBase<
 
   let storageType = [{ ::mlir::DenseElementsAttr }];
   let returnType = [{ ::mlir::DenseElementsAttr }];
+  let elementReturnType = [{ ::llvm::SmallString }];
 
   let convertFromStorage = "$_self";
 }
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index 11780f84697b15..c344ed88bea198 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -1370,7 +1370,7 @@ static LogicalResult verifyMapClause(Operation *op, OperandRange mapVars) {
 
 void TargetDataOp::build(OpBuilder &builder, OperationState &state,
                          const TargetDataOperands &clauses) {
-  TargetDataOp::build(builder, state, clauses.device, clauses.ifVar,
+  TargetDataOp::build(builder, state, clauses.device, clauses.ifExpr,
                       clauses.mapVars, clauses.useDeviceAddrVars,
                       clauses.useDevicePtrVars);
 }
@@ -1395,7 +1395,7 @@ void TargetEnterDataOp::build(
   MLIRContext *ctx = builder.getContext();
   TargetEnterDataOp::build(builder, state,
                            makeArrayAttr(ctx, clauses.dependKinds),
-                           clauses.dependVars, clauses.device, clauses.ifVar,
+                           clauses.dependVars, clauses.device, clauses.ifExpr,
                            clauses.mapVars, clauses.nowait);
 }
 
@@ -1415,7 +1415,7 @@ void TargetExitDataOp::build(OpBuilder &builder, OperationState &state,
   MLIRContext *ctx = builder.getContext();
   TargetExitDataOp::build(builder, state,
                           makeArrayAttr(ctx, clauses.dependKinds),
-                          clauses.dependVars, clauses.device, clauses.ifVar,
+                          clauses.dependVars, clauses.device, clauses.ifExpr,
                           clauses.mapVars, clauses.nowait);
 }
 
@@ -1434,7 +1434,7 @@ void TargetUpdateOp::build(OpBuilder &builder, OperationState &state,
                            const TargetEnterExitUpdateDataOperands &clauses) {
   MLIRContext *ctx = builder.getContext();
   TargetUpdateOp::build(builder, state, makeArrayAttr(ctx, clauses.dependKinds),
-                        clauses.dependVars, clauses.device, clauses.ifVar,
+                        clauses.dependVars, clauses.device, clauses.ifExpr,
                         clauses.mapVars, clauses.nowait);
 }
 
@@ -1456,7 +1456,7 @@ void TargetOp::build(OpBuilder &builder, OperationState &state,
   // inReductionByref, inReductionSyms.
   TargetOp::build(builder, state, /*allocate_vars=*/{}, /*allocator_vars=*/{},
                   makeArrayAttr(ctx, clauses.dependKinds), clauses.dependVars,
-                  clauses.device, clauses.hasDeviceAddrVars, clauses.ifVar,
+                  clauses.device, clauses.hasDeviceAddrVars, clauses.ifExpr,
                   /*in_reduction_vars=*/{}, /*in_reduction_byref=*/nullptr,
                   /*in_reduction_syms=*/nullptr, clauses.isDevicePtrVars,
                   clauses.mapVars, clauses.nowait, clauses.privateVars,
@@ -1488,9 +1488,8 @@ void ParallelOp::build(OpBuilder &builder, OperationState &state,
 void ParallelOp::build(OpBuilder &builder, OperationState &state,
                        const ParallelOperands &clauses) {
   MLIRContext *ctx = builder.getContext();
-
   ParallelOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
-                    clauses.ifVar, clauses.numThreads, clauses.privateVars,
+                    clauses.ifExpr, clauses.numThreads, clauses.privateVars,
                     makeArrayAttr(ctx, clauses.privateSyms),
                     clauses.procBindKind, clauses.reductionVars,
                     makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
@@ -1583,13 +1582,12 @@ void TeamsOp::build(OpBuilder &builder, OperationState &state,
                     const TeamsOperands &clauses) {
   MLIRContext *ctx = builder.getContext();
   // TODO Store clauses in op: privateVars, privateSyms.
-  TeamsOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
-                 clauses.ifVar, clauses.numTeamsLower, clauses.numTeamsUpper,
-                 /*private_vars=*/{},
-                 /*private_syms=*/nullptr, clauses.reductionVars,
-                 makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
-                 makeArrayAttr(ctx, clauses.reductionSyms),
-                 clauses.threadLimit);
+  TeamsOp::build(
+      builder, state, clauses.allocateVars, clauses.allocatorVars,
+      clauses.ifExpr, clauses.numTeamsLower, clauses.numTeamsUpper,
+      /*private_vars=*/{}, /*private_syms=*/nullptr, clauses.reductionVars,
+      makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
+      makeArrayAttr(ctx, clauses.reductionSyms), clauses.threadLimit);
 }
 
 LogicalResult TeamsOp::verify() {
@@ -1769,7 +1767,7 @@ void SimdOp::build(OpBuilder &builder, OperationState &state,
   // TODO Store clauses in op: linearVars, linearStepVars, privateVars,
   // privateSyms, reductionVars, reductionByref, reductionSyms.
   SimdOp::build(builder, state, clauses.alignedVars,
-                makeArrayAttr(ctx, clauses.alignments), clauses.ifVar,
+                makeArrayAttr(ctx, clauses.alignments), clauses.ifExpr,
                 /*linear_vars=*/{}, /*linear_step_vars=*/{},
                 clauses.nontemporalVars, clauses.order, clauses.orderMod,
                 /*private_vars=*/{}, /*private_syms=*/nullptr,
@@ -1938,7 +1936,7 @@ void TaskOp::build(OpBuilder &builder, OperationState &state,
   // TODO Store clauses in op: privateVars, privateSyms.
   TaskOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
                 makeArrayAttr(ctx, clauses.dependKinds), clauses.dependVars,
-                clauses.final, clauses.ifVar, clauses.inReductionVars,
+                clauses.final, clauses.ifExpr, clauses.inReductionVars,
                 makeDenseBoolArrayAttr(ctx, clauses.inReductionByref),
                 makeArrayAttr(ctx, clauses.inReductionSyms), clauses.mergeable,
                 clauses.priority, /*private_vars=*/{}, /*private_syms=*/nullptr,
@@ -1984,7 +1982,7 @@ void TaskloopOp::build(OpBuilder &builder, OperationState &state,
   // TODO Store clauses in op: privateVars, privateSyms.
   TaskloopOp::build(
       builder, state, clauses.allocateVars, clauses.allocatorVars,
-      clauses.final, clauses.grainsize, clauses.ifVar, clauses.inReductionVars,
+      clauses.final, clauses.grainsize, clauses.ifExpr, clauses.inReductionVars,
       makeDenseBoolArrayAttr(ctx, clauses.inReductionByref),
       makeArrayAttr(ctx, clauses.inReductionSyms), clauses.mergeable,
       clauses.nogroup, clauses.numTasks, clauses.priority, /*private_vars=*/{},
@@ -2363,7 +2361,7 @@ LogicalResult AtomicCaptureOp::verifyRegions() {
 
 void CancelOp::build(OpBuilder &builder, OperationState &state,
                      const CancelOperands &clauses) {
-  CancelOp::build(builder, state, clauses.cancelDirective, clauses.ifVar);
+  CancelOp::build(builder, state, clauses.cancelDirective, clauses.ifExpr);
 }
 
 LogicalResult CancelOp::verify() {
diff --git a/mlir/test/mlir-tblgen/openmp-clause-ops.td b/mlir/test/mlir-tblgen/openmp-clause-ops.td
new file mode 100644
index 00000000000000..b0139eb546e1be
--- /dev/null
+++ b/mlir/test/mlir-tblgen/openmp-clause-ops.td
@@ -0,0 +1,78 @@
+// Tablegen tests for the automatic generation of OpenMP clause operand
+// structure definitions.
+
+// Run tablegen to generate OmpCommon.td in temp directory first.
+// RUN: mkdir -p %t/mlir/Dialect/OpenMP
+// RUN: mlir-tblgen --gen-directive-decl --directives-dialect=OpenMP \
+// RUN:   %S/../../../llvm/include/llvm/Frontend/OpenMP/OMP.td \
+// RUN:   -I %S/../../../llvm/include > %t/mlir/Dialect/OpenMP/OmpCommon.td
+
+// RUN: mlir-tblgen -gen-openmp-clause-ops -I %S/../../include -I %t %s | FileCheck %s
+
+include "mlir/Dialect/OpenMP/OpenMPOpBase.td"
+
+
+def OpenMP_MyFirstClause : OpenMP_Clause<
+    /*isRequired=*/false, /*skipTraits=*/false, /*skipArguments=*/false,
+    /*skipAssemblyFormat=*/false, /*skipDescription=*/false,
+    /*skipExtraClassDeclaration=*/false> {
+  let arguments = (ins
+    // Simple attributes
+    I32Attr:$int_attr,
+    TypeAttr:$type_attr,
+    DeclareTargetAttr:$omp_attr,
+
+    // Array attributes
+    F32ArrayAttr:$float_array_attr,
+    StrArrayAttr:$str_array_attr,
+    AnyIntElementsAttr:$anyint_elems_attr,
+    RankedF32ElementsAttr<[3, 4, 5]>:$float_nd_elems_attr,
+
+    // Optional attributes
+    OptionalAttr<BoolAttr>:$opt_bool_attr,
+    OptionalAttr<I64ArrayAttr>:$opt_int_array_attr,
+    OptionalAttr<DenseI8ArrayAttr>:$opt_int_elems_attr,
+
+    // Multi-level composition
+    ConfinedAttr<OptionalAttr<I64Attr>, [IntMinValue<0>]>:$complex_opt_int_attr
+  );
+}
+// CHECK:      struct MyFirstClauseOps {
+// CHECK-NEXT:   ::mlir::IntegerAttr intAttr;
+// CHECK-NEXT:   ::mlir::TypeAttr typeAttr;
+// CHECK-NEXT:   ::mlir::omp::DeclareTargetAttr ompAttr;
+
+// CHECK-NEXT:   ::llvm::SmallVector<::mlir::Attribute> floatArrayAttr;
+// CHECK-NEXT:   ::llvm::SmallVector<::mlir::Attribute> strArrayAttr;
+// CHECK-NEXT:   ::llvm::SmallVector<::llvm::APInt> anyintElemsAttr;
+// CHECK-NEXT:   ::llvm::SmallVector<::llvm::APFloat> floatNdElemsAttr;
+// CHECK-NEXT:   int floatNdElemsAttrDims[3];
+
+// CHECK-NEXT:   ::mlir::BoolAttr optBoolAttr;
+// CHECK-NEXT:   ::llvm::SmallVector<::mlir::Attribute> optIntArrayAttr;
+// CHECK-NEXT:   ::llvm::SmallVector<int8_t> optIntElemsAttr;
+
+// CHECK-NEXT:   ::mlir::IntegerAttr complexOptIntAttr;
+// CHECK-NEXT: }
+
+def OpenMP_MySecondClause : OpenMP_Clause<
+    /*isRequired=*/false, /*skipTraits=*/false, /*skipArguments=*/false,
+    /*skipAssemblyFormat=*/false, /*skipDescription=*/false,
+    /*skipExtraClassDeclaration=*/false> {
+  let arguments = (ins
+    I32:$int_val,
+    Optional<AnyType>:$opt_any_val,
+    Variadic<Index>:$variadic_index_val
+  );
+}
+// CHECK:      struct MySecondClauseOps {
+// CHECK-NEXT:   ::mlir::Value intVal;
+// CHECK-NEXT:   ::mlir::Value optAnyVal;
+// CHECK-NEXT:   ::llvm::SmallVector<::mlir::Value> variadicIndexVal;
+// CHECK-NEXT: }
+
+def OpenMP_MyFirstOp : OpenMP_Op<"op", clauses=[OpenMP_MyFirstClause]>;
+// CHECK: using MyFirstOperands = detail::Clauses<MyFirstClauseOps>;
+
+def OpenMP_MySecondOp : OpenMP_Op<"op", clauses=[OpenMP_MyFirstClause, OpenMP_MySecondClause]>;
+// CHECK: using MySecondOperands = detail::Clauses<MyFirstClauseOps, MySecondClauseOps>;
diff --git a/mlir/tools/mlir-tblgen/OmpOpGen.cpp b/mlir/tools/mlir-tblgen/OmpOpGen.cpp
index 51eb43f322e6ad..9d91495078ff0b 100644
--- a/mlir/tools/mlir-tblgen/OmpOpGen.cpp
+++ b/mlir/tools/mlir-tblgen/OmpOpGen.cpp
@@ -12,11 +12,52 @@
 
 #include "mlir/TableGen/GenInfo.h"
 
+#include "mlir/TableGen/CodeGenHelpers.h"
+#include "llvm/ADT/StringExtras.h"
+#include "llvm/ADT/TypeSwitch.h"
 #include "llvm/TableGen/Error.h"
 #include "llvm/TableGen/Record.h"
 
 using namespace llvm;
 
+/// The code block defining the base mixin class for combining clause operand
+/// structures.
+static const char *const baseMixinClass = R"(
+namespace detail {
+template <typename... Mixins>
+struct Clauses : public Mixins... {};
+} // namespace detail
+)";
+
+/// The code block defining operation argument structures.
+static const char *const operationArgStruct = R"(
+using {0}Operands = detail::Clauses<{1}>;
+)";
+
+/// Remove multiple optional prefixes and suffixes from \c str.
+///
+/// Prefixes and suffixes are attempted to be removed once in the order they
+/// appear in the \c prefixes and \c suffixes arguments. All prefixes are
+/// processed before suffixes are. This means it will behave as shown in the
+/// following example:
+///   - str: "PrePreNameSuf1Suf2"
+///   - prefixes: ["Pre"]
+///   - suffixes: ["Suf1", "Suf2"]
+///   - return: "PreNameSuf1"
+static StringRef stripPrefixAndSuffix(StringRef str,
+                                      llvm::ArrayRef<StringRef> prefixes,
+                                      llvm::ArrayRef<StringRef> suffixes) {
+  for (StringRef prefix : prefixes)
+    if (str.starts_with(prefix))
+      str = str.drop_front(prefix.size());
+
+  for (StringRef suffix : suffixes)
+    if (str.ends_with(suffix))
+      str = str.drop_back(suffix.size());
+
+  return str;
+}
+
 /// Obtain the name of the OpenMP clause a given record inheriting
 /// `OpenMP_Clause` refers to.
 ///
@@ -53,19 +94,8 @@ static StringRef extractOmpClauseName(Record *clause) {
   assert(!clauseClassName.empty() && "clause name must be found");
 
   // Keep only the OpenMP clause name itself for reporting purposes.
-  StringRef prefix = "OpenMP_";
-  StringRef suffixes[] = {"Skip", "Clause"};
-
-  if (clauseClassName.starts_with(prefix))
-    clauseClassName = clauseClassName.substr(prefix.size());
-
-  for (StringRef suffix : suffixes) {
-    if (clauseClassName.ends_with(suffix))
-      clauseClassName =
-          clauseClassName.substr(0, clauseClassName.size() - suffix.size());
-  }
-
-  return clauseClassName;
+  return stripPrefixAndSuffix(clauseClassName, /*prefixes=*/{"OpenMP_"},
+                              /*suffixes=*/{"Skip", "Clause"});
 }
 
 /// Check that the given argument, identified by its name and initialization
@@ -148,6 +178,111 @@ static void verifyClause(Record *op, Record *clause) {
             "or explicitly skipping this field.");
 }
 
+/// Translate the type of an OpenMP clause's argument to its corresponding
+/// representation for clause operand structures.
+///
+/// All kinds of values are represented as `mlir::Value` fields, whereas
+/// attributes are represented based on their `storageType`.
+///
+/// \param[in] init The `DefInit` object representing the argument.
+/// \param[out] rank Number of levels of array nesting associated with the
+///                  type.
+///
+/// \return the name of the base type to represent elements of the argument
+///         type.
+static StringRef translateArgumentType(Init *init, int &rank) {
+  Record *def = cast<DefInit>(init)->getDef();
+  bool isAttr = false, isValue = false;
+
+  for (auto [sc, _] : def->getSuperClasses()) {
+    std::string scName = sc->getNameInitAsString();
+    if (scName == "OptionalAttr")
+      return translateArgumentType(def->getValue("baseAttr")->getValue(), rank);
+
+    if (scName == "TypedArrayAttrBase") {
+      ++rank;
+      return translateArgumentType(def->getValue("elementAttr")->getValue(),
+                                   rank);
+    }
+
+    if (scName == "ElementsAttrBase") {
+      rank += def->getValueAsInt("rank");
+      return def->getValueAsString("elementReturnType").trim();
+    }
+
+    if (scName == "Attr")
+      isAttr = true;
+    else if (scName == "TypeConstraint")
+      isValue = true;
+    else if (scName == "Variadic")
+      ++rank;
+  }
+
+  if (isValue) {
+    assert(!isAttr &&
+           "argument can't be simultaneously a value and an attribute");
+    return "::mlir::Value";
+  }
+
+  assert(isAttr && "argument must be an attribute if it's not a value");
+  return rank > 0 ? "::mlir::Attribute"
+                  : def->getValueAsString("storageType").trim();
+}
+
+/// Generate the structure that represents the arguments of the given \c clause
+/// record of type \c OpenMP_Clause.
+///
+/// It will contain a field for each argument, using the same name translated to
+/// camel case and the corresponding base type as returned by
+/// translateArgumentType() optionally wrapped in one or more llvm::SmallVector.
+static void genClauseOpsStruct(Record *clause, raw_ostream &os) {
+  if (clause->isAnonymous())
+    return;
+
+  StringRef clauseName = extractOmpClauseName(clause);
+  os << "struct " << clauseName << "ClauseOps {\n";
+
+  DagInit *arguments = clause->getValueAsDag("arguments");
+  for (auto [name, arg] :
+       zip_equal(arguments->getArgNames(), arguments->getArgs())) {
+    int rank = 0;
+    StringRef baseType = translateArgumentType(arg, rank);
+
+    if (rank > 0)
+      os << "  ::llvm::SmallVector<" << baseType << ">";
+    else
+      os << "  " << baseType;
+
+    std::string fieldName =
+        convertToCamelFromSnakeCase(name->getAsUnquotedString(),
+                                    /*capitalizeFirst=*/false);
+    os << " " << fieldName << ";\n";
+
+    if (rank > 1)
+      os << "  int " << fieldName << "Dims[" << rank << "];\n";
+  }
+
+  os << "};\n";
+}
+
+/// Generate the structure that represents the clause-related arguments of the
+/// given \c op record of type \c OpenMP_Op.
+///
+/// This structure will be defined in terms of the clause operand structures
+/// associated to the clauses of the operation.
+static void genOperandsDef(Record *op, raw_ostream &os) {
+  if (op->isAnonymous())
+    return;
+
+  SmallVector<std::string> clauseNames;
+  for (Record *clause : op->getValueAsListOfDefs("clauseList"))
+    clauseNames.push_back((extractOmpClauseName(clause) + "ClauseOps").str());
+
+  StringRef opName = stripPrefixAndSuffix(
+      op->getName(), /*prefixes=*/{"OpenMP_"}, /*suffixes=*/{"Op"});
+  os << formatv(operationArgStruct, opName, join(clauseNames, ", "));
+}
+
 /// Verify that all properties of `OpenMP_Clause`s of records deriving from
 /// `OpenMP_Op`s have been inherited by the latter.
 static bool verifyDecls(const RecordKeeper &recordKeeper, raw_ostream &) {
@@ -159,8 +294,31 @@ static bool verifyDecls(const RecordKeeper &recordKeeper, raw_ostream &) {
   return false;
 }
 
+/// Generate structures to represent clause-related operands, based on existing
+/// `OpenMP_Clause` definitions and aggregate them into operation-specific
+/// structures according to the `clauses` argument of each definition deriving
+/// from `OpenMP_Op`.
+static bool genClauseOps(const RecordKeeper &recordKeeper, raw_ostream &os) {
+  mlir::tblgen::NamespaceEmitter ns(os, "mlir::omp");
+  for (Record *clause : recordKeeper.getAllDerivedDefinitions("OpenMP_Clause"))
+    genClauseOpsStruct(clause, os);
+
+  // Produce base mixin class.
+  os << baseMixinClass;
+
+  for (Record *op : recordKeeper.getAllDerivedDefinitions("OpenMP_Op"))
+    genOperandsDef(op, os);
+
+  return false;
+}
+
 // Registers the generator to mlir-tblgen.
 static mlir::GenRegistration
     verifyOpenmpOps("verify-openmp-ops",
                     "Verify OpenMP operations (produce no output file)",
                     verifyDecls);
+
+static mlir::GenRegistration
+    genOpenmpClauseOps("gen-openmp-clause-ops",
+                       "Generate OpenMP clause operand structures",
+                       genClauseOps);

>From 56a445c48f8e4de7f6c912b4badd8ecf6cc23494 Mon Sep 17 00:00:00 2001
From: Sergio Afonso <safonsof at amd.com>
Date: Tue, 27 Aug 2024 17:21:10 +0100
Subject: [PATCH 2/2] Address review comments

---
 mlir/include/mlir/IR/CommonAttrConstraints.td | 18 +----
 mlir/test/mlir-tblgen/openmp-clause-ops.td    | 14 +++-
 mlir/tools/mlir-tblgen/OmpOpGen.cpp           | 71 ++++++++++++-------
 3 files changed, 58 insertions(+), 45 deletions(-)

diff --git a/mlir/include/mlir/IR/CommonAttrConstraints.td b/mlir/include/mlir/IR/CommonAttrConstraints.td
index f9dcfedfee1051..d99bde1f87ef00 100644
--- a/mlir/include/mlir/IR/CommonAttrConstraints.td
+++ b/mlir/include/mlir/IR/CommonAttrConstraints.td
@@ -408,18 +408,10 @@ class ElementsAttrBase<Pred condition, string summary> :
   let storageType = [{ ::mlir::ElementsAttr }];
   let returnType = [{ ::mlir::ElementsAttr }];
   let convertFromStorage = "$_self";
-
-  // The underlying C++ value type of each element.
-  string elementReturnType = ?;
-
-  // The number of dimensions represented by the element collection.
-  int rank = 1;
 }
 
 def ElementsAttr : ElementsAttrBase<CPred<"::llvm::isa<::mlir::ElementsAttr>($_self)">,
-                                    "constant vector/tensor attribute"> {
-  let elementReturnType = [{ ::mlir::Attribute }];
-}
+                                    "constant vector/tensor attribute">;
 
 class IntElementsAttrBase<Pred condition, string summary> :
     ElementsAttrBase<And<[CPred<"::llvm::isa<::mlir::DenseIntElementsAttr>($_self)">,
@@ -427,7 +419,6 @@ class IntElementsAttrBase<Pred condition, string summary> :
                      summary> {
   let storageType = [{ ::mlir::DenseIntElementsAttr }];
   let returnType = [{ ::mlir::DenseIntElementsAttr }];
-  let elementReturnType = [{ ::llvm::APInt }];
 
   let convertFromStorage = "$_self";
 }
@@ -437,7 +428,6 @@ class DenseArrayAttrBase<string denseAttrName, string cppType, string summaryNam
                      summaryName # " dense array attribute"> {
   let storageType = "::mlir::" # denseAttrName;
   let returnType = "::llvm::ArrayRef<" # cppType # ">";
-  let elementReturnType = cppType;
   let constBuilderCall = "$_builder.get" # denseAttrName # "($0)";
 }
 def DenseBoolArrayAttr : DenseArrayAttrBase<"DenseBoolArrayAttr", "bool", "i1">;
@@ -496,8 +486,6 @@ class RankedSignlessIntElementsAttr<int width, list<int> dims> :
   let constBuilderCall = "::mlir::DenseIntElementsAttr::get("
     "::mlir::RankedTensorType::get({" # !interleave(dims, ", ") #
     "}, $_builder.getIntegerType(" # width # ")), ::llvm::ArrayRef($0))";
-  
-  let rank = !size(dims);
 }
 
 class RankedI32ElementsAttr<list<int> dims> :
@@ -513,7 +501,6 @@ class FloatElementsAttr<int width> : ElementsAttrBase<
 
   let storageType = [{ ::mlir::DenseElementsAttr }];
   let returnType = [{ ::mlir::DenseElementsAttr }];
-  let elementReturnType = [{ ::llvm::APFloat }];
 
   // Note that this is only constructing scalar elements attribute.
   let constBuilderCall = "::mlir::DenseElementsAttr::get("
@@ -539,8 +526,6 @@ class RankedFloatElementsAttr<int width, list<int> dims> : ElementsAttrBase<
 
   let storageType = [{ ::mlir::DenseFPElementsAttr }];
   let returnType = [{ ::mlir::DenseFPElementsAttr }];
-  let elementReturnType = [{ ::llvm::APFloat }];
-  let rank = !size(dims);
 
   let constBuilderCall = "::llvm::cast<::mlir::DenseFPElementsAttr>("
   "::mlir::DenseElementsAttr::get("
@@ -559,7 +544,6 @@ def StringElementsAttr : ElementsAttrBase<
 
   let storageType = [{ ::mlir::DenseElementsAttr }];
   let returnType = [{ ::mlir::DenseElementsAttr }];
-  let elementReturnType = [{ ::llvm::SmallString }];
 
   let convertFromStorage = "$_self";
 }
diff --git a/mlir/test/mlir-tblgen/openmp-clause-ops.td b/mlir/test/mlir-tblgen/openmp-clause-ops.td
index b0139eb546e1be..cee3f2a693bf8c 100644
--- a/mlir/test/mlir-tblgen/openmp-clause-ops.td
+++ b/mlir/test/mlir-tblgen/openmp-clause-ops.td
@@ -7,7 +7,7 @@
 // RUN:   %S/../../../llvm/include/llvm/Frontend/OpenMP/OMP.td \
 // RUN:   -I %S/../../../llvm/include > %t/mlir/Dialect/OpenMP/OmpCommon.td
 
-// RUN: mlir-tblgen -gen-openmp-clause-ops -I %S/../../include -I %t %s | FileCheck %s
+// RUN: mlir-tblgen -gen-openmp-clause-ops -I %S/../../include -I %t %s 2>&1 | FileCheck %s
 
 include "mlir/Dialect/OpenMP/OpenMPOpBase.td"
 
@@ -34,7 +34,13 @@ def OpenMP_MyFirstClause : OpenMP_Clause<
     OptionalAttr<DenseI8ArrayAttr>:$opt_int_elems_attr,
 
     // Multi-level composition
-    ConfinedAttr<OptionalAttr<I64Attr>, [IntMinValue<0>]>:$complex_opt_int_attr
+    ConfinedAttr<OptionalAttr<I64Attr>, [IntMinValue<0>]>:$complex_opt_int_attr,
+
+    // ElementsAttrBase-related edge cases.
+    // CHECK: warning: could not infer array-like attribute element type for argument 'elements_attr', will use bare `storageType`
+    ElementsAttr:$elements_attr,
+    // CHECK: warning: could not infer array-like attribute element type for argument 'string_elements_attr', will use bare `storageType`
+    StringElementsAttr:$string_elements_attr
   );
 }
 // CHECK:      struct MyFirstClauseOps {
@@ -46,13 +52,15 @@ def OpenMP_MyFirstClause : OpenMP_Clause<
 // CHECK-NEXT:   ::llvm::SmallVector<::mlir::Attribute> strArrayAttr;
 // CHECK-NEXT:   ::llvm::SmallVector<::llvm::APInt> anyintElemsAttr;
 // CHECK-NEXT:   ::llvm::SmallVector<::llvm::APFloat> floatNdElemsAttr;
-// CHECK-NEXT:   int floatNdElemsAttrDims[3];
 
 // CHECK-NEXT:   ::mlir::BoolAttr optBoolAttr;
 // CHECK-NEXT:   ::llvm::SmallVector<::mlir::Attribute> optIntArrayAttr;
 // CHECK-NEXT:   ::llvm::SmallVector<int8_t> optIntElemsAttr;
 
 // CHECK-NEXT:   ::mlir::IntegerAttr complexOptIntAttr;
+
+// CHECK-NEXT:   ::mlir::ElementsAttr elementsAttr;
+// CHECK-NEXT:   ::mlir::DenseElementsAttr stringElementsAttr;
 // CHECK-NEXT: }
 
 def OpenMP_MySecondClause : OpenMP_Clause<
diff --git a/mlir/tools/mlir-tblgen/OmpOpGen.cpp b/mlir/tools/mlir-tblgen/OmpOpGen.cpp
index 9d91495078ff0b..2635a8bce5b11c 100644
--- a/mlir/tools/mlir-tblgen/OmpOpGen.cpp
+++ b/mlir/tools/mlir-tblgen/OmpOpGen.cpp
@@ -14,6 +14,7 @@
 
 #include "mlir/TableGen/CodeGenHelpers.h"
 #include "llvm/ADT/StringExtras.h"
+#include "llvm/ADT/StringSet.h"
 #include "llvm/ADT/TypeSwitch.h"
 #include "llvm/TableGen/Error.h"
 #include "llvm/TableGen/Record.h"
@@ -184,40 +185,59 @@ static void verifyClause(Record *op, Record *clause) {
 /// All kinds of values are represented as `mlir::Value` fields, whereas
 /// attributes are represented based on their `storageType`.
 ///
+/// \param[in] name The name of the argument.
 /// \param[in] init The `DefInit` object representing the argument.
 /// \param[out] rank Number of levels of array nesting associated with the
 ///                  type.
 ///
 /// \return the name of the base type to represent elements of the argument
 ///         type.
-static StringRef translateArgumentType(Init *init, int &rank) {
+static StringRef translateArgumentType(ArrayRef<SMLoc> loc, StringInit *name,
+                                       Init *init, int &rank) {
   Record *def = cast<DefInit>(init)->getDef();
-  bool isAttr = false, isValue = false;
 
-  for (auto [sc, _] : def->getSuperClasses()) {
-    std::string scName = sc->getNameInitAsString();
-    if (scName == "OptionalAttr")
-      return translateArgumentType(def->getValue("baseAttr")->getValue(), rank);
-
-    if (scName == "TypedArrayAttrBase") {
-      ++rank;
-      return translateArgumentType(def->getValue("elementAttr")->getValue(),
-                                   rank);
-    }
-
-    if (scName == "ElementsAttrBase") {
-      rank += def->getValueAsInt("rank");
-      return def->getValueAsString("elementReturnType").trim();
-    }
-
-    if (scName == "Attr")
-      isAttr = true;
-    else if (scName == "TypeConstraint")
-      isValue = true;
-    else if (scName == "Variadic")
-      ++rank;
+  llvm::StringSet superClasses;
+  for (auto [sc, _] : def->getSuperClasses())
+    superClasses.insert(sc->getNameInitAsString());
+
+  // Handle wrapper-style superclasses.
+  if (superClasses.contains("OptionalAttr"))
+    return translateArgumentType(loc, name,
+                                 def->getValue("baseAttr")->getValue(), rank);
+
+  if (superClasses.contains("TypedArrayAttrBase"))
+    return translateArgumentType(
+        loc, name, def->getValue("elementAttr")->getValue(), ++rank);
+
+  // Handle ElementsAttrBase superclasses.
+  if (superClasses.contains("ElementsAttrBase")) {
+    // TODO: Support properly obtaining rank from ranked types.
+    ++rank;
+
+    if (superClasses.contains("IntElementsAttrBase"))
+      return "::llvm::APInt";
+    if (superClasses.contains("FloatElementsAttr") ||
+        superClasses.contains("RankedFloatElementsAttr"))
+      return "::llvm::APFloat";
+    if (superClasses.contains("DenseArrayAttrBase"))
+      return stripPrefixAndSuffix(def->getValueAsString("returnType"),
+                                  {"::llvm::ArrayRef<"}, {">"});
+
+    // Reset the rank in the case where the base type cannot be inferred, so
+    // that the bare storageType is used instead of a vector.
+    rank = 0;
+    PrintWarning(
+        loc,
+        "could not infer array-like attribute element type for argument '" +
+            name->getAsUnquotedString() + "', will use bare `storageType`");
   }
 
+  // Handle simple attribute and value types.
+  bool isAttr = superClasses.contains("Attr");
+  bool isValue = superClasses.contains("TypeConstraint");
+  if (superClasses.contains("Variadic"))
+    ++rank;
+
   if (isValue) {
     assert(!isAttr &&
            "argument can't be simultaneously a value and an attribute");
@@ -246,7 +266,8 @@ static void genClauseOpsStruct(Record *clause, raw_ostream &os) {
   for (auto [name, arg] :
        zip_equal(arguments->getArgNames(), arguments->getArgs())) {
     int rank = 0;
-    StringRef baseType = translateArgumentType(arg, rank);
+    StringRef baseType =
+        translateArgumentType(clause->getLoc(), name, arg, rank);
 
     if (rank > 0)
       os << "  ::llvm::SmallVector<" << baseType << ">";



More information about the flang-commits mailing list