[Mlir-commits] [mlir] [MLIR] Add ODS support for generating helpers for dialect (discardable) attributes (PR #77024)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Jan 4 15:33:19 PST 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-ods
@llvm/pr-subscribers-mlir-llvm
Author: Mehdi Amini (joker-eph)
<details>
<summary>Changes</summary>
WIP (missing docs)
See https://github.com/llvm/llvm-project/pull/75118 for context
---
Full diff: https://github.com/llvm/llvm-project/pull/77024.diff
9 Files Affected:
- (modified) mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td (+8-10)
- (modified) mlir/include/mlir/IR/DialectBase.td (+2)
- (modified) mlir/include/mlir/TableGen/Dialect.h (+6)
- (modified) mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp (+9-8)
- (modified) mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp (+2-2)
- (modified) mlir/lib/TableGen/Dialect.cpp (+4)
- (modified) mlir/lib/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.cpp (+6-4)
- (modified) mlir/test/lib/Dialect/Test/TestDialect.td (+4)
- (modified) mlir/tools/mlir-tblgen/DialectGen.cpp (+85-4)
``````````diff
diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
index 48b830ae34f292..6abcfdf60e0fd0 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
@@ -28,16 +28,6 @@ def ROCDL_Dialect : Dialect {
let hasOperationAttrVerify = 1;
let extraClassDeclaration = [{
- /// Get the name of the attribute used to annotate external kernel
- /// functions.
- static StringRef getKernelFuncAttrName() { return "rocdl.kernel"; }
- static constexpr ::llvm::StringLiteral getFlatWorkGroupSizeAttrName() {
- return ::llvm::StringLiteral("rocdl.flat_work_group_size");
- }
- static constexpr ::llvm::StringLiteral getReqdWorkGroupSizeAttrName() {
- return ::llvm::StringLiteral("rocdl.reqd_work_group_size");
- }
-
/// The address space value that represents global memory.
static constexpr unsigned kGlobalMemoryAddressSpace = 1;
/// The address space value that represents shared memory.
@@ -46,6 +36,14 @@ def ROCDL_Dialect : Dialect {
static constexpr unsigned kPrivateMemoryAddressSpace = 5;
}];
+ let discardableAttrs = (ins
+ "::mlir::UnitAttr":$kernel,
+ "::mlir::DenseI32ArrayAttr":$reqd_work_group_size,
+ "::mlir::StringAttr":$flat_work_group_size,
+ "::mlir::IntegerAttr":$max_flat_work_group_size,
+ "::mlir::IntegerAttr":$waves_per_eu
+ );
+
let useDefaultAttributePrinterParser = 1;
}
diff --git a/mlir/include/mlir/IR/DialectBase.td b/mlir/include/mlir/IR/DialectBase.td
index 5afa23933ea1f7..16750dc7e4d320 100644
--- a/mlir/include/mlir/IR/DialectBase.td
+++ b/mlir/include/mlir/IR/DialectBase.td
@@ -34,6 +34,8 @@ class Dialect {
// pattern or interfaces.
list<string> dependentDialects = [];
+ dag discardableAttrs = (ins);
+
// The C++ namespace that ops of this dialect should be placed into.
//
// By default, uses the name of the dialect as the only namespace. To avoid
diff --git a/mlir/include/mlir/TableGen/Dialect.h b/mlir/include/mlir/TableGen/Dialect.h
index 5337bd3beb5f9d..3530d240c976c6 100644
--- a/mlir/include/mlir/TableGen/Dialect.h
+++ b/mlir/include/mlir/TableGen/Dialect.h
@@ -14,6 +14,8 @@
#define MLIR_TABLEGEN_DIALECT_H_
#include "mlir/Support/LLVM.h"
+#include "llvm/TableGen/Record.h"
+
#include <string>
#include <vector>
@@ -90,6 +92,10 @@ class Dialect {
/// dialect.
bool usePropertiesForAttributes() const;
+ llvm::DagInit *getDiscardableAttributes() const;
+
+ const llvm::Record *getDef() const { return def; }
+
// Returns whether two dialects are equal by checking the equality of the
// underlying record.
bool operator==(const Dialect &other) const;
diff --git a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
index 599bb13190f12d..abd2733ba4fbd1 100644
--- a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
+++ b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
@@ -285,14 +285,17 @@ struct LowerGpuOpsToROCDLOpsPass
configureGpuToROCDLConversionLegality(target);
if (failed(applyPartialConversion(m, target, std::move(llvmPatterns))))
signalPassFailure();
-
+ auto *rocdlDialect = getContext().getLoadedDialect<ROCDL::ROCDLDialect>();
+ auto reqdWorkGroupSizeAttrHelper =
+ rocdlDialect->getReqdWorkGroupSizeAttrHelper();
+ auto flatWorkGroupSizeAttrHelper =
+ rocdlDialect->getFlatWorkGroupSizeAttrHelper();
// Manually rewrite known block size attributes so the LLVMIR translation
// infrastructure can pick them up.
- m.walk([ctx](LLVM::LLVMFuncOp op) {
+ m.walk([&](LLVM::LLVMFuncOp op) {
if (auto blockSizes = dyn_cast_or_null<DenseI32ArrayAttr>(
op->removeAttr(gpu::GPUFuncOp::getKnownBlockSizeAttrName()))) {
- op->setAttr(ROCDL::ROCDLDialect::getReqdWorkGroupSizeAttrName(),
- blockSizes);
+ reqdWorkGroupSizeAttrHelper.setAttr(op, blockSizes);
// Also set up the rocdl.flat_work_group_size attribute to prevent
// conflicting metadata.
uint32_t flatSize = 1;
@@ -301,8 +304,7 @@ struct LowerGpuOpsToROCDLOpsPass
}
StringAttr flatSizeAttr =
StringAttr::get(ctx, Twine(flatSize) + "," + Twine(flatSize));
- op->setAttr(ROCDL::ROCDLDialect::getFlatWorkGroupSizeAttrName(),
- flatSizeAttr);
+ flatWorkGroupSizeAttrHelper.setAttr(op, flatSizeAttr);
}
});
}
@@ -355,8 +357,7 @@ void mlir::populateGpuToROCDLConversionPatterns(
converter,
/*allocaAddrSpace=*/ROCDL::ROCDLDialect::kPrivateMemoryAddressSpace,
/*workgroupAddrSpace=*/ROCDL::ROCDLDialect::kSharedMemoryAddressSpace,
- StringAttr::get(&converter.getContext(),
- ROCDL::ROCDLDialect::getKernelFuncAttrName()));
+ ROCDL::ROCDLDialect::KernelAttrHelper(&converter.getContext()).getName());
if (Runtime::HIP == runtime) {
patterns.add<GPUPrintfOpToHIPLowering>(converter);
} else if (Runtime::OpenCL == runtime) {
diff --git a/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp
index 26e46b31ddc018..0f2e75cd7e8bc7 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp
@@ -253,9 +253,9 @@ void ROCDLDialect::initialize() {
LogicalResult ROCDLDialect::verifyOperationAttribute(Operation *op,
NamedAttribute attr) {
// Kernel function attribute should be attached to functions.
- if (attr.getName() == ROCDLDialect::getKernelFuncAttrName()) {
+ if (kernelAttrName.getName() == attr.getName()) {
if (!isa<LLVM::LLVMFuncOp>(op)) {
- return op->emitError() << "'" << ROCDLDialect::getKernelFuncAttrName()
+ return op->emitError() << "'" << kernelAttrName.getName()
<< "' attribute attached to unexpected op";
}
}
diff --git a/mlir/lib/TableGen/Dialect.cpp b/mlir/lib/TableGen/Dialect.cpp
index 6924a2862eef07..081f6e56f9ded4 100644
--- a/mlir/lib/TableGen/Dialect.cpp
+++ b/mlir/lib/TableGen/Dialect.cpp
@@ -106,6 +106,10 @@ bool Dialect::usePropertiesForAttributes() const {
return def->getValueAsBit("usePropertiesForAttributes");
}
+llvm::DagInit *Dialect::getDiscardableAttributes() const {
+ return def->getValueAsDag("discardableAttrs");
+}
+
bool Dialect::operator==(const Dialect &other) const {
return def == other.def;
}
diff --git a/mlir/lib/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.cpp
index 55a6285ec87eb4..6783ffcde6d531 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.cpp
@@ -84,7 +84,8 @@ class ROCDLDialectLLVMIRTranslationInterface
amendOperation(Operation *op, ArrayRef<llvm::Instruction *> instructions,
NamedAttribute attribute,
LLVM::ModuleTranslation &moduleTranslation) const final {
- if (attribute.getName() == ROCDL::ROCDLDialect::getKernelFuncAttrName()) {
+ auto *dialect = dyn_cast<ROCDL::ROCDLDialect>(attribute.getNameDialect());
+ if (dialect->getKernelAttrHelper().getName() == attribute.getName()) {
auto func = dyn_cast<LLVM::LLVMFuncOp>(op);
if (!func)
return failure();
@@ -106,7 +107,8 @@ class ROCDLDialectLLVMIRTranslationInterface
// Override flat-work-group-size
// TODO: update clients to rocdl.flat_work_group_size instead,
// then remove this half of the branch
- if ("rocdl.max_flat_work_group_size" == attribute.getName()) {
+ if (dialect->getMaxFlatWorkGroupSizeAttrHelper().getName() ==
+ attribute.getName()) {
auto func = dyn_cast<LLVM::LLVMFuncOp>(op);
if (!func)
return failure();
@@ -121,7 +123,7 @@ class ROCDLDialectLLVMIRTranslationInterface
attrValueStream << "1," << value.getInt();
llvmFunc->addFnAttr("amdgpu-flat-work-group-size", llvmAttrValue);
}
- if (ROCDL::ROCDLDialect::getFlatWorkGroupSizeAttrName() ==
+ if (dialect->getFlatWorkGroupSizeAttrHelper().getName() ==
attribute.getName()) {
auto func = dyn_cast<LLVM::LLVMFuncOp>(op);
if (!func)
@@ -138,7 +140,7 @@ class ROCDLDialectLLVMIRTranslationInterface
}
// Set reqd_work_group_size metadata
- if (ROCDL::ROCDLDialect::getReqdWorkGroupSizeAttrName() ==
+ if (dialect->getReqdWorkGroupSizeAttrHelper().getName() ==
attribute.getName()) {
auto func = dyn_cast<LLVM::LLVMFuncOp>(op);
if (!func)
diff --git a/mlir/test/lib/Dialect/Test/TestDialect.td b/mlir/test/lib/Dialect/Test/TestDialect.td
index 8524d5b1458447..2b5491fc0c6a02 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.td
+++ b/mlir/test/lib/Dialect/Test/TestDialect.td
@@ -25,6 +25,10 @@ def Test_Dialect : Dialect {
let useDefaultAttributePrinterParser = 1;
let isExtensible = 1;
let dependentDialects = ["::mlir::DLTIDialect"];
+ let discardableAttrs = (ins
+ "mlir::IntegerAttr":$discardable_attr_key,
+ "SimpleAAttr":$other_discardable_attr_key
+ );
let extraClassDeclaration = [{
void registerAttributes();
diff --git a/mlir/tools/mlir-tblgen/DialectGen.cpp b/mlir/tools/mlir-tblgen/DialectGen.cpp
index f22434f755abe3..a0ebd9ce8f29e3 100644
--- a/mlir/tools/mlir-tblgen/DialectGen.cpp
+++ b/mlir/tools/mlir-tblgen/DialectGen.cpp
@@ -43,6 +43,21 @@ using DialectFilterIterator =
std::function<bool(const llvm::Record *)>>;
} // namespace
+static void populateDiscardableAttributes(
+ Dialect &dialect, llvm::DagInit *discardableAttrDag,
+ SmallVector<std::pair<std::string, std::string>> &discardableAttributes) {
+ for (int i : llvm::seq<int>(0, discardableAttrDag->getNumArgs())) {
+ llvm::Init *arg = discardableAttrDag->getArg(i);
+
+ StringRef givenName = discardableAttrDag->getArgNameStr(i);
+ if (givenName.empty())
+ PrintFatalError(dialect.getDef()->getLoc(),
+ "discardable attributes must be named");
+ discardableAttributes.push_back(
+ {givenName.str(), arg->getAsUnquotedString()});
+ }
+}
+
/// Given a set of records for a T, filter the ones that correspond to
/// the given dialect.
template <typename T>
@@ -181,6 +196,37 @@ static const char *const operationInterfaceFallbackDecl = R"(
mlir::OperationName opName) override;
)";
+/// The code block for the discardable attribute helper.
+static const char *const discardableAttrHelperDecl = R"(
+ /// Helper to manage the discardable attribute `{1}`.
+ class {0}AttrHelper {{
+ mlir::StringAttr name;
+ public:
+ static constexpr llvm::StringLiteral getNameStr() {{
+ return "{4}.{1}";
+ }
+ constexpr mlir::StringAttr getName() {{
+ return name;
+ }
+
+ {0}AttrHelper(mlir::MLIRContext *ctx)
+ : name(mlir::StringAttr::get(ctx, getNameStr())) {{}
+
+ {2} getAttr(::mlir::Operation *op) {{
+ return op->getAttrOfType<{2}>(getName());
+ }
+ void setAttr(::mlir::Operation *op, {2} val) {{
+ op->setAttr(getName(), val);
+ }
+ };
+ {0}AttrHelper get{0}AttrHelper() {
+ return {3}AttrName;
+ }
+ private:
+ {0}AttrHelper {3}AttrName;
+ public:
+)";
+
/// Generate the declaration for the given dialect class.
static void emitDialectDecl(Dialect &dialect, raw_ostream &os) {
// Emit all nested namespaces.
@@ -216,6 +262,22 @@ static void emitDialectDecl(Dialect &dialect, raw_ostream &os) {
os << regionResultAttrVerifierDecl;
if (dialect.hasOperationInterfaceFallback())
os << operationInterfaceFallbackDecl;
+
+ llvm::DagInit *discardableAttrDag = dialect.getDiscardableAttributes();
+ SmallVector<std::pair<std::string, std::string>> discardableAttributes;
+ populateDiscardableAttributes(dialect, discardableAttrDag,
+ discardableAttributes);
+
+ for (const auto &attrPair : discardableAttributes) {
+ std::string camelNameUpper = llvm::convertToCamelFromSnakeCase(
+ attrPair.first, /*capitalizeFirst=*/true);
+ std::string camelName = llvm::convertToCamelFromSnakeCase(
+ attrPair.first, /*capitalizeFirst=*/false);
+ os << llvm::formatv(discardableAttrHelperDecl, camelNameUpper,
+ attrPair.first, attrPair.second, camelName,
+ dialect.getName());
+ }
+
if (std::optional<StringRef> extraDecl = dialect.getExtraClassDeclaration())
os << *extraDecl;
@@ -253,9 +315,12 @@ static bool emitDialectDecls(const llvm::RecordKeeper &recordKeeper,
/// {1}: initialization code that is emitted in the ctor body before calling
/// initialize().
/// {2}: The dialect parent class.
+/// {3}: Extra members to initialize
static const char *const dialectConstructorStr = R"(
{0}::{0}(::mlir::MLIRContext *context)
- : ::mlir::{2}(getDialectNamespace(), context, ::mlir::TypeID::get<{0}>()) {{
+ : ::mlir::{2}(getDialectNamespace(), context, ::mlir::TypeID::get<{0}>())
+ {3}
+ {{
{1}
initialize();
}
@@ -269,7 +334,9 @@ static const char *const dialectDestructorStr = R"(
)";
-static void emitDialectDef(Dialect &dialect, raw_ostream &os) {
+static void emitDialectDef(Dialect &dialect,
+ const llvm::RecordKeeper &recordKeeper,
+ raw_ostream &os) {
std::string cppClassName = dialect.getCppClassName();
// Emit the TypeID explicit specializations to have a single symbol def.
@@ -292,8 +359,22 @@ static void emitDialectDef(Dialect &dialect, raw_ostream &os) {
// Emit the constructor and destructor.
StringRef superClassName =
dialect.isExtensible() ? "ExtensibleDialect" : "Dialect";
+
+ llvm::DagInit *discardableAttrDag = dialect.getDiscardableAttributes();
+ SmallVector<std::pair<std::string, std::string>> discardableAttributes;
+ populateDiscardableAttributes(dialect, discardableAttrDag,
+ discardableAttributes);
+ std::string discardableAttributesInit;
+ for (const auto &attrPair : discardableAttributes) {
+ std::string camelName = llvm::convertToCamelFromSnakeCase(
+ attrPair.first, /*capitalizeFirst=*/false);
+ llvm::raw_string_ostream os(discardableAttributesInit);
+ os << ", " << camelName << "AttrName(context)";
+ }
+
os << llvm::formatv(dialectConstructorStr, cppClassName,
- dependentDialectRegistrations, superClassName);
+ dependentDialectRegistrations, superClassName,
+ discardableAttributesInit);
if (!dialect.hasNonDefaultDestructor())
os << llvm::formatv(dialectDestructorStr, cppClassName);
}
@@ -310,7 +391,7 @@ static bool emitDialectDefs(const llvm::RecordKeeper &recordKeeper,
std::optional<Dialect> dialect = findDialectToGenerate(dialects);
if (!dialect)
return true;
- emitDialectDef(*dialect, os);
+ emitDialectDef(*dialect, recordKeeper, os);
return false;
}
``````````
</details>
https://github.com/llvm/llvm-project/pull/77024
More information about the Mlir-commits
mailing list