[Mlir-commits] [mlir] [NFC] Improve readability of AttrHelper usage (PR #135873)

Simon Waters llvmlistbot at llvm.org
Tue Apr 15 16:57:35 PDT 2025


https://github.com/sjw36 updated https://github.com/llvm/llvm-project/pull/135873

>From bbbbbacb2a49003711f6ab1e5c5d4743d3f3d912 Mon Sep 17 00:00:00 2001
From: Simon Waters <simon at kernelize.ai>
Date: Tue, 15 Apr 2025 17:35:17 -0500
Subject: [PATCH 1/2] [NFC] Improve readability of AttrHelper usage.

---
 .../Conversion/GPUCommon/GPUOpsLowering.cpp   |  2 +-
 .../GPUCommon/IndexIntrinsicsOpLowering.h     | 13 ++---
 .../GPUToROCDL/LowerGpuOpsToROCDLOps.cpp      |  5 +-
 .../ROCDL/ROCDLToLLVMIRTranslation.cpp        |  2 +-
 mlir/tools/mlir-tblgen/DialectGen.cpp         | 53 +++++++++++++------
 5 files changed, 46 insertions(+), 29 deletions(-)

diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
index f22ad1fd70db2..1b4ea6b1164ec 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
@@ -194,7 +194,7 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
   DenseI32ArrayAttr knownGridSize = gpuFuncOp.getKnownGridSizeAttr();
   // Ensure we don't lose information if the function is lowered before its
   // surrounding context.
-  auto *gpuDialect = cast<gpu::GPUDialect>(gpuFuncOp->getDialect());
+  auto *gpuDialect = gpu::GPUDialect::getLoaded(gpuFuncOp);
   if (knownBlockSize)
     attributes.emplace_back(gpuDialect->getKnownBlockSizeAttrHelper().getName(),
                             knownBlockSize);
diff --git a/mlir/lib/Conversion/GPUCommon/IndexIntrinsicsOpLowering.h b/mlir/lib/Conversion/GPUCommon/IndexIntrinsicsOpLowering.h
index 1f158b271e5c6..d7aa5f70d984a 100644
--- a/mlir/lib/Conversion/GPUCommon/IndexIntrinsicsOpLowering.h
+++ b/mlir/lib/Conversion/GPUCommon/IndexIntrinsicsOpLowering.h
@@ -74,21 +74,18 @@ struct OpLowering : public ConvertOpToLLVMPattern<Op> {
     // 3. Discardable attributes on a surrounding function of any kind
     // The below code handles these in reverse order so that more important
     // sources overwrite less important ones.
+    auto *gpuDialect = gpu::GPUDialect::getLoaded(op);
     DenseI32ArrayAttr funcBounds = nullptr;
     if (auto funcOp = op->template getParentOfType<FunctionOpInterface>()) {
       switch (indexKind) {
       case IndexKind::Block: {
-        auto blockHelper =
-            gpu::GPUDialect::KnownBlockSizeAttrHelper(op.getContext());
-        if (blockHelper.isAttrPresent(funcOp))
-          funcBounds = blockHelper.getAttr(funcOp);
+        auto blockHelper = gpuDialect->getKnownBlockSizeAttrHelper();
+        funcBounds = blockHelper.getAttr(funcOp);
         break;
       }
       case IndexKind::Grid: {
-        auto gridHelper =
-            gpu::GPUDialect::KnownGridSizeAttrHelper(op.getContext());
-        if (gridHelper.isAttrPresent(funcOp))
-          funcBounds = gridHelper.getAttr(funcOp);
+        auto gridHelper = gpuDialect->getKnownGridSizeAttrHelper();
+        funcBounds = gridHelper.getAttr(funcOp);
         break;
       }
       case IndexKind::Other:
diff --git a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
index c6c695b442b4f..4a4c97dfc7bc0 100644
--- a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
+++ b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
@@ -326,7 +326,7 @@ struct LowerGpuOpsToROCDLOpsPass final
     configureGpuToROCDLConversionLegality(target);
     if (failed(applyPartialConversion(m, target, std::move(llvmPatterns))))
       signalPassFailure();
-    auto *rocdlDialect = getContext().getLoadedDialect<ROCDL::ROCDLDialect>();
+    auto *rocdlDialect = ROCDL::ROCDLDialect::getLoaded(getContext());
     auto reqdWorkGroupSizeAttrHelper =
         rocdlDialect->getReqdWorkGroupSizeAttrHelper();
     auto flatWorkGroupSizeAttrHelper =
@@ -374,8 +374,7 @@ void mlir::populateGpuToROCDLConversionPatterns(
   using gpu::index_lowering::IndexKind;
   using gpu::index_lowering::IntrType;
   using mlir::gpu::amd::Runtime;
-  auto *rocdlDialect =
-      converter.getContext().getLoadedDialect<ROCDL::ROCDLDialect>();
+  auto *rocdlDialect = ROCDL::ROCDLDialect::getLoaded(converter.getContext());
   populateWithGenerated(patterns);
   patterns.add<
       gpu::index_lowering::OpLowering<gpu::ThreadIdOp, ROCDL::ThreadIdXOp,
diff --git a/mlir/lib/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.cpp
index 88a9d4c2a7ef2..abc46bf0e25f1 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.cpp
@@ -76,7 +76,7 @@ class ROCDLDialectLLVMIRTranslationInterface
   amendOperation(Operation *op, ArrayRef<llvm::Instruction *> instructions,
                  NamedAttribute attribute,
                  LLVM::ModuleTranslation &moduleTranslation) const final {
-    auto *dialect = dyn_cast<ROCDL::ROCDLDialect>(attribute.getNameDialect());
+    auto *dialect = ROCDL::ROCDLDialect::getLoaded(op);
     llvm::LLVMContext &llvmContext = moduleTranslation.getLLVMContext();
     if (dialect->getKernelAttrHelper().getName() == attribute.getName()) {
       auto func = dyn_cast<LLVM::LLVMFuncOp>(op);
diff --git a/mlir/tools/mlir-tblgen/DialectGen.cpp b/mlir/tools/mlir-tblgen/DialectGen.cpp
index 6cf71d2bb0174..bbd67d274e851 100644
--- a/mlir/tools/mlir-tblgen/DialectGen.cpp
+++ b/mlir/tools/mlir-tblgen/DialectGen.cpp
@@ -110,15 +110,23 @@ tblgen::findDialectToGenerate(ArrayRef<Dialect> dialects) {
 /// {2}: The dialect parent class.
 static const char *const dialectDeclBeginStr = R"(
 class {0} : public ::mlir::{2} {
+  typedef {0} DialectType;
   explicit {0}(::mlir::MLIRContext *context);
 
   void initialize();
   friend class ::mlir::MLIRContext;
 public:
   ~{0}() override;
-  static constexpr ::llvm::StringLiteral getDialectNamespace() {
+  static constexpr ::llvm::StringLiteral getDialectNamespace() {{
     return ::llvm::StringLiteral("{1}");
   }
+  static const DialectType *getLoaded(::mlir::MLIRContext &context) {{
+    return context.getLoadedDialect<DialectType>();
+  }
+  static const DialectType *getLoaded(::mlir::MLIRContext *context) {{
+    return getLoaded(*context);
+  }
+  static const DialectType *getLoaded(::mlir::Operation *operation);
 )";
 
 /// Registration for a single dependent dialect: to be inserted in the ctor
@@ -206,28 +214,28 @@ static const char *const discardableAttrHelperDecl = R"(
       static constexpr ::llvm::StringLiteral getNameStr() {{
         return "{4}.{1}";
       }
-      constexpr ::mlir::StringAttr getName() {{
+      constexpr ::mlir::StringAttr getName() const {{
         return name;
       }
 
       {0}AttrHelper(::mlir::MLIRContext *ctx)
         : name(::mlir::StringAttr::get(ctx, getNameStr())) {{}
 
-     {2} getAttr(::mlir::Operation *op) {{
-       return op->getAttrOfType<{2}>(name);
-     }
-     void setAttr(::mlir::Operation *op, {2} val) {{
-       op->setAttr(name, val);
-     }
-     bool isAttrPresent(::mlir::Operation *op) {{
-       return op->hasAttrOfType<{2}>(name);
-     }
-     void removeAttr(::mlir::Operation *op) {{
-       assert(op->hasAttrOfType<{2}>(name));
-       op->removeAttr(name);
-     }
+      {2} getAttr(::mlir::Operation *op) const {{
+        return op->getAttrOfType<{2}>(name);
+      }
+      void setAttr(::mlir::Operation *op, {2} val) const {{
+        op->setAttr(name, val);
+      }
+      bool isAttrPresent(::mlir::Operation *op) const {{
+        return op->hasAttrOfType<{2}>(name);
+      }
+      void removeAttr(::mlir::Operation *op) const {{
+        assert(op->hasAttrOfType<{2}>(name));
+        op->removeAttr(name);
+      }
    };
-   {0}AttrHelper get{0}AttrHelper() {
+   const {0}AttrHelper get{0}AttrHelper() const {
      return {3}AttrName;
    }
  private:
@@ -341,7 +349,17 @@ static const char *const dialectDestructorStr = R"(
 {0}::~{0}() = default;
 
 )";
+  
+/// The code block to generate a member funcs.
+///
+/// {0}: The name of the dialect class.
+static const char *const dialectStaticMemberDefs = R"(
+const {0} *{0}::getLoaded(::mlir::Operation *operation) {{
+  return getLoaded(*operation->getContext());
+}
 
+)";
+  
 static void emitDialectDef(Dialect &dialect, const RecordKeeper &records,
                            raw_ostream &os) {
   std::string cppClassName = dialect.getCppClassName();
@@ -388,6 +406,9 @@ static void emitDialectDef(Dialect &dialect, const RecordKeeper &records,
                       discardableAttributesInit);
   if (!dialect.hasNonDefaultDestructor())
     os << llvm::formatv(dialectDestructorStr, cppClassName);
+
+  // Emit member function definitions.
+  os << llvm::formatv(dialectStaticMemberDefs, cppClassName);
 }
 
 static bool emitDialectDefs(const RecordKeeper &records, raw_ostream &os) {

>From 83f8b7eefc4039df3c4c127712dec0a91b0272b4 Mon Sep 17 00:00:00 2001
From: Simon Waters <simon at kernelize.ai>
Date: Tue, 15 Apr 2025 18:56:56 -0500
Subject: [PATCH 2/2] * format

---
 mlir/tools/mlir-tblgen/DialectGen.cpp | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/mlir/tools/mlir-tblgen/DialectGen.cpp b/mlir/tools/mlir-tblgen/DialectGen.cpp
index bbd67d274e851..700f68e940f13 100644
--- a/mlir/tools/mlir-tblgen/DialectGen.cpp
+++ b/mlir/tools/mlir-tblgen/DialectGen.cpp
@@ -349,7 +349,7 @@ static const char *const dialectDestructorStr = R"(
 {0}::~{0}() = default;
 
 )";
-  
+
 /// The code block to generate a member funcs.
 ///
 /// {0}: The name of the dialect class.
@@ -359,7 +359,7 @@ const {0} *{0}::getLoaded(::mlir::Operation *operation) {{
 }
 
 )";
-  
+
 static void emitDialectDef(Dialect &dialect, const RecordKeeper &records,
                            raw_ostream &os) {
   std::string cppClassName = dialect.getCppClassName();



More information about the Mlir-commits mailing list