[Mlir-commits] [mlir] [MLIR] Add ODS support for generating helpers for dialect (discardable) attributes (PR #77024)

Mehdi Amini llvmlistbot at llvm.org
Mon Feb 19 18:05:02 PST 2024


https://github.com/joker-eph updated https://github.com/llvm/llvm-project/pull/77024

>From 40a814a5175c51d5764ce2cf321bc4cfc56c74ff Mon Sep 17 00:00:00 2001
From: Mehdi Amini <joker.eph at gmail.com>
Date: Thu, 14 Dec 2023 20:55:49 +0000
Subject: [PATCH] [MLIR] Add ODS support for generating helpers for dialect
 (discardable) attributes

This is a new ODS feature that allows dialects to define a list of
key/value pair representing an attribute type and a name.
This will generate helper classes on the dialect to be able to
manage discardable attributes on operations in a type safe way.

For example the `test` dialect can define:

```
  let discardableAttrs = (ins
     "mlir::IntegerAttr":$discardable_attr_key,
  );
```

And the following will be generated in the TestDialect class:

```
   /// Helper to manage the discardable attribute `discardable_attr_key`.
    class DiscardableAttrKeyAttrHelper {
      ::mlir::StringAttr name;
    public:
      static constexpr ::llvm::StringLiteral getNameStr() {
        return "test.discardable_attr_key";
      }
      constexpr ::mlir::StringAttr getName() {
        return name;
      }

      DiscardableAttrKeyAttrHelper(::mlir::MLIRContext *ctx)
        : name(::mlir::StringAttr::get(ctx, getNameStr())) {}

     mlir::IntegerAttr getAttr(::mlir::Operation *op) {
       return op->getAttrOfType<mlir::IntegerAttr>(name);
     }
     void setAttr(::mlir::Operation *op, mlir::IntegerAttr val) {
       op->setAttr(name, val);
     }
     bool isAttrPresent(::mlir::Operation *op) {
       return op->hasAttrOfType<mlir::IntegerAttr>(name);
     }
     void removeAttr(::mlir::Operation *op) {
       assert(op->hasAttrOfType<mlir::IntegerAttr>(name));
       op->removeAttr(name);
     }
   };
   DiscardableAttrKeyAttrHelper getDiscardableAttrKeyAttrHelper() {
     return discardableAttrKeyAttrName;
   }
```

User code having an instance of the TestDialect can then manipulate this
attribute on operation using:

```
  auto helper = testDialect.getDiscardableAttrKeyAttrHelper();

  helper.setAttr(op, value);
  helper.isAttrPresent(op);
  ...
```
---
 mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td  | 18 ++--
 mlir/include/mlir/IR/DialectBase.td           |  5 +
 mlir/include/mlir/TableGen/Dialect.h          |  6 ++
 .../GPUToROCDL/LowerGpuOpsToROCDLOps.cpp      | 17 ++--
 mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp   |  4 +-
 mlir/lib/TableGen/Dialect.cpp                 |  4 +
 .../ROCDL/ROCDLToLLVMIRTranslation.cpp        | 11 ++-
 mlir/test/lib/Dialect/Test/TestDialect.td     |  4 +
 mlir/tools/mlir-tblgen/DialectGen.cpp         | 96 ++++++++++++++++++-
 9 files changed, 136 insertions(+), 29 deletions(-)

diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
index 962c159e68a2ee..6b170c8d06f437 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
@@ -28,16 +28,6 @@ def ROCDL_Dialect : Dialect {
   let hasOperationAttrVerify = 1;
 
   let extraClassDeclaration = [{
-    /// Get the name of the attribute used to annotate external kernel
-    /// functions.
-    static StringRef getKernelFuncAttrName() { return "rocdl.kernel"; }
-    static constexpr ::llvm::StringLiteral getFlatWorkGroupSizeAttrName() {
-      return ::llvm::StringLiteral("rocdl.flat_work_group_size");
-    }
-    static constexpr ::llvm::StringLiteral getReqdWorkGroupSizeAttrName() {
-      return ::llvm::StringLiteral("rocdl.reqd_work_group_size");
-    }
-
     /// The address space value that represents global memory.
     static constexpr unsigned kGlobalMemoryAddressSpace = 1;
     /// The address space value that represents shared memory.
@@ -46,6 +36,14 @@ def ROCDL_Dialect : Dialect {
     static constexpr unsigned kPrivateMemoryAddressSpace = 5;
   }];
 
+  let discardableAttrs = (ins
+     "::mlir::UnitAttr":$kernel,
+     "::mlir::DenseI32ArrayAttr":$reqd_work_group_size,
+     "::mlir::StringAttr":$flat_work_group_size,
+     "::mlir::IntegerAttr":$max_flat_work_group_size,
+     "::mlir::IntegerAttr":$waves_per_eu
+  );
+
   let useDefaultAttributePrinterParser = 1;
 }
 
diff --git a/mlir/include/mlir/IR/DialectBase.td b/mlir/include/mlir/IR/DialectBase.td
index 5afa23933ea1f7..115a01f92706ab 100644
--- a/mlir/include/mlir/IR/DialectBase.td
+++ b/mlir/include/mlir/IR/DialectBase.td
@@ -34,6 +34,11 @@ class Dialect {
   // pattern or interfaces.
   list<string> dependentDialects = [];
 
+  // A list of key/value pair representing an attribute type and a name.
+  // This will generate helper classes on the dialect to be able to
+  // manage discardable attributes on operations in a type safe way.
+  dag discardableAttrs = (ins);
+
   // The C++ namespace that ops of this dialect should be placed into.
   //
   // By default, uses the name of the dialect as the only namespace. To avoid
diff --git a/mlir/include/mlir/TableGen/Dialect.h b/mlir/include/mlir/TableGen/Dialect.h
index 5337bd3beb5f9d..3530d240c976c6 100644
--- a/mlir/include/mlir/TableGen/Dialect.h
+++ b/mlir/include/mlir/TableGen/Dialect.h
@@ -14,6 +14,8 @@
 #define MLIR_TABLEGEN_DIALECT_H_
 
 #include "mlir/Support/LLVM.h"
+#include "llvm/TableGen/Record.h"
+
 #include <string>
 #include <vector>
 
@@ -90,6 +92,10 @@ class Dialect {
   /// dialect.
   bool usePropertiesForAttributes() const;
 
+  llvm::DagInit *getDiscardableAttributes() const;
+
+  const llvm::Record *getDef() const { return def; }
+
   // Returns whether two dialects are equal by checking the equality of the
   // underlying record.
   bool operator==(const Dialect &other) const;
diff --git a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
index 363e6016113b16..4fa3cdcbf85ce2 100644
--- a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
+++ b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
@@ -285,14 +285,17 @@ struct LowerGpuOpsToROCDLOpsPass
     configureGpuToROCDLConversionLegality(target);
     if (failed(applyPartialConversion(m, target, std::move(llvmPatterns))))
       signalPassFailure();
-
+    auto *rocdlDialect = getContext().getLoadedDialect<ROCDL::ROCDLDialect>();
+    auto reqdWorkGroupSizeAttrHelper =
+        rocdlDialect->getReqdWorkGroupSizeAttrHelper();
+    auto flatWorkGroupSizeAttrHelper =
+        rocdlDialect->getFlatWorkGroupSizeAttrHelper();
     // Manually rewrite known block size attributes so the LLVMIR translation
     // infrastructure can pick them up.
-    m.walk([ctx](LLVM::LLVMFuncOp op) {
+    m.walk([&](LLVM::LLVMFuncOp op) {
       if (auto blockSizes = dyn_cast_or_null<DenseI32ArrayAttr>(
               op->removeAttr(gpu::GPUFuncOp::getKnownBlockSizeAttrName()))) {
-        op->setAttr(ROCDL::ROCDLDialect::getReqdWorkGroupSizeAttrName(),
-                    blockSizes);
+        reqdWorkGroupSizeAttrHelper.setAttr(op, blockSizes);
         // Also set up the rocdl.flat_work_group_size attribute to prevent
         // conflicting metadata.
         uint32_t flatSize = 1;
@@ -301,8 +304,7 @@ struct LowerGpuOpsToROCDLOpsPass
         }
         StringAttr flatSizeAttr =
             StringAttr::get(ctx, Twine(flatSize) + "," + Twine(flatSize));
-        op->setAttr(ROCDL::ROCDLDialect::getFlatWorkGroupSizeAttrName(),
-                    flatSizeAttr);
+        flatWorkGroupSizeAttrHelper.setAttr(op, flatSizeAttr);
       }
     });
   }
@@ -355,8 +357,7 @@ void mlir::populateGpuToROCDLConversionPatterns(
       converter,
       /*allocaAddrSpace=*/ROCDL::ROCDLDialect::kPrivateMemoryAddressSpace,
       /*workgroupAddrSpace=*/ROCDL::ROCDLDialect::kSharedMemoryAddressSpace,
-      StringAttr::get(&converter.getContext(),
-                      ROCDL::ROCDLDialect::getKernelFuncAttrName()));
+      ROCDL::ROCDLDialect::KernelAttrHelper(&converter.getContext()).getName());
   if (Runtime::HIP == runtime) {
     patterns.add<GPUPrintfOpToHIPLowering>(converter);
   } else if (Runtime::OpenCL == runtime) {
diff --git a/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp
index 26e46b31ddc018..0f2e75cd7e8bc7 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp
@@ -253,9 +253,9 @@ void ROCDLDialect::initialize() {
 LogicalResult ROCDLDialect::verifyOperationAttribute(Operation *op,
                                                      NamedAttribute attr) {
   // Kernel function attribute should be attached to functions.
-  if (attr.getName() == ROCDLDialect::getKernelFuncAttrName()) {
+  if (kernelAttrName.getName() == attr.getName()) {
     if (!isa<LLVM::LLVMFuncOp>(op)) {
-      return op->emitError() << "'" << ROCDLDialect::getKernelFuncAttrName()
+      return op->emitError() << "'" << kernelAttrName.getName()
                              << "' attribute attached to unexpected op";
     }
   }
diff --git a/mlir/lib/TableGen/Dialect.cpp b/mlir/lib/TableGen/Dialect.cpp
index 6924a2862eef07..081f6e56f9ded4 100644
--- a/mlir/lib/TableGen/Dialect.cpp
+++ b/mlir/lib/TableGen/Dialect.cpp
@@ -106,6 +106,10 @@ bool Dialect::usePropertiesForAttributes() const {
   return def->getValueAsBit("usePropertiesForAttributes");
 }
 
+llvm::DagInit *Dialect::getDiscardableAttributes() const {
+  return def->getValueAsDag("discardableAttrs");
+}
+
 bool Dialect::operator==(const Dialect &other) const {
   return def == other.def;
 }
diff --git a/mlir/lib/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.cpp
index 0cbb3da79d151e..93eb456cdc2c4f 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.cpp
@@ -84,7 +84,8 @@ class ROCDLDialectLLVMIRTranslationInterface
   amendOperation(Operation *op, ArrayRef<llvm::Instruction *> instructions,
                  NamedAttribute attribute,
                  LLVM::ModuleTranslation &moduleTranslation) const final {
-    if (attribute.getName() == ROCDL::ROCDLDialect::getKernelFuncAttrName()) {
+    auto *dialect = dyn_cast<ROCDL::ROCDLDialect>(attribute.getNameDialect());
+    if (dialect->getKernelAttrHelper().getName() == attribute.getName()) {
       auto func = dyn_cast<LLVM::LLVMFuncOp>(op);
       if (!func)
         return failure();
@@ -99,12 +100,12 @@ class ROCDLDialectLLVMIRTranslationInterface
       if (!llvmFunc->hasFnAttribute("amdgpu-flat-work-group-size")) {
         llvmFunc->addFnAttr("amdgpu-flat-work-group-size", "1,256");
       }
-
     }
     // Override flat-work-group-size
     // TODO: update clients to rocdl.flat_work_group_size instead,
     // then remove this half of the branch
-    if ("rocdl.max_flat_work_group_size" == attribute.getName()) {
+    if (dialect->getMaxFlatWorkGroupSizeAttrHelper().getName() ==
+        attribute.getName()) {
       auto func = dyn_cast<LLVM::LLVMFuncOp>(op);
       if (!func)
         return failure();
@@ -119,7 +120,7 @@ class ROCDLDialectLLVMIRTranslationInterface
       attrValueStream << "1," << value.getInt();
       llvmFunc->addFnAttr("amdgpu-flat-work-group-size", llvmAttrValue);
     }
-    if (ROCDL::ROCDLDialect::getFlatWorkGroupSizeAttrName() ==
+    if (dialect->getFlatWorkGroupSizeAttrHelper().getName() ==
         attribute.getName()) {
       auto func = dyn_cast<LLVM::LLVMFuncOp>(op);
       if (!func)
@@ -136,7 +137,7 @@ class ROCDLDialectLLVMIRTranslationInterface
     }
 
     // Set reqd_work_group_size metadata
-    if (ROCDL::ROCDLDialect::getReqdWorkGroupSizeAttrName() ==
+    if (dialect->getReqdWorkGroupSizeAttrHelper().getName() ==
         attribute.getName()) {
       auto func = dyn_cast<LLVM::LLVMFuncOp>(op);
       if (!func)
diff --git a/mlir/test/lib/Dialect/Test/TestDialect.td b/mlir/test/lib/Dialect/Test/TestDialect.td
index 8524d5b1458447..2b5491fc0c6a02 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.td
+++ b/mlir/test/lib/Dialect/Test/TestDialect.td
@@ -25,6 +25,10 @@ def Test_Dialect : Dialect {
   let useDefaultAttributePrinterParser = 1;
   let isExtensible = 1;
   let dependentDialects = ["::mlir::DLTIDialect"];
+  let discardableAttrs = (ins
+     "mlir::IntegerAttr":$discardable_attr_key,
+     "SimpleAAttr":$other_discardable_attr_key
+  );
 
   let extraClassDeclaration = [{
     void registerAttributes();
diff --git a/mlir/tools/mlir-tblgen/DialectGen.cpp b/mlir/tools/mlir-tblgen/DialectGen.cpp
index 4f2021083384fc..46e585a351a2b0 100644
--- a/mlir/tools/mlir-tblgen/DialectGen.cpp
+++ b/mlir/tools/mlir-tblgen/DialectGen.cpp
@@ -43,6 +43,21 @@ using DialectFilterIterator =
                           std::function<bool(const llvm::Record *)>>;
 } // namespace
 
+static void populateDiscardableAttributes(
+    Dialect &dialect, llvm::DagInit *discardableAttrDag,
+    SmallVector<std::pair<std::string, std::string>> &discardableAttributes) {
+  for (int i : llvm::seq<int>(0, discardableAttrDag->getNumArgs())) {
+    llvm::Init *arg = discardableAttrDag->getArg(i);
+
+    StringRef givenName = discardableAttrDag->getArgNameStr(i);
+    if (givenName.empty())
+      PrintFatalError(dialect.getDef()->getLoc(),
+                      "discardable attributes must be named");
+    discardableAttributes.push_back(
+        {givenName.str(), arg->getAsUnquotedString()});
+  }
+}
+
 /// Given a set of records for a T, filter the ones that correspond to
 /// the given dialect.
 template <typename T>
@@ -180,6 +195,44 @@ static const char *const operationInterfaceFallbackDecl = R"(
                                       mlir::OperationName opName) override;
 )";
 
+/// The code block for the discardable attribute helper.
+static const char *const discardableAttrHelperDecl = R"(
+    /// Helper to manage the discardable attribute `{1}`.
+    class {0}AttrHelper {{
+      ::mlir::StringAttr name;
+    public:
+      static constexpr ::llvm::StringLiteral getNameStr() {{
+        return "{4}.{1}";
+      }
+      constexpr ::mlir::StringAttr getName() {{
+        return name;
+      }
+
+      {0}AttrHelper(::mlir::MLIRContext *ctx)
+        : name(::mlir::StringAttr::get(ctx, getNameStr())) {{}
+
+     {2} getAttr(::mlir::Operation *op) {{
+       return op->getAttrOfType<{2}>(name);
+     }
+     void setAttr(::mlir::Operation *op, {2} val) {{
+       op->setAttr(name, val);
+     }
+     bool isAttrPresent(::mlir::Operation *op) {{
+       return op->hasAttrOfType<{2}>(name);
+     }
+     void removeAttr(::mlir::Operation *op) {{
+       assert(op->hasAttrOfType<{2}>(name));
+       op->removeAttr(name);
+     }
+   };
+   {0}AttrHelper get{0}AttrHelper() {
+     return {3}AttrName;
+   }
+ private:
+   {0}AttrHelper {3}AttrName;
+ public:
+)";
+
 /// Generate the declaration for the given dialect class.
 static void emitDialectDecl(Dialect &dialect, raw_ostream &os) {
   // Emit all nested namespaces.
@@ -215,6 +268,22 @@ static void emitDialectDecl(Dialect &dialect, raw_ostream &os) {
       os << regionResultAttrVerifierDecl;
     if (dialect.hasOperationInterfaceFallback())
       os << operationInterfaceFallbackDecl;
+
+    llvm::DagInit *discardableAttrDag = dialect.getDiscardableAttributes();
+    SmallVector<std::pair<std::string, std::string>> discardableAttributes;
+    populateDiscardableAttributes(dialect, discardableAttrDag,
+                                  discardableAttributes);
+
+    for (const auto &attrPair : discardableAttributes) {
+      std::string camelNameUpper = llvm::convertToCamelFromSnakeCase(
+          attrPair.first, /*capitalizeFirst=*/true);
+      std::string camelName = llvm::convertToCamelFromSnakeCase(
+          attrPair.first, /*capitalizeFirst=*/false);
+      os << llvm::formatv(discardableAttrHelperDecl, camelNameUpper,
+                          attrPair.first, attrPair.second, camelName,
+                          dialect.getName());
+    }
+
     if (std::optional<StringRef> extraDecl = dialect.getExtraClassDeclaration())
       os << *extraDecl;
 
@@ -252,9 +321,12 @@ static bool emitDialectDecls(const llvm::RecordKeeper &recordKeeper,
 /// {1}: Initialization code that is emitted in the ctor body before calling
 ///      initialize(), such as dependent dialect registration.
 /// {2}: The dialect parent class.
+/// {3}: Extra members to initialize
 static const char *const dialectConstructorStr = R"(
 {0}::{0}(::mlir::MLIRContext *context)
-    : ::mlir::{2}(getDialectNamespace(), context, ::mlir::TypeID::get<{0}>()) {{
+    : ::mlir::{2}(getDialectNamespace(), context, ::mlir::TypeID::get<{0}>())
+    {3}
+     {{
   {1}
   initialize();
 }
@@ -268,7 +340,9 @@ static const char *const dialectDestructorStr = R"(
 
 )";
 
-static void emitDialectDef(Dialect &dialect, raw_ostream &os) {
+static void emitDialectDef(Dialect &dialect,
+                           const llvm::RecordKeeper &recordKeeper,
+                           raw_ostream &os) {
   std::string cppClassName = dialect.getCppClassName();
 
   // Emit the TypeID explicit specializations to have a single symbol def.
@@ -295,8 +369,22 @@ static void emitDialectDef(Dialect &dialect, raw_ostream &os) {
   // Emit the constructor and destructor.
   StringRef superClassName =
       dialect.isExtensible() ? "ExtensibleDialect" : "Dialect";
+
+  llvm::DagInit *discardableAttrDag = dialect.getDiscardableAttributes();
+  SmallVector<std::pair<std::string, std::string>> discardableAttributes;
+  populateDiscardableAttributes(dialect, discardableAttrDag,
+                                discardableAttributes);
+  std::string discardableAttributesInit;
+  for (const auto &attrPair : discardableAttributes) {
+    std::string camelName = llvm::convertToCamelFromSnakeCase(
+        attrPair.first, /*capitalizeFirst=*/false);
+    llvm::raw_string_ostream os(discardableAttributesInit);
+    os << ", " << camelName << "AttrName(context)";
+  }
+
   os << llvm::formatv(dialectConstructorStr, cppClassName,
-                      dependentDialectRegistrations, superClassName);
+                      dependentDialectRegistrations, superClassName,
+                      discardableAttributesInit);
   if (!dialect.hasNonDefaultDestructor())
     os << llvm::formatv(dialectDestructorStr, cppClassName);
 }
@@ -313,7 +401,7 @@ static bool emitDialectDefs(const llvm::RecordKeeper &recordKeeper,
   std::optional<Dialect> dialect = findDialectToGenerate(dialects);
   if (!dialect)
     return true;
-  emitDialectDef(*dialect, os);
+  emitDialectDef(*dialect, recordKeeper, os);
   return false;
 }
 



More information about the Mlir-commits mailing list