[Mlir-commits] [mlir] [MLIR] Add ODS support for generating helpers for dialect (discardable) attributes (PR #77024)
Mehdi Amini
llvmlistbot at llvm.org
Mon Feb 19 18:05:02 PST 2024
https://github.com/joker-eph updated https://github.com/llvm/llvm-project/pull/77024
>From 40a814a5175c51d5764ce2cf321bc4cfc56c74ff Mon Sep 17 00:00:00 2001
From: Mehdi Amini <joker.eph at gmail.com>
Date: Thu, 14 Dec 2023 20:55:49 +0000
Subject: [PATCH] [MLIR] Add ODS support for generating helpers for dialect
(discardable) attributes
This is a new ODS feature that allows dialects to define a list of
key/value pair representing an attribute type and a name.
This will generate helper classes on the dialect to be able to
manage discardable attributes on operations in a type safe way.
For example the `test` dialect can define:
```
let discardableAttrs = (ins
"mlir::IntegerAttr":$discardable_attr_key,
);
```
And the following will be generated in the TestDialect class:
```
/// Helper to manage the discardable attribute `discardable_attr_key`.
class DiscardableAttrKeyAttrHelper {
::mlir::StringAttr name;
public:
static constexpr ::llvm::StringLiteral getNameStr() {
return "test.discardable_attr_key";
}
constexpr ::mlir::StringAttr getName() {
return name;
}
DiscardableAttrKeyAttrHelper(::mlir::MLIRContext *ctx)
: name(::mlir::StringAttr::get(ctx, getNameStr())) {}
mlir::IntegerAttr getAttr(::mlir::Operation *op) {
return op->getAttrOfType<mlir::IntegerAttr>(name);
}
void setAttr(::mlir::Operation *op, mlir::IntegerAttr val) {
op->setAttr(name, val);
}
bool isAttrPresent(::mlir::Operation *op) {
return op->hasAttrOfType<mlir::IntegerAttr>(name);
}
void removeAttr(::mlir::Operation *op) {
assert(op->hasAttrOfType<mlir::IntegerAttr>(name));
op->removeAttr(name);
}
};
DiscardableAttrKeyAttrHelper getDiscardableAttrKeyAttrHelper() {
return discardableAttrKeyAttrName;
}
```
User code having an instance of the TestDialect can then manipulate this
attribute on operation using:
```
auto helper = testDialect.getDiscardableAttrKeyAttrHelper();
helper.setAttr(op, value);
helper.isAttrPresent(op);
...
```
---
mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td | 18 ++--
mlir/include/mlir/IR/DialectBase.td | 5 +
mlir/include/mlir/TableGen/Dialect.h | 6 ++
.../GPUToROCDL/LowerGpuOpsToROCDLOps.cpp | 17 ++--
mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp | 4 +-
mlir/lib/TableGen/Dialect.cpp | 4 +
.../ROCDL/ROCDLToLLVMIRTranslation.cpp | 11 ++-
mlir/test/lib/Dialect/Test/TestDialect.td | 4 +
mlir/tools/mlir-tblgen/DialectGen.cpp | 96 ++++++++++++++++++-
9 files changed, 136 insertions(+), 29 deletions(-)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
index 962c159e68a2ee..6b170c8d06f437 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
@@ -28,16 +28,6 @@ def ROCDL_Dialect : Dialect {
let hasOperationAttrVerify = 1;
let extraClassDeclaration = [{
- /// Get the name of the attribute used to annotate external kernel
- /// functions.
- static StringRef getKernelFuncAttrName() { return "rocdl.kernel"; }
- static constexpr ::llvm::StringLiteral getFlatWorkGroupSizeAttrName() {
- return ::llvm::StringLiteral("rocdl.flat_work_group_size");
- }
- static constexpr ::llvm::StringLiteral getReqdWorkGroupSizeAttrName() {
- return ::llvm::StringLiteral("rocdl.reqd_work_group_size");
- }
-
/// The address space value that represents global memory.
static constexpr unsigned kGlobalMemoryAddressSpace = 1;
/// The address space value that represents shared memory.
@@ -46,6 +36,14 @@ def ROCDL_Dialect : Dialect {
static constexpr unsigned kPrivateMemoryAddressSpace = 5;
}];
+ let discardableAttrs = (ins
+ "::mlir::UnitAttr":$kernel,
+ "::mlir::DenseI32ArrayAttr":$reqd_work_group_size,
+ "::mlir::StringAttr":$flat_work_group_size,
+ "::mlir::IntegerAttr":$max_flat_work_group_size,
+ "::mlir::IntegerAttr":$waves_per_eu
+ );
+
let useDefaultAttributePrinterParser = 1;
}
diff --git a/mlir/include/mlir/IR/DialectBase.td b/mlir/include/mlir/IR/DialectBase.td
index 5afa23933ea1f7..115a01f92706ab 100644
--- a/mlir/include/mlir/IR/DialectBase.td
+++ b/mlir/include/mlir/IR/DialectBase.td
@@ -34,6 +34,11 @@ class Dialect {
// pattern or interfaces.
list<string> dependentDialects = [];
+ // A list of key/value pair representing an attribute type and a name.
+ // This will generate helper classes on the dialect to be able to
+ // manage discardable attributes on operations in a type safe way.
+ dag discardableAttrs = (ins);
+
// The C++ namespace that ops of this dialect should be placed into.
//
// By default, uses the name of the dialect as the only namespace. To avoid
diff --git a/mlir/include/mlir/TableGen/Dialect.h b/mlir/include/mlir/TableGen/Dialect.h
index 5337bd3beb5f9d..3530d240c976c6 100644
--- a/mlir/include/mlir/TableGen/Dialect.h
+++ b/mlir/include/mlir/TableGen/Dialect.h
@@ -14,6 +14,8 @@
#define MLIR_TABLEGEN_DIALECT_H_
#include "mlir/Support/LLVM.h"
+#include "llvm/TableGen/Record.h"
+
#include <string>
#include <vector>
@@ -90,6 +92,10 @@ class Dialect {
/// dialect.
bool usePropertiesForAttributes() const;
+ llvm::DagInit *getDiscardableAttributes() const;
+
+ const llvm::Record *getDef() const { return def; }
+
// Returns whether two dialects are equal by checking the equality of the
// underlying record.
bool operator==(const Dialect &other) const;
diff --git a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
index 363e6016113b16..4fa3cdcbf85ce2 100644
--- a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
+++ b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
@@ -285,14 +285,17 @@ struct LowerGpuOpsToROCDLOpsPass
configureGpuToROCDLConversionLegality(target);
if (failed(applyPartialConversion(m, target, std::move(llvmPatterns))))
signalPassFailure();
-
+ auto *rocdlDialect = getContext().getLoadedDialect<ROCDL::ROCDLDialect>();
+ auto reqdWorkGroupSizeAttrHelper =
+ rocdlDialect->getReqdWorkGroupSizeAttrHelper();
+ auto flatWorkGroupSizeAttrHelper =
+ rocdlDialect->getFlatWorkGroupSizeAttrHelper();
// Manually rewrite known block size attributes so the LLVMIR translation
// infrastructure can pick them up.
- m.walk([ctx](LLVM::LLVMFuncOp op) {
+ m.walk([&](LLVM::LLVMFuncOp op) {
if (auto blockSizes = dyn_cast_or_null<DenseI32ArrayAttr>(
op->removeAttr(gpu::GPUFuncOp::getKnownBlockSizeAttrName()))) {
- op->setAttr(ROCDL::ROCDLDialect::getReqdWorkGroupSizeAttrName(),
- blockSizes);
+ reqdWorkGroupSizeAttrHelper.setAttr(op, blockSizes);
// Also set up the rocdl.flat_work_group_size attribute to prevent
// conflicting metadata.
uint32_t flatSize = 1;
@@ -301,8 +304,7 @@ struct LowerGpuOpsToROCDLOpsPass
}
StringAttr flatSizeAttr =
StringAttr::get(ctx, Twine(flatSize) + "," + Twine(flatSize));
- op->setAttr(ROCDL::ROCDLDialect::getFlatWorkGroupSizeAttrName(),
- flatSizeAttr);
+ flatWorkGroupSizeAttrHelper.setAttr(op, flatSizeAttr);
}
});
}
@@ -355,8 +357,7 @@ void mlir::populateGpuToROCDLConversionPatterns(
converter,
/*allocaAddrSpace=*/ROCDL::ROCDLDialect::kPrivateMemoryAddressSpace,
/*workgroupAddrSpace=*/ROCDL::ROCDLDialect::kSharedMemoryAddressSpace,
- StringAttr::get(&converter.getContext(),
- ROCDL::ROCDLDialect::getKernelFuncAttrName()));
+ ROCDL::ROCDLDialect::KernelAttrHelper(&converter.getContext()).getName());
if (Runtime::HIP == runtime) {
patterns.add<GPUPrintfOpToHIPLowering>(converter);
} else if (Runtime::OpenCL == runtime) {
diff --git a/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp
index 26e46b31ddc018..0f2e75cd7e8bc7 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp
@@ -253,9 +253,9 @@ void ROCDLDialect::initialize() {
LogicalResult ROCDLDialect::verifyOperationAttribute(Operation *op,
NamedAttribute attr) {
// Kernel function attribute should be attached to functions.
- if (attr.getName() == ROCDLDialect::getKernelFuncAttrName()) {
+ if (kernelAttrName.getName() == attr.getName()) {
if (!isa<LLVM::LLVMFuncOp>(op)) {
- return op->emitError() << "'" << ROCDLDialect::getKernelFuncAttrName()
+ return op->emitError() << "'" << kernelAttrName.getName()
<< "' attribute attached to unexpected op";
}
}
diff --git a/mlir/lib/TableGen/Dialect.cpp b/mlir/lib/TableGen/Dialect.cpp
index 6924a2862eef07..081f6e56f9ded4 100644
--- a/mlir/lib/TableGen/Dialect.cpp
+++ b/mlir/lib/TableGen/Dialect.cpp
@@ -106,6 +106,10 @@ bool Dialect::usePropertiesForAttributes() const {
return def->getValueAsBit("usePropertiesForAttributes");
}
+llvm::DagInit *Dialect::getDiscardableAttributes() const {
+ return def->getValueAsDag("discardableAttrs");
+}
+
bool Dialect::operator==(const Dialect &other) const {
return def == other.def;
}
diff --git a/mlir/lib/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.cpp
index 0cbb3da79d151e..93eb456cdc2c4f 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.cpp
@@ -84,7 +84,8 @@ class ROCDLDialectLLVMIRTranslationInterface
amendOperation(Operation *op, ArrayRef<llvm::Instruction *> instructions,
NamedAttribute attribute,
LLVM::ModuleTranslation &moduleTranslation) const final {
- if (attribute.getName() == ROCDL::ROCDLDialect::getKernelFuncAttrName()) {
+ auto *dialect = dyn_cast<ROCDL::ROCDLDialect>(attribute.getNameDialect());
+ if (dialect->getKernelAttrHelper().getName() == attribute.getName()) {
auto func = dyn_cast<LLVM::LLVMFuncOp>(op);
if (!func)
return failure();
@@ -99,12 +100,12 @@ class ROCDLDialectLLVMIRTranslationInterface
if (!llvmFunc->hasFnAttribute("amdgpu-flat-work-group-size")) {
llvmFunc->addFnAttr("amdgpu-flat-work-group-size", "1,256");
}
-
}
// Override flat-work-group-size
// TODO: update clients to rocdl.flat_work_group_size instead,
// then remove this half of the branch
- if ("rocdl.max_flat_work_group_size" == attribute.getName()) {
+ if (dialect->getMaxFlatWorkGroupSizeAttrHelper().getName() ==
+ attribute.getName()) {
auto func = dyn_cast<LLVM::LLVMFuncOp>(op);
if (!func)
return failure();
@@ -119,7 +120,7 @@ class ROCDLDialectLLVMIRTranslationInterface
attrValueStream << "1," << value.getInt();
llvmFunc->addFnAttr("amdgpu-flat-work-group-size", llvmAttrValue);
}
- if (ROCDL::ROCDLDialect::getFlatWorkGroupSizeAttrName() ==
+ if (dialect->getFlatWorkGroupSizeAttrHelper().getName() ==
attribute.getName()) {
auto func = dyn_cast<LLVM::LLVMFuncOp>(op);
if (!func)
@@ -136,7 +137,7 @@ class ROCDLDialectLLVMIRTranslationInterface
}
// Set reqd_work_group_size metadata
- if (ROCDL::ROCDLDialect::getReqdWorkGroupSizeAttrName() ==
+ if (dialect->getReqdWorkGroupSizeAttrHelper().getName() ==
attribute.getName()) {
auto func = dyn_cast<LLVM::LLVMFuncOp>(op);
if (!func)
diff --git a/mlir/test/lib/Dialect/Test/TestDialect.td b/mlir/test/lib/Dialect/Test/TestDialect.td
index 8524d5b1458447..2b5491fc0c6a02 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.td
+++ b/mlir/test/lib/Dialect/Test/TestDialect.td
@@ -25,6 +25,10 @@ def Test_Dialect : Dialect {
let useDefaultAttributePrinterParser = 1;
let isExtensible = 1;
let dependentDialects = ["::mlir::DLTIDialect"];
+ let discardableAttrs = (ins
+ "mlir::IntegerAttr":$discardable_attr_key,
+ "SimpleAAttr":$other_discardable_attr_key
+ );
let extraClassDeclaration = [{
void registerAttributes();
diff --git a/mlir/tools/mlir-tblgen/DialectGen.cpp b/mlir/tools/mlir-tblgen/DialectGen.cpp
index 4f2021083384fc..46e585a351a2b0 100644
--- a/mlir/tools/mlir-tblgen/DialectGen.cpp
+++ b/mlir/tools/mlir-tblgen/DialectGen.cpp
@@ -43,6 +43,21 @@ using DialectFilterIterator =
std::function<bool(const llvm::Record *)>>;
} // namespace
+static void populateDiscardableAttributes(
+ Dialect &dialect, llvm::DagInit *discardableAttrDag,
+ SmallVector<std::pair<std::string, std::string>> &discardableAttributes) {
+ for (int i : llvm::seq<int>(0, discardableAttrDag->getNumArgs())) {
+ llvm::Init *arg = discardableAttrDag->getArg(i);
+
+ StringRef givenName = discardableAttrDag->getArgNameStr(i);
+ if (givenName.empty())
+ PrintFatalError(dialect.getDef()->getLoc(),
+ "discardable attributes must be named");
+ discardableAttributes.push_back(
+ {givenName.str(), arg->getAsUnquotedString()});
+ }
+}
+
/// Given a set of records for a T, filter the ones that correspond to
/// the given dialect.
template <typename T>
@@ -180,6 +195,44 @@ static const char *const operationInterfaceFallbackDecl = R"(
mlir::OperationName opName) override;
)";
+/// The code block for the discardable attribute helper.
+static const char *const discardableAttrHelperDecl = R"(
+ /// Helper to manage the discardable attribute `{1}`.
+ class {0}AttrHelper {{
+ ::mlir::StringAttr name;
+ public:
+ static constexpr ::llvm::StringLiteral getNameStr() {{
+ return "{4}.{1}";
+ }
+ constexpr ::mlir::StringAttr getName() {{
+ return name;
+ }
+
+ {0}AttrHelper(::mlir::MLIRContext *ctx)
+ : name(::mlir::StringAttr::get(ctx, getNameStr())) {{}
+
+ {2} getAttr(::mlir::Operation *op) {{
+ return op->getAttrOfType<{2}>(name);
+ }
+ void setAttr(::mlir::Operation *op, {2} val) {{
+ op->setAttr(name, val);
+ }
+ bool isAttrPresent(::mlir::Operation *op) {{
+ return op->hasAttrOfType<{2}>(name);
+ }
+ void removeAttr(::mlir::Operation *op) {{
+ assert(op->hasAttrOfType<{2}>(name));
+ op->removeAttr(name);
+ }
+ };
+ {0}AttrHelper get{0}AttrHelper() {
+ return {3}AttrName;
+ }
+ private:
+ {0}AttrHelper {3}AttrName;
+ public:
+)";
+
/// Generate the declaration for the given dialect class.
static void emitDialectDecl(Dialect &dialect, raw_ostream &os) {
// Emit all nested namespaces.
@@ -215,6 +268,22 @@ static void emitDialectDecl(Dialect &dialect, raw_ostream &os) {
os << regionResultAttrVerifierDecl;
if (dialect.hasOperationInterfaceFallback())
os << operationInterfaceFallbackDecl;
+
+ llvm::DagInit *discardableAttrDag = dialect.getDiscardableAttributes();
+ SmallVector<std::pair<std::string, std::string>> discardableAttributes;
+ populateDiscardableAttributes(dialect, discardableAttrDag,
+ discardableAttributes);
+
+ for (const auto &attrPair : discardableAttributes) {
+ std::string camelNameUpper = llvm::convertToCamelFromSnakeCase(
+ attrPair.first, /*capitalizeFirst=*/true);
+ std::string camelName = llvm::convertToCamelFromSnakeCase(
+ attrPair.first, /*capitalizeFirst=*/false);
+ os << llvm::formatv(discardableAttrHelperDecl, camelNameUpper,
+ attrPair.first, attrPair.second, camelName,
+ dialect.getName());
+ }
+
if (std::optional<StringRef> extraDecl = dialect.getExtraClassDeclaration())
os << *extraDecl;
@@ -252,9 +321,12 @@ static bool emitDialectDecls(const llvm::RecordKeeper &recordKeeper,
/// {1}: Initialization code that is emitted in the ctor body before calling
/// initialize(), such as dependent dialect registration.
/// {2}: The dialect parent class.
+/// {3}: Extra members to initialize
static const char *const dialectConstructorStr = R"(
{0}::{0}(::mlir::MLIRContext *context)
- : ::mlir::{2}(getDialectNamespace(), context, ::mlir::TypeID::get<{0}>()) {{
+ : ::mlir::{2}(getDialectNamespace(), context, ::mlir::TypeID::get<{0}>())
+ {3}
+ {{
{1}
initialize();
}
@@ -268,7 +340,9 @@ static const char *const dialectDestructorStr = R"(
)";
-static void emitDialectDef(Dialect &dialect, raw_ostream &os) {
+static void emitDialectDef(Dialect &dialect,
+ const llvm::RecordKeeper &recordKeeper,
+ raw_ostream &os) {
std::string cppClassName = dialect.getCppClassName();
// Emit the TypeID explicit specializations to have a single symbol def.
@@ -295,8 +369,22 @@ static void emitDialectDef(Dialect &dialect, raw_ostream &os) {
// Emit the constructor and destructor.
StringRef superClassName =
dialect.isExtensible() ? "ExtensibleDialect" : "Dialect";
+
+ llvm::DagInit *discardableAttrDag = dialect.getDiscardableAttributes();
+ SmallVector<std::pair<std::string, std::string>> discardableAttributes;
+ populateDiscardableAttributes(dialect, discardableAttrDag,
+ discardableAttributes);
+ std::string discardableAttributesInit;
+ for (const auto &attrPair : discardableAttributes) {
+ std::string camelName = llvm::convertToCamelFromSnakeCase(
+ attrPair.first, /*capitalizeFirst=*/false);
+ llvm::raw_string_ostream os(discardableAttributesInit);
+ os << ", " << camelName << "AttrName(context)";
+ }
+
os << llvm::formatv(dialectConstructorStr, cppClassName,
- dependentDialectRegistrations, superClassName);
+ dependentDialectRegistrations, superClassName,
+ discardableAttributesInit);
if (!dialect.hasNonDefaultDestructor())
os << llvm::formatv(dialectDestructorStr, cppClassName);
}
@@ -313,7 +401,7 @@ static bool emitDialectDefs(const llvm::RecordKeeper &recordKeeper,
std::optional<Dialect> dialect = findDialectToGenerate(dialects);
if (!dialect)
return true;
- emitDialectDef(*dialect, os);
+ emitDialectDef(*dialect, recordKeeper, os);
return false;
}
More information about the Mlir-commits
mailing list