[llvm-branch-commits] [mlir] [MLIR][OpenMP] Automate operand structure definition (PR #99508)
Sergio Afonso via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Fri Jul 19 07:02:45 PDT 2024
https://github.com/skatrak updated https://github.com/llvm/llvm-project/pull/99508
>From 1d99939c020aab8650cd20df24e0b1e71726ae90 Mon Sep 17 00:00:00 2001
From: Sergio Afonso <safonsof at amd.com>
Date: Wed, 17 Jul 2024 13:26:09 +0100
Subject: [PATCH 1/2] [MLIR][OpenMP] Automate operand structure definition
This patch adds the "gen-openmp-clause-ops" `mlir-tblgen` generator to produce
the structure definitions previously in OpenMPClauseOperands.h automatically
from the information contained in OpenMPOps.td and OpenMPClauses.td.
Changes introduced to the `ElementsAttrBase` common tablegen class, as well as
some of its subclasses, add more fine-grained information on their shape and
type of their elements. This information is needed in order to properly
generate the corresponding types to represent these attributes within the
produced operand structures.
The original header is maintained to enable the definition of similar
structures that are not directly related to any single `OpenMP_Clause` or
`OpenMP_Op` tablegen definition.
---
.../mlir/Dialect/OpenMP/CMakeLists.txt | 1 +
.../Dialect/OpenMP/OpenMPClauseOperands.h | 290 +-----------------
mlir/include/mlir/IR/CommonAttrConstraints.td | 18 +-
mlir/test/mlir-tblgen/openmp-clause-ops.td | 78 +++++
mlir/tools/mlir-tblgen/OmpOpGen.cpp | 174 ++++++++++-
5 files changed, 266 insertions(+), 295 deletions(-)
create mode 100644 mlir/test/mlir-tblgen/openmp-clause-ops.td
diff --git a/mlir/include/mlir/Dialect/OpenMP/CMakeLists.txt b/mlir/include/mlir/Dialect/OpenMP/CMakeLists.txt
index d3422f6e48b06..23ccba3067bcb 100644
--- a/mlir/include/mlir/Dialect/OpenMP/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/OpenMP/CMakeLists.txt
@@ -17,6 +17,7 @@ mlir_tablegen(OpenMPOpsDialect.h.inc -gen-dialect-decls -dialect=omp)
mlir_tablegen(OpenMPOpsDialect.cpp.inc -gen-dialect-defs -dialect=omp)
mlir_tablegen(OpenMPOps.h.inc -gen-op-decls)
mlir_tablegen(OpenMPOps.cpp.inc -gen-op-defs)
+mlir_tablegen(OpenMPClauseOps.h.inc -gen-openmp-clause-ops)
mlir_tablegen(OpenMPOpsTypes.h.inc -gen-typedef-decls -typedefs-dialect=omp)
mlir_tablegen(OpenMPOpsTypes.cpp.inc -gen-typedef-defs -typedefs-dialect=omp)
mlir_tablegen(OpenMPOpsEnums.h.inc -gen-enum-decls)
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauseOperands.h b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauseOperands.h
index f4a87d52a172e..e5b4de4908966 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauseOperands.h
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauseOperands.h
@@ -23,303 +23,31 @@
#define GET_ATTRDEF_CLASSES
#include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.h.inc"
+#include "mlir/Dialect/OpenMP/OpenMPClauseOps.h.inc"
+
namespace mlir {
namespace omp {
//===----------------------------------------------------------------------===//
-// Mixin structures defining MLIR operands associated with each OpenMP clause.
+// Extra clause operand structures.
//===----------------------------------------------------------------------===//
-struct AlignedClauseOps {
- llvm::SmallVector<Value> alignedVars;
- llvm::SmallVector<Attribute> alignments;
-};
-
-struct AllocateClauseOps {
- llvm::SmallVector<Value> allocateVars, allocatorVars;
-};
-
-struct CancelDirectiveNameClauseOps {
- ClauseCancellationConstructTypeAttr cancelDirective;
-};
-
-struct CollapseClauseOps {
- llvm::SmallVector<Value> collapseLowerBound, collapseUpperBound, collapseStep;
-};
-
-struct CopyprivateClauseOps {
- llvm::SmallVector<Value> copyprivateVars;
- llvm::SmallVector<Attribute> copyprivateSyms;
-};
-
-struct CriticalNameClauseOps {
- StringAttr symName;
-};
-
-struct DependClauseOps {
- llvm::SmallVector<Attribute> dependKinds;
- llvm::SmallVector<Value> dependVars;
-};
-
-struct DeviceClauseOps {
- Value device;
-};
-
struct DeviceTypeClauseOps {
// The default capture type.
DeclareTargetDeviceType deviceType = DeclareTargetDeviceType::any;
};
-struct DistScheduleClauseOps {
- UnitAttr distScheduleStatic;
- Value distScheduleChunkSize;
-};
-
-struct DoacrossClauseOps {
- ClauseDependAttr doacrossDependType;
- IntegerAttr doacrossNumLoops;
- llvm::SmallVector<Value> doacrossDependVars;
-};
-
-struct FilterClauseOps {
- Value filteredThreadId;
-};
-
-struct FinalClauseOps {
- Value final;
-};
-
-struct GrainsizeClauseOps {
- Value grainsize;
-};
-
-struct HasDeviceAddrClauseOps {
- llvm::SmallVector<Value> hasDeviceAddrVars;
-};
-
-struct HintClauseOps {
- IntegerAttr hint;
-};
-
-struct IfClauseOps {
- Value ifVar;
-};
-
-struct InReductionClauseOps {
- llvm::SmallVector<Value> inReductionVars;
- llvm::SmallVector<bool> inReductionByref;
- llvm::SmallVector<Attribute> inReductionSyms;
-};
-
-struct IsDevicePtrClauseOps {
- llvm::SmallVector<Value> isDevicePtrVars;
-};
-
-struct LinearClauseOps {
- llvm::SmallVector<Value> linearVars, linearStepVars;
-};
-
-struct LoopRelatedOps {
- UnitAttr loopInclusive;
-};
-
-struct MapClauseOps {
- llvm::SmallVector<Value> mapVars;
-};
-
-struct MergeableClauseOps {
- UnitAttr mergeable;
-};
-
-struct NogroupClauseOps {
- UnitAttr nogroup;
-};
-
-struct NontemporalClauseOps {
- llvm::SmallVector<Value> nontemporalVars;
-};
-
-struct NowaitClauseOps {
- UnitAttr nowait;
-};
-
-struct NumTasksClauseOps {
- Value numTasks;
-};
-
-struct NumTeamsClauseOps {
- Value numTeamsLower, numTeamsUpper;
-};
-
-struct NumThreadsClauseOps {
- Value numThreads;
-};
-
-struct OrderClauseOps {
- ClauseOrderKindAttr order;
- OrderModifierAttr orderMod;
-};
-
-struct OrderedClauseOps {
- IntegerAttr ordered;
-};
-
-struct ParallelizationLevelClauseOps {
- UnitAttr parLevelSimd;
-};
-
-struct PriorityClauseOps {
- Value priority;
-};
-
-struct PrivateClauseOps {
- // SSA values that correspond to "original" values being privatized.
- // They refer to the SSA value outside the OpenMP region from which a clone is
- // created inside the region.
- llvm::SmallVector<Value> privateVars;
- // The list of symbols referring to delayed privatizer ops (i.e. `omp.private`
- // ops).
- llvm::SmallVector<Attribute> privateSyms;
-};
-
-struct ProcBindClauseOps {
- ClauseProcBindKindAttr procBindKind;
-};
-
-struct ReductionClauseOps {
- llvm::SmallVector<Value> reductionVars;
- llvm::SmallVector<bool> reductionByref;
- llvm::SmallVector<Attribute> reductionSyms;
-};
-
-struct SafelenClauseOps {
- IntegerAttr safelen;
-};
-
-struct ScheduleClauseOps {
- ClauseScheduleKindAttr scheduleKind;
- Value scheduleChunk;
- ScheduleModifierAttr scheduleMod;
- UnitAttr scheduleSimd;
-};
-
-struct SimdlenClauseOps {
- IntegerAttr simdlen;
-};
-
-struct TaskReductionClauseOps {
- llvm::SmallVector<Value> taskReductionVars;
- llvm::SmallVector<bool> taskReductionByref;
- llvm::SmallVector<Attribute> taskReductionSyms;
-};
-
-struct ThreadLimitClauseOps {
- Value threadLimit;
-};
-
-struct UntiedClauseOps {
- UnitAttr untied;
-};
-
-struct UseDeviceAddrClauseOps {
- llvm::SmallVector<Value> useDeviceAddrVars;
-};
-
-struct UseDevicePtrClauseOps {
- llvm::SmallVector<Value> useDevicePtrVars;
-};
-
//===----------------------------------------------------------------------===//
-// Structures defining clause operands associated with each OpenMP leaf
-// construct.
-//
-// These mirror the arguments expected by the corresponding OpenMP MLIR ops.
+// Extra operation operand structures.
//===----------------------------------------------------------------------===//
-namespace detail {
-template <typename... Mixins>
-struct Clauses : public Mixins... {};
-} // namespace detail
-
-using CancelOperands =
- detail::Clauses<CancelDirectiveNameClauseOps, IfClauseOps>;
-
-using CancellationPointOperands = detail::Clauses<CancelDirectiveNameClauseOps>;
-
-using CriticalDeclareOperands =
- detail::Clauses<CriticalNameClauseOps, HintClauseOps>;
-
-// TODO `indirect` clause.
+// TODO: Add `indirect` clause.
using DeclareTargetOperands = detail::Clauses<DeviceTypeClauseOps>;
-using DistributeOperands =
- detail::Clauses<AllocateClauseOps, DistScheduleClauseOps, OrderClauseOps,
- PrivateClauseOps>;
-
-using LoopNestOperands = detail::Clauses<CollapseClauseOps, LoopRelatedOps>;
-
-using MaskedOperands = detail::Clauses<FilterClauseOps>;
-
-using OrderedOperands = detail::Clauses<DoacrossClauseOps>;
-
-using OrderedRegionOperands = detail::Clauses<ParallelizationLevelClauseOps>;
-
-using ParallelOperands =
- detail::Clauses<AllocateClauseOps, IfClauseOps, NumThreadsClauseOps,
- PrivateClauseOps, ProcBindClauseOps, ReductionClauseOps>;
-
-using SectionsOperands = detail::Clauses<AllocateClauseOps, NowaitClauseOps,
- PrivateClauseOps, ReductionClauseOps>;
-
-// TODO `linear` clause.
-using SimdOperands =
- detail::Clauses<AlignedClauseOps, IfClauseOps, NontemporalClauseOps,
- OrderClauseOps, PrivateClauseOps, ReductionClauseOps,
- SafelenClauseOps, SimdlenClauseOps>;
-
-using SingleOperands = detail::Clauses<AllocateClauseOps, CopyprivateClauseOps,
- NowaitClauseOps, PrivateClauseOps>;
-
-// TODO `defaultmap`, `uses_allocators` clauses.
-using TargetOperands =
- detail::Clauses<AllocateClauseOps, DependClauseOps, DeviceClauseOps,
- HasDeviceAddrClauseOps, IfClauseOps, InReductionClauseOps,
- IsDevicePtrClauseOps, MapClauseOps, NowaitClauseOps,
- PrivateClauseOps, ThreadLimitClauseOps>;
-
-using TargetDataOperands =
- detail::Clauses<DeviceClauseOps, IfClauseOps, MapClauseOps,
- UseDeviceAddrClauseOps, UseDevicePtrClauseOps>;
-
-using TargetEnterExitUpdateDataOperands =
- detail::Clauses<DependClauseOps, DeviceClauseOps, IfClauseOps, MapClauseOps,
- NowaitClauseOps>;
-
-// TODO `affinity`, `detach` clauses.
-using TaskOperands =
- detail::Clauses<AllocateClauseOps, DependClauseOps, FinalClauseOps,
- IfClauseOps, InReductionClauseOps, MergeableClauseOps,
- PriorityClauseOps, PrivateClauseOps, UntiedClauseOps>;
-
-using TaskgroupOperands =
- detail::Clauses<AllocateClauseOps, TaskReductionClauseOps>;
-
-using TaskloopOperands =
- detail::Clauses<AllocateClauseOps, FinalClauseOps, GrainsizeClauseOps,
- IfClauseOps, InReductionClauseOps, MergeableClauseOps,
- NogroupClauseOps, NumTasksClauseOps, PriorityClauseOps,
- PrivateClauseOps, ReductionClauseOps, UntiedClauseOps>;
-
-using TaskwaitOperands = detail::Clauses<DependClauseOps, NowaitClauseOps>;
-
-using TeamsOperands =
- detail::Clauses<AllocateClauseOps, IfClauseOps, NumTeamsClauseOps,
- PrivateClauseOps, ReductionClauseOps, ThreadLimitClauseOps>;
-
-using WsloopOperands =
- detail::Clauses<AllocateClauseOps, LinearClauseOps, NowaitClauseOps,
- OrderClauseOps, OrderedClauseOps, PrivateClauseOps,
- ReductionClauseOps, ScheduleClauseOps>;
+// omp.target_enter_data, omp.target_exit_data and omp.target_update take the
+// same clauses, so we give the structure to be shared by all of them a
+// representative name.
+using TargetEnterExitUpdateDataOperands = TargetEnterDataOperands;
} // namespace omp
} // namespace mlir
diff --git a/mlir/include/mlir/IR/CommonAttrConstraints.td b/mlir/include/mlir/IR/CommonAttrConstraints.td
index d99bde1f87ef0..f9dcfedfee105 100644
--- a/mlir/include/mlir/IR/CommonAttrConstraints.td
+++ b/mlir/include/mlir/IR/CommonAttrConstraints.td
@@ -408,10 +408,18 @@ class ElementsAttrBase<Pred condition, string summary> :
let storageType = [{ ::mlir::ElementsAttr }];
let returnType = [{ ::mlir::ElementsAttr }];
let convertFromStorage = "$_self";
+
+ // The underlying C++ value type of each element.
+ string elementReturnType = ?;
+
+ // The number of dimensions represented by the element collection.
+ int rank = 1;
}
def ElementsAttr : ElementsAttrBase<CPred<"::llvm::isa<::mlir::ElementsAttr>($_self)">,
- "constant vector/tensor attribute">;
+ "constant vector/tensor attribute"> {
+ let elementReturnType = [{ ::mlir::Attribute }];
+}
class IntElementsAttrBase<Pred condition, string summary> :
ElementsAttrBase<And<[CPred<"::llvm::isa<::mlir::DenseIntElementsAttr>($_self)">,
@@ -419,6 +427,7 @@ class IntElementsAttrBase<Pred condition, string summary> :
summary> {
let storageType = [{ ::mlir::DenseIntElementsAttr }];
let returnType = [{ ::mlir::DenseIntElementsAttr }];
+ let elementReturnType = [{ ::llvm::APInt }];
let convertFromStorage = "$_self";
}
@@ -428,6 +437,7 @@ class DenseArrayAttrBase<string denseAttrName, string cppType, string summaryNam
summaryName # " dense array attribute"> {
let storageType = "::mlir::" # denseAttrName;
let returnType = "::llvm::ArrayRef<" # cppType # ">";
+ let elementReturnType = cppType;
let constBuilderCall = "$_builder.get" # denseAttrName # "($0)";
}
def DenseBoolArrayAttr : DenseArrayAttrBase<"DenseBoolArrayAttr", "bool", "i1">;
@@ -486,6 +496,8 @@ class RankedSignlessIntElementsAttr<int width, list<int> dims> :
let constBuilderCall = "::mlir::DenseIntElementsAttr::get("
"::mlir::RankedTensorType::get({" # !interleave(dims, ", ") #
"}, $_builder.getIntegerType(" # width # ")), ::llvm::ArrayRef($0))";
+
+ let rank = !size(dims);
}
class RankedI32ElementsAttr<list<int> dims> :
@@ -501,6 +513,7 @@ class FloatElementsAttr<int width> : ElementsAttrBase<
let storageType = [{ ::mlir::DenseElementsAttr }];
let returnType = [{ ::mlir::DenseElementsAttr }];
+ let elementReturnType = [{ ::llvm::APFloat }];
// Note that this is only constructing scalar elements attribute.
let constBuilderCall = "::mlir::DenseElementsAttr::get("
@@ -526,6 +539,8 @@ class RankedFloatElementsAttr<int width, list<int> dims> : ElementsAttrBase<
let storageType = [{ ::mlir::DenseFPElementsAttr }];
let returnType = [{ ::mlir::DenseFPElementsAttr }];
+ let elementReturnType = [{ ::llvm::APFloat }];
+ let rank = !size(dims);
let constBuilderCall = "::llvm::cast<::mlir::DenseFPElementsAttr>("
"::mlir::DenseElementsAttr::get("
@@ -544,6 +559,7 @@ def StringElementsAttr : ElementsAttrBase<
let storageType = [{ ::mlir::DenseElementsAttr }];
let returnType = [{ ::mlir::DenseElementsAttr }];
+ let elementReturnType = [{ ::llvm::SmallString }];
let convertFromStorage = "$_self";
}
diff --git a/mlir/test/mlir-tblgen/openmp-clause-ops.td b/mlir/test/mlir-tblgen/openmp-clause-ops.td
new file mode 100644
index 0000000000000..b0139eb546e1b
--- /dev/null
+++ b/mlir/test/mlir-tblgen/openmp-clause-ops.td
@@ -0,0 +1,78 @@
+// Tablegen tests for the automatic generation of OpenMP clause operand
+// structure definitions.
+
+// Run tablegen to generate OmpCommon.td in temp directory first.
+// RUN: mkdir -p %t/mlir/Dialect/OpenMP
+// RUN: mlir-tblgen --gen-directive-decl --directives-dialect=OpenMP \
+// RUN: %S/../../../llvm/include/llvm/Frontend/OpenMP/OMP.td \
+// RUN: -I %S/../../../llvm/include > %t/mlir/Dialect/OpenMP/OmpCommon.td
+
+// RUN: mlir-tblgen -gen-openmp-clause-ops -I %S/../../include -I %t %s | FileCheck %s
+
+include "mlir/Dialect/OpenMP/OpenMPOpBase.td"
+
+
+def OpenMP_MyFirstClause : OpenMP_Clause<
+ /*isRequired=*/false, /*skipTraits=*/false, /*skipArguments=*/false,
+ /*skipAssemblyFormat=*/false, /*skipDescription=*/false,
+ /*skipExtraClassDeclaration=*/false> {
+ let arguments = (ins
+ // Simple attributes
+ I32Attr:$int_attr,
+ TypeAttr:$type_attr,
+ DeclareTargetAttr:$omp_attr,
+
+ // Array attributes
+ F32ArrayAttr:$float_array_attr,
+ StrArrayAttr:$str_array_attr,
+ AnyIntElementsAttr:$anyint_elems_attr,
+ RankedF32ElementsAttr<[3, 4, 5]>:$float_nd_elems_attr,
+
+ // Optional attributes
+ OptionalAttr<BoolAttr>:$opt_bool_attr,
+ OptionalAttr<I64ArrayAttr>:$opt_int_array_attr,
+ OptionalAttr<DenseI8ArrayAttr>:$opt_int_elems_attr,
+
+ // Multi-level composition
+ ConfinedAttr<OptionalAttr<I64Attr>, [IntMinValue<0>]>:$complex_opt_int_attr
+ );
+}
+// CHECK: struct MyFirstClauseOps {
+// CHECK-NEXT: ::mlir::IntegerAttr intAttr;
+// CHECK-NEXT: ::mlir::TypeAttr typeAttr;
+// CHECK-NEXT: ::mlir::omp::DeclareTargetAttr ompAttr;
+
+// CHECK-NEXT: ::llvm::SmallVector<::mlir::Attribute> floatArrayAttr;
+// CHECK-NEXT: ::llvm::SmallVector<::mlir::Attribute> strArrayAttr;
+// CHECK-NEXT: ::llvm::SmallVector<::llvm::APInt> anyintElemsAttr;
+// CHECK-NEXT: ::llvm::SmallVector<::llvm::APFloat> floatNdElemsAttr;
+// CHECK-NEXT: int floatNdElemsAttrDims[3];
+
+// CHECK-NEXT: ::mlir::BoolAttr optBoolAttr;
+// CHECK-NEXT: ::llvm::SmallVector<::mlir::Attribute> optIntArrayAttr;
+// CHECK-NEXT: ::llvm::SmallVector<int8_t> optIntElemsAttr;
+
+// CHECK-NEXT: ::mlir::IntegerAttr complexOptIntAttr;
+// CHECK-NEXT: }
+
+def OpenMP_MySecondClause : OpenMP_Clause<
+ /*isRequired=*/false, /*skipTraits=*/false, /*skipArguments=*/false,
+ /*skipAssemblyFormat=*/false, /*skipDescription=*/false,
+ /*skipExtraClassDeclaration=*/false> {
+ let arguments = (ins
+ I32:$int_val,
+ Optional<AnyType>:$opt_any_val,
+ Variadic<Index>:$variadic_index_val
+ );
+}
+// CHECK: struct MySecondClauseOps {
+// CHECK-NEXT: ::mlir::Value intVal;
+// CHECK-NEXT: ::mlir::Value optAnyVal;
+// CHECK-NEXT: ::llvm::SmallVector<::mlir::Value> variadicIndexVal;
+// CHECK-NEXT: }
+
+def OpenMP_MyFirstOp : OpenMP_Op<"op", clauses=[OpenMP_MyFirstClause]>;
+// CHECK: using MyFirstOperands = detail::Clauses<MyFirstClauseOps>;
+
+def OpenMP_MySecondOp : OpenMP_Op<"op", clauses=[OpenMP_MyFirstClause, OpenMP_MySecondClause]>;
+// CHECK: using MySecondOperands = detail::Clauses<MyFirstClauseOps, MySecondClauseOps>;
diff --git a/mlir/tools/mlir-tblgen/OmpOpGen.cpp b/mlir/tools/mlir-tblgen/OmpOpGen.cpp
index 51eb43f322e6a..d4c2cd48e891f 100644
--- a/mlir/tools/mlir-tblgen/OmpOpGen.cpp
+++ b/mlir/tools/mlir-tblgen/OmpOpGen.cpp
@@ -12,11 +12,43 @@
#include "mlir/TableGen/GenInfo.h"
+#include "mlir/TableGen/CodeGenHelpers.h"
+#include "llvm/ADT/StringExtras.h"
+#include "llvm/ADT/TypeSwitch.h"
#include "llvm/TableGen/Error.h"
#include "llvm/TableGen/Record.h"
using namespace llvm;
+/// The code block defining the base mixin class for combining clause operand
+/// structures.
+static const char *const baseMixinClass = R"(
+namespace detail {
+template <typename... Mixins>
+struct Clauses : public Mixins... {};
+} // namespace detail
+)";
+
+/// The code block defining operation argument structures.
+static const char *const operationArgStruct = R"(
+using {0}Operands = detail::Clauses<{1}>;
+)";
+
+/// Remove multiple optional prefixes and suffixes from \c str.
+static StringRef stripPrefixAndSuffix(StringRef str,
+ llvm::ArrayRef<StringRef> prefixes,
+ llvm::ArrayRef<StringRef> suffixes) {
+ for (StringRef prefix : prefixes)
+ if (str.starts_with(prefix))
+ str = str.substr(prefix.size());
+
+ for (StringRef suffix : suffixes)
+ if (str.ends_with(suffix))
+ str = str.substr(0, str.size() - suffix.size());
+
+ return str;
+}
+
/// Obtain the name of the OpenMP clause a given record inheriting
/// `OpenMP_Clause` refers to.
///
@@ -53,19 +85,8 @@ static StringRef extractOmpClauseName(Record *clause) {
assert(!clauseClassName.empty() && "clause name must be found");
// Keep only the OpenMP clause name itself for reporting purposes.
- StringRef prefix = "OpenMP_";
- StringRef suffixes[] = {"Skip", "Clause"};
-
- if (clauseClassName.starts_with(prefix))
- clauseClassName = clauseClassName.substr(prefix.size());
-
- for (StringRef suffix : suffixes) {
- if (clauseClassName.ends_with(suffix))
- clauseClassName =
- clauseClassName.substr(0, clauseClassName.size() - suffix.size());
- }
-
- return clauseClassName;
+ return stripPrefixAndSuffix(clauseClassName, /*prefixes=*/{"OpenMP_"},
+ /*suffixes=*/{"Skip", "Clause"});
}
/// Check that the given argument, identified by its name and initialization
@@ -148,6 +169,110 @@ static void verifyClause(Record *op, Record *clause) {
"or explicitly skipping this field.");
}
+/// Translate the type of an OpenMP clause's argument to its corresponding
+/// representation for clause operand structures.
+///
+/// All kinds of values are represented as `mlir::Value` fields, whereas
+/// attributes are represented based on their `storageType`.
+///
+/// \param[in] init The `DefInit` object representing the argument.
+/// \param[out] rank Number of levels of array nesting associated with the
+/// type.
+///
+/// \return the name of the base type to represent elements of the argument
+/// type.
+static StringRef translateArgumentType(Init *init, int &rank) {
+ Record *def = cast<DefInit>(init)->getDef();
+ bool isAttr = false, isValue = false;
+
+ for (auto [sc, _] : def->getSuperClasses()) {
+ std::string scName = sc->getNameInitAsString();
+ if (scName == "OptionalAttr")
+ return translateArgumentType(def->getValue("baseAttr")->getValue(), rank);
+
+ if (scName == "TypedArrayAttrBase") {
+ ++rank;
+ return translateArgumentType(def->getValue("elementAttr")->getValue(),
+ rank);
+ }
+
+ if (scName == "ElementsAttrBase") {
+ rank += def->getValueAsInt("rank");
+ return def->getValueAsString("elementReturnType").trim();
+ }
+
+ if (scName == "Attr")
+ isAttr = true;
+ else if (scName == "TypeConstraint")
+ isValue = true;
+ else if (scName == "Variadic")
+ ++rank;
+ }
+
+ if (isValue) {
+ assert(!isAttr);
+ return "::mlir::Value";
+ }
+
+ assert(isAttr);
+ return rank > 0 ? "::mlir::Attribute"
+ : def->getValueAsString("storageType").trim();
+}
+
+/// Generate the structure that represents the arguments of the given \c clause
+/// record of type \c OpenMP_Clause.
+///
+/// It will contain a field for each argument, using the same name translated to
+/// camel case and the corresponding base type as returned by
+/// translateArgumentType() optionally wrapped in one or more llvm::SmallVector.
+static void genClauseOpsStruct(Record *clause, raw_ostream &os) {
+ if (clause->isAnonymous())
+ return;
+
+ StringRef clauseName = extractOmpClauseName(clause);
+ os << "struct " << clauseName << "ClauseOps {\n";
+
+ DagInit *arguments = clause->getValueAsDag("arguments");
+ for (auto [name, arg] :
+ zip_equal(arguments->getArgNames(), arguments->getArgs())) {
+ int rank = 0;
+ StringRef baseType = translateArgumentType(arg, rank);
+
+ if (rank > 0)
+ os << " ::llvm::SmallVector<" << baseType << ">";
+ else
+ os << " " << baseType;
+
+ std::string fieldName =
+ convertToCamelFromSnakeCase(name->getAsUnquotedString(),
+ /*capitalizeFirst=*/false);
+ os << " " << fieldName << ";\n";
+
+ if (rank > 1)
+ os << " int " << fieldName << "Dims[" << rank << "];\n";
+ }
+
+ os << "};\n";
+}
+
+/// Generate the structure that represents the clause-related arguments of the
+/// given \c op record of type \c OpenMP_Op.
+///
+/// This structure will be defined in terms of the clause operand structures
+/// associated to the clauses of the operation.
+static void genOperandsDef(Record *op, raw_ostream &os) {
+ if (op->isAnonymous())
+ return;
+
+ SmallVector<std::string> clauseNames;
+ for (Record *clause : op->getValueAsListOfDefs("clauseList"))
+ clauseNames.push_back((extractOmpClauseName(clause) + "ClauseOps").str());
+
+ StringRef opName = stripPrefixAndSuffix(
+ op->getName(), /*prefixes=*/{"OpenMP_"}, /*suffixes=*/{"Op"});
+ os << formatv(operationArgStruct, opName, join(clauseNames, ", "));
+}
+
/// Verify that all properties of `OpenMP_Clause`s of records deriving from
/// `OpenMP_Op`s have been inherited by the latter.
static bool verifyDecls(const RecordKeeper &recordKeeper, raw_ostream &) {
@@ -159,8 +284,31 @@ static bool verifyDecls(const RecordKeeper &recordKeeper, raw_ostream &) {
return false;
}
+/// Generate structures to represent clause-related operands, based on existing
+/// `OpenMP_Clause` definitions and aggregate them into operation-specific
+/// structures according to the `clauses` argument of each definition deriving
+/// from `OpenMP_Op`.
+static bool genClauseOps(const RecordKeeper &recordKeeper, raw_ostream &os) {
+ mlir::tblgen::NamespaceEmitter ns(os, "mlir::omp");
+ for (Record *clause : recordKeeper.getAllDerivedDefinitions("OpenMP_Clause"))
+ genClauseOpsStruct(clause, os);
+
+ // Produce base mixin class.
+ os << baseMixinClass;
+
+ for (Record *op : recordKeeper.getAllDerivedDefinitions("OpenMP_Op"))
+ genOperandsDef(op, os);
+
+ return false;
+}
+
// Registers the generator to mlir-tblgen.
static mlir::GenRegistration
verifyOpenmpOps("verify-openmp-ops",
"Verify OpenMP operations (produce no output file)",
verifyDecls);
+
+static mlir::GenRegistration
+ genOpenmpClauseOps("gen-openmp-clause-ops",
+ "Generate OpenMP clause operand structures",
+ genClauseOps);
>From 19ea9b4f6cd6811d5f01ac4e2d5b0e4073230f6f Mon Sep 17 00:00:00 2001
From: Sergio Afonso <safonsof at amd.com>
Date: Fri, 19 Jul 2024 15:02:33 +0100
Subject: [PATCH 2/2] Address review comments
---
mlir/tools/mlir-tblgen/OmpOpGen.cpp | 18 ++++++++++++++----
1 file changed, 14 insertions(+), 4 deletions(-)
diff --git a/mlir/tools/mlir-tblgen/OmpOpGen.cpp b/mlir/tools/mlir-tblgen/OmpOpGen.cpp
index d4c2cd48e891f..ee3347609ddfd 100644
--- a/mlir/tools/mlir-tblgen/OmpOpGen.cpp
+++ b/mlir/tools/mlir-tblgen/OmpOpGen.cpp
@@ -35,6 +35,15 @@ using {0}Operands = detail::Clauses<{1}>;
)";
/// Remove multiple optional prefixes and suffixes from \c str.
+///
+/// Prefixes and suffixes are attempted to be removed once in the order they
+/// appear in the \c prefixes and \c suffixes arguments. All prefixes are
+/// processed before suffixes are. This means it will behave as shown in the
+/// following example:
+/// - str: "PrePreNameSuf1Suf2"
+/// - prefixes: ["Pre"]
+/// - suffixes: ["Suf1", "Suf2"]
+/// - return: "PreNameSuf1"
static StringRef stripPrefixAndSuffix(StringRef str,
llvm::ArrayRef<StringRef> prefixes,
llvm::ArrayRef<StringRef> suffixes) {
@@ -177,10 +186,10 @@ static void verifyClause(Record *op, Record *clause) {
///
/// \param[in] init The `DefInit` object representing the argument.
/// \param[out] rank Number of levels of array nesting associated with the
-/// type.
+/// type.
///
/// \return the name of the base type to represent elements of the argument
-/// type.
+/// type.
static StringRef translateArgumentType(Init *init, int &rank) {
Record *def = cast<DefInit>(init)->getDef();
bool isAttr = false, isValue = false;
@@ -210,11 +219,12 @@ static StringRef translateArgumentType(Init *init, int &rank) {
}
if (isValue) {
- assert(!isAttr);
+ assert(!isAttr &&
+ "argument can't be simultaneously a value and an attribute");
return "::mlir::Value";
}
- assert(isAttr);
+ assert(isAttr && "argument must be an attribute if it's not a value");
return rank > 0 ? "::mlir::Attribute"
: def->getValueAsString("storageType").trim();
}
More information about the llvm-branch-commits
mailing list