[Mlir-commits] [mlir] [NFC] Improve readability of AttrHelper usage (PR #135873)
Simon Waters
llvmlistbot at llvm.org
Tue Apr 15 16:57:35 PDT 2025
https://github.com/sjw36 updated https://github.com/llvm/llvm-project/pull/135873
>From bbbbbacb2a49003711f6ab1e5c5d4743d3f3d912 Mon Sep 17 00:00:00 2001
From: Simon Waters <simon at kernelize.ai>
Date: Tue, 15 Apr 2025 17:35:17 -0500
Subject: [PATCH 1/2] [NFC] Improve readability of AttrHelper usage.
---
.../Conversion/GPUCommon/GPUOpsLowering.cpp | 2 +-
.../GPUCommon/IndexIntrinsicsOpLowering.h | 13 ++---
.../GPUToROCDL/LowerGpuOpsToROCDLOps.cpp | 5 +-
.../ROCDL/ROCDLToLLVMIRTranslation.cpp | 2 +-
mlir/tools/mlir-tblgen/DialectGen.cpp | 53 +++++++++++++------
5 files changed, 46 insertions(+), 29 deletions(-)
diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
index f22ad1fd70db2..1b4ea6b1164ec 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
@@ -194,7 +194,7 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
DenseI32ArrayAttr knownGridSize = gpuFuncOp.getKnownGridSizeAttr();
// Ensure we don't lose information if the function is lowered before its
// surrounding context.
- auto *gpuDialect = cast<gpu::GPUDialect>(gpuFuncOp->getDialect());
+ auto *gpuDialect = gpu::GPUDialect::getLoaded(gpuFuncOp);
if (knownBlockSize)
attributes.emplace_back(gpuDialect->getKnownBlockSizeAttrHelper().getName(),
knownBlockSize);
diff --git a/mlir/lib/Conversion/GPUCommon/IndexIntrinsicsOpLowering.h b/mlir/lib/Conversion/GPUCommon/IndexIntrinsicsOpLowering.h
index 1f158b271e5c6..d7aa5f70d984a 100644
--- a/mlir/lib/Conversion/GPUCommon/IndexIntrinsicsOpLowering.h
+++ b/mlir/lib/Conversion/GPUCommon/IndexIntrinsicsOpLowering.h
@@ -74,21 +74,18 @@ struct OpLowering : public ConvertOpToLLVMPattern<Op> {
// 3. Discardable attributes on a surrounding function of any kind
// The below code handles these in reverse order so that more important
// sources overwrite less important ones.
+ auto *gpuDialect = gpu::GPUDialect::getLoaded(op);
DenseI32ArrayAttr funcBounds = nullptr;
if (auto funcOp = op->template getParentOfType<FunctionOpInterface>()) {
switch (indexKind) {
case IndexKind::Block: {
- auto blockHelper =
- gpu::GPUDialect::KnownBlockSizeAttrHelper(op.getContext());
- if (blockHelper.isAttrPresent(funcOp))
- funcBounds = blockHelper.getAttr(funcOp);
+ auto blockHelper = gpuDialect->getKnownBlockSizeAttrHelper();
+ funcBounds = blockHelper.getAttr(funcOp);
break;
}
case IndexKind::Grid: {
- auto gridHelper =
- gpu::GPUDialect::KnownGridSizeAttrHelper(op.getContext());
- if (gridHelper.isAttrPresent(funcOp))
- funcBounds = gridHelper.getAttr(funcOp);
+ auto gridHelper = gpuDialect->getKnownGridSizeAttrHelper();
+ funcBounds = gridHelper.getAttr(funcOp);
break;
}
case IndexKind::Other:
diff --git a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
index c6c695b442b4f..4a4c97dfc7bc0 100644
--- a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
+++ b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
@@ -326,7 +326,7 @@ struct LowerGpuOpsToROCDLOpsPass final
configureGpuToROCDLConversionLegality(target);
if (failed(applyPartialConversion(m, target, std::move(llvmPatterns))))
signalPassFailure();
- auto *rocdlDialect = getContext().getLoadedDialect<ROCDL::ROCDLDialect>();
+ auto *rocdlDialect = ROCDL::ROCDLDialect::getLoaded(getContext());
auto reqdWorkGroupSizeAttrHelper =
rocdlDialect->getReqdWorkGroupSizeAttrHelper();
auto flatWorkGroupSizeAttrHelper =
@@ -374,8 +374,7 @@ void mlir::populateGpuToROCDLConversionPatterns(
using gpu::index_lowering::IndexKind;
using gpu::index_lowering::IntrType;
using mlir::gpu::amd::Runtime;
- auto *rocdlDialect =
- converter.getContext().getLoadedDialect<ROCDL::ROCDLDialect>();
+ auto *rocdlDialect = ROCDL::ROCDLDialect::getLoaded(converter.getContext());
populateWithGenerated(patterns);
patterns.add<
gpu::index_lowering::OpLowering<gpu::ThreadIdOp, ROCDL::ThreadIdXOp,
diff --git a/mlir/lib/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.cpp
index 88a9d4c2a7ef2..abc46bf0e25f1 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.cpp
@@ -76,7 +76,7 @@ class ROCDLDialectLLVMIRTranslationInterface
amendOperation(Operation *op, ArrayRef<llvm::Instruction *> instructions,
NamedAttribute attribute,
LLVM::ModuleTranslation &moduleTranslation) const final {
- auto *dialect = dyn_cast<ROCDL::ROCDLDialect>(attribute.getNameDialect());
+ auto *dialect = ROCDL::ROCDLDialect::getLoaded(op);
llvm::LLVMContext &llvmContext = moduleTranslation.getLLVMContext();
if (dialect->getKernelAttrHelper().getName() == attribute.getName()) {
auto func = dyn_cast<LLVM::LLVMFuncOp>(op);
diff --git a/mlir/tools/mlir-tblgen/DialectGen.cpp b/mlir/tools/mlir-tblgen/DialectGen.cpp
index 6cf71d2bb0174..bbd67d274e851 100644
--- a/mlir/tools/mlir-tblgen/DialectGen.cpp
+++ b/mlir/tools/mlir-tblgen/DialectGen.cpp
@@ -110,15 +110,23 @@ tblgen::findDialectToGenerate(ArrayRef<Dialect> dialects) {
/// {2}: The dialect parent class.
static const char *const dialectDeclBeginStr = R"(
class {0} : public ::mlir::{2} {
+ typedef {0} DialectType;
explicit {0}(::mlir::MLIRContext *context);
void initialize();
friend class ::mlir::MLIRContext;
public:
~{0}() override;
- static constexpr ::llvm::StringLiteral getDialectNamespace() {
+ static constexpr ::llvm::StringLiteral getDialectNamespace() {{
return ::llvm::StringLiteral("{1}");
}
+ static const DialectType *getLoaded(::mlir::MLIRContext &context) {{
+ return context.getLoadedDialect<DialectType>();
+ }
+ static const DialectType *getLoaded(::mlir::MLIRContext *context) {{
+ return getLoaded(*context);
+ }
+ static const DialectType *getLoaded(::mlir::Operation *operation);
)";
/// Registration for a single dependent dialect: to be inserted in the ctor
@@ -206,28 +214,28 @@ static const char *const discardableAttrHelperDecl = R"(
static constexpr ::llvm::StringLiteral getNameStr() {{
return "{4}.{1}";
}
- constexpr ::mlir::StringAttr getName() {{
+ constexpr ::mlir::StringAttr getName() const {{
return name;
}
{0}AttrHelper(::mlir::MLIRContext *ctx)
: name(::mlir::StringAttr::get(ctx, getNameStr())) {{}
- {2} getAttr(::mlir::Operation *op) {{
- return op->getAttrOfType<{2}>(name);
- }
- void setAttr(::mlir::Operation *op, {2} val) {{
- op->setAttr(name, val);
- }
- bool isAttrPresent(::mlir::Operation *op) {{
- return op->hasAttrOfType<{2}>(name);
- }
- void removeAttr(::mlir::Operation *op) {{
- assert(op->hasAttrOfType<{2}>(name));
- op->removeAttr(name);
- }
+ {2} getAttr(::mlir::Operation *op) const {{
+ return op->getAttrOfType<{2}>(name);
+ }
+ void setAttr(::mlir::Operation *op, {2} val) const {{
+ op->setAttr(name, val);
+ }
+ bool isAttrPresent(::mlir::Operation *op) const {{
+ return op->hasAttrOfType<{2}>(name);
+ }
+ void removeAttr(::mlir::Operation *op) const {{
+ assert(op->hasAttrOfType<{2}>(name));
+ op->removeAttr(name);
+ }
};
- {0}AttrHelper get{0}AttrHelper() {
+ const {0}AttrHelper get{0}AttrHelper() const {
return {3}AttrName;
}
private:
@@ -341,7 +349,17 @@ static const char *const dialectDestructorStr = R"(
{0}::~{0}() = default;
)";
+
+/// The code block to generate a member funcs.
+///
+/// {0}: The name of the dialect class.
+static const char *const dialectStaticMemberDefs = R"(
+const {0} *{0}::getLoaded(::mlir::Operation *operation) {{
+ return getLoaded(*operation->getContext());
+}
+)";
+
static void emitDialectDef(Dialect &dialect, const RecordKeeper &records,
raw_ostream &os) {
std::string cppClassName = dialect.getCppClassName();
@@ -388,6 +406,9 @@ static void emitDialectDef(Dialect &dialect, const RecordKeeper &records,
discardableAttributesInit);
if (!dialect.hasNonDefaultDestructor())
os << llvm::formatv(dialectDestructorStr, cppClassName);
+
+ // Emit member function definitions.
+ os << llvm::formatv(dialectStaticMemberDefs, cppClassName);
}
static bool emitDialectDefs(const RecordKeeper &records, raw_ostream &os) {
>From 83f8b7eefc4039df3c4c127712dec0a91b0272b4 Mon Sep 17 00:00:00 2001
From: Simon Waters <simon at kernelize.ai>
Date: Tue, 15 Apr 2025 18:56:56 -0500
Subject: [PATCH 2/2] * format
---
mlir/tools/mlir-tblgen/DialectGen.cpp | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/mlir/tools/mlir-tblgen/DialectGen.cpp b/mlir/tools/mlir-tblgen/DialectGen.cpp
index bbd67d274e851..700f68e940f13 100644
--- a/mlir/tools/mlir-tblgen/DialectGen.cpp
+++ b/mlir/tools/mlir-tblgen/DialectGen.cpp
@@ -349,7 +349,7 @@ static const char *const dialectDestructorStr = R"(
{0}::~{0}() = default;
)";
-
+
/// The code block to generate a member funcs.
///
/// {0}: The name of the dialect class.
@@ -359,7 +359,7 @@ const {0} *{0}::getLoaded(::mlir::Operation *operation) {{
}
)";
-
+
static void emitDialectDef(Dialect &dialect, const RecordKeeper &records,
raw_ostream &os) {
std::string cppClassName = dialect.getCppClassName();
More information about the Mlir-commits
mailing list