[Mlir-commits] [mlir] [MLIR] Add ODS support for generating helpers for dialect (discardable) attributes (PR #77024)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Jan 4 15:33:19 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-ods

@llvm/pr-subscribers-mlir-llvm

Author: Mehdi Amini (joker-eph)

<details>
<summary>Changes</summary>

WIP (missing docs)

See https://github.com/llvm/llvm-project/pull/75118 for context

---
Full diff: https://github.com/llvm/llvm-project/pull/77024.diff


9 Files Affected:

- (modified) mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td (+8-10) 
- (modified) mlir/include/mlir/IR/DialectBase.td (+2) 
- (modified) mlir/include/mlir/TableGen/Dialect.h (+6) 
- (modified) mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp (+9-8) 
- (modified) mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp (+2-2) 
- (modified) mlir/lib/TableGen/Dialect.cpp (+4) 
- (modified) mlir/lib/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.cpp (+6-4) 
- (modified) mlir/test/lib/Dialect/Test/TestDialect.td (+4) 
- (modified) mlir/tools/mlir-tblgen/DialectGen.cpp (+85-4) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
index 48b830ae34f292..6abcfdf60e0fd0 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
@@ -28,16 +28,6 @@ def ROCDL_Dialect : Dialect {
   let hasOperationAttrVerify = 1;
 
   let extraClassDeclaration = [{
-    /// Get the name of the attribute used to annotate external kernel
-    /// functions.
-    static StringRef getKernelFuncAttrName() { return "rocdl.kernel"; }
-    static constexpr ::llvm::StringLiteral getFlatWorkGroupSizeAttrName() {
-      return ::llvm::StringLiteral("rocdl.flat_work_group_size");
-    }
-    static constexpr ::llvm::StringLiteral getReqdWorkGroupSizeAttrName() {
-      return ::llvm::StringLiteral("rocdl.reqd_work_group_size");
-    }
-
     /// The address space value that represents global memory.
     static constexpr unsigned kGlobalMemoryAddressSpace = 1;
     /// The address space value that represents shared memory.
@@ -46,6 +36,14 @@ def ROCDL_Dialect : Dialect {
     static constexpr unsigned kPrivateMemoryAddressSpace = 5;
   }];
 
+  let discardableAttrs = (ins
+     "::mlir::UnitAttr":$kernel,
+     "::mlir::DenseI32ArrayAttr":$reqd_work_group_size,
+     "::mlir::StringAttr":$flat_work_group_size,
+     "::mlir::IntegerAttr":$max_flat_work_group_size,
+     "::mlir::IntegerAttr":$waves_per_eu
+  );
+
   let useDefaultAttributePrinterParser = 1;
 }
 
diff --git a/mlir/include/mlir/IR/DialectBase.td b/mlir/include/mlir/IR/DialectBase.td
index 5afa23933ea1f7..16750dc7e4d320 100644
--- a/mlir/include/mlir/IR/DialectBase.td
+++ b/mlir/include/mlir/IR/DialectBase.td
@@ -34,6 +34,8 @@ class Dialect {
   // pattern or interfaces.
   list<string> dependentDialects = [];
 
+  dag discardableAttrs = (ins);
+
   // The C++ namespace that ops of this dialect should be placed into.
   //
   // By default, uses the name of the dialect as the only namespace. To avoid
diff --git a/mlir/include/mlir/TableGen/Dialect.h b/mlir/include/mlir/TableGen/Dialect.h
index 5337bd3beb5f9d..3530d240c976c6 100644
--- a/mlir/include/mlir/TableGen/Dialect.h
+++ b/mlir/include/mlir/TableGen/Dialect.h
@@ -14,6 +14,8 @@
 #define MLIR_TABLEGEN_DIALECT_H_
 
 #include "mlir/Support/LLVM.h"
+#include "llvm/TableGen/Record.h"
+
 #include <string>
 #include <vector>
 
@@ -90,6 +92,10 @@ class Dialect {
   /// dialect.
   bool usePropertiesForAttributes() const;
 
+  llvm::DagInit *getDiscardableAttributes() const;
+
+  const llvm::Record *getDef() const { return def; }
+
   // Returns whether two dialects are equal by checking the equality of the
   // underlying record.
   bool operator==(const Dialect &other) const;
diff --git a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
index 599bb13190f12d..abd2733ba4fbd1 100644
--- a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
+++ b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
@@ -285,14 +285,17 @@ struct LowerGpuOpsToROCDLOpsPass
     configureGpuToROCDLConversionLegality(target);
     if (failed(applyPartialConversion(m, target, std::move(llvmPatterns))))
       signalPassFailure();
-
+    auto *rocdlDialect = getContext().getLoadedDialect<ROCDL::ROCDLDialect>();
+    auto reqdWorkGroupSizeAttrHelper =
+        rocdlDialect->getReqdWorkGroupSizeAttrHelper();
+    auto flatWorkGroupSizeAttrHelper =
+        rocdlDialect->getFlatWorkGroupSizeAttrHelper();
     // Manually rewrite known block size attributes so the LLVMIR translation
     // infrastructure can pick them up.
-    m.walk([ctx](LLVM::LLVMFuncOp op) {
+    m.walk([&](LLVM::LLVMFuncOp op) {
       if (auto blockSizes = dyn_cast_or_null<DenseI32ArrayAttr>(
               op->removeAttr(gpu::GPUFuncOp::getKnownBlockSizeAttrName()))) {
-        op->setAttr(ROCDL::ROCDLDialect::getReqdWorkGroupSizeAttrName(),
-                    blockSizes);
+        reqdWorkGroupSizeAttrHelper.setAttr(op, blockSizes);
         // Also set up the rocdl.flat_work_group_size attribute to prevent
         // conflicting metadata.
         uint32_t flatSize = 1;
@@ -301,8 +304,7 @@ struct LowerGpuOpsToROCDLOpsPass
         }
         StringAttr flatSizeAttr =
             StringAttr::get(ctx, Twine(flatSize) + "," + Twine(flatSize));
-        op->setAttr(ROCDL::ROCDLDialect::getFlatWorkGroupSizeAttrName(),
-                    flatSizeAttr);
+        flatWorkGroupSizeAttrHelper.setAttr(op, flatSizeAttr);
       }
     });
   }
@@ -355,8 +357,7 @@ void mlir::populateGpuToROCDLConversionPatterns(
       converter,
       /*allocaAddrSpace=*/ROCDL::ROCDLDialect::kPrivateMemoryAddressSpace,
       /*workgroupAddrSpace=*/ROCDL::ROCDLDialect::kSharedMemoryAddressSpace,
-      StringAttr::get(&converter.getContext(),
-                      ROCDL::ROCDLDialect::getKernelFuncAttrName()));
+      ROCDL::ROCDLDialect::KernelAttrHelper(&converter.getContext()).getName());
   if (Runtime::HIP == runtime) {
     patterns.add<GPUPrintfOpToHIPLowering>(converter);
   } else if (Runtime::OpenCL == runtime) {
diff --git a/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp
index 26e46b31ddc018..0f2e75cd7e8bc7 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp
@@ -253,9 +253,9 @@ void ROCDLDialect::initialize() {
 LogicalResult ROCDLDialect::verifyOperationAttribute(Operation *op,
                                                      NamedAttribute attr) {
   // Kernel function attribute should be attached to functions.
-  if (attr.getName() == ROCDLDialect::getKernelFuncAttrName()) {
+  if (kernelAttrName.getName() == attr.getName()) {
     if (!isa<LLVM::LLVMFuncOp>(op)) {
-      return op->emitError() << "'" << ROCDLDialect::getKernelFuncAttrName()
+      return op->emitError() << "'" << kernelAttrName.getName()
                              << "' attribute attached to unexpected op";
     }
   }
diff --git a/mlir/lib/TableGen/Dialect.cpp b/mlir/lib/TableGen/Dialect.cpp
index 6924a2862eef07..081f6e56f9ded4 100644
--- a/mlir/lib/TableGen/Dialect.cpp
+++ b/mlir/lib/TableGen/Dialect.cpp
@@ -106,6 +106,10 @@ bool Dialect::usePropertiesForAttributes() const {
   return def->getValueAsBit("usePropertiesForAttributes");
 }
 
+llvm::DagInit *Dialect::getDiscardableAttributes() const {
+  return def->getValueAsDag("discardableAttrs");
+}
+
 bool Dialect::operator==(const Dialect &other) const {
   return def == other.def;
 }
diff --git a/mlir/lib/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.cpp
index 55a6285ec87eb4..6783ffcde6d531 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.cpp
@@ -84,7 +84,8 @@ class ROCDLDialectLLVMIRTranslationInterface
   amendOperation(Operation *op, ArrayRef<llvm::Instruction *> instructions,
                  NamedAttribute attribute,
                  LLVM::ModuleTranslation &moduleTranslation) const final {
-    if (attribute.getName() == ROCDL::ROCDLDialect::getKernelFuncAttrName()) {
+    auto *dialect = dyn_cast<ROCDL::ROCDLDialect>(attribute.getNameDialect());
+    if (dialect->getKernelAttrHelper().getName() == attribute.getName()) {
       auto func = dyn_cast<LLVM::LLVMFuncOp>(op);
       if (!func)
         return failure();
@@ -106,7 +107,8 @@ class ROCDLDialectLLVMIRTranslationInterface
     // Override flat-work-group-size
     // TODO: update clients to rocdl.flat_work_group_size instead,
     // then remove this half of the branch
-    if ("rocdl.max_flat_work_group_size" == attribute.getName()) {
+    if (dialect->getMaxFlatWorkGroupSizeAttrHelper().getName() ==
+        attribute.getName()) {
       auto func = dyn_cast<LLVM::LLVMFuncOp>(op);
       if (!func)
         return failure();
@@ -121,7 +123,7 @@ class ROCDLDialectLLVMIRTranslationInterface
       attrValueStream << "1," << value.getInt();
       llvmFunc->addFnAttr("amdgpu-flat-work-group-size", llvmAttrValue);
     }
-    if (ROCDL::ROCDLDialect::getFlatWorkGroupSizeAttrName() ==
+    if (dialect->getFlatWorkGroupSizeAttrHelper().getName() ==
         attribute.getName()) {
       auto func = dyn_cast<LLVM::LLVMFuncOp>(op);
       if (!func)
@@ -138,7 +140,7 @@ class ROCDLDialectLLVMIRTranslationInterface
     }
 
     // Set reqd_work_group_size metadata
-    if (ROCDL::ROCDLDialect::getReqdWorkGroupSizeAttrName() ==
+    if (dialect->getReqdWorkGroupSizeAttrHelper().getName() ==
         attribute.getName()) {
       auto func = dyn_cast<LLVM::LLVMFuncOp>(op);
       if (!func)
diff --git a/mlir/test/lib/Dialect/Test/TestDialect.td b/mlir/test/lib/Dialect/Test/TestDialect.td
index 8524d5b1458447..2b5491fc0c6a02 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.td
+++ b/mlir/test/lib/Dialect/Test/TestDialect.td
@@ -25,6 +25,10 @@ def Test_Dialect : Dialect {
   let useDefaultAttributePrinterParser = 1;
   let isExtensible = 1;
   let dependentDialects = ["::mlir::DLTIDialect"];
+  let discardableAttrs = (ins
+     "mlir::IntegerAttr":$discardable_attr_key,
+     "SimpleAAttr":$other_discardable_attr_key
+  );
 
   let extraClassDeclaration = [{
     void registerAttributes();
diff --git a/mlir/tools/mlir-tblgen/DialectGen.cpp b/mlir/tools/mlir-tblgen/DialectGen.cpp
index f22434f755abe3..a0ebd9ce8f29e3 100644
--- a/mlir/tools/mlir-tblgen/DialectGen.cpp
+++ b/mlir/tools/mlir-tblgen/DialectGen.cpp
@@ -43,6 +43,21 @@ using DialectFilterIterator =
                           std::function<bool(const llvm::Record *)>>;
 } // namespace
 
+static void populateDiscardableAttributes(
+    Dialect &dialect, llvm::DagInit *discardableAttrDag,
+    SmallVector<std::pair<std::string, std::string>> &discardableAttributes) {
+  for (int i : llvm::seq<int>(0, discardableAttrDag->getNumArgs())) {
+    llvm::Init *arg = discardableAttrDag->getArg(i);
+
+    StringRef givenName = discardableAttrDag->getArgNameStr(i);
+    if (givenName.empty())
+      PrintFatalError(dialect.getDef()->getLoc(),
+                      "discardable attributes must be named");
+    discardableAttributes.push_back(
+        {givenName.str(), arg->getAsUnquotedString()});
+  }
+}
+
 /// Given a set of records for a T, filter the ones that correspond to
 /// the given dialect.
 template <typename T>
@@ -181,6 +196,37 @@ static const char *const operationInterfaceFallbackDecl = R"(
                                       mlir::OperationName opName) override;
 )";
 
+/// The code block for the discardable attribute helper.
+static const char *const discardableAttrHelperDecl = R"(
+    /// Helper to manage the discardable attribute `{1}`.
+    class {0}AttrHelper {{
+      mlir::StringAttr name;
+    public:
+      static constexpr llvm::StringLiteral getNameStr() {{
+        return "{4}.{1}";
+      }
+      constexpr mlir::StringAttr getName() {{
+        return name;
+      }
+
+      {0}AttrHelper(mlir::MLIRContext *ctx)
+        : name(mlir::StringAttr::get(ctx, getNameStr())) {{}
+
+     {2} getAttr(::mlir::Operation *op) {{
+       return op->getAttrOfType<{2}>(getName());
+     }
+     void setAttr(::mlir::Operation *op, {2} val) {{
+       op->setAttr(getName(), val);
+     }
+   };
+   {0}AttrHelper get{0}AttrHelper() {
+     return {3}AttrName;
+   }
+ private:
+   {0}AttrHelper {3}AttrName;
+ public:
+)";
+
 /// Generate the declaration for the given dialect class.
 static void emitDialectDecl(Dialect &dialect, raw_ostream &os) {
   // Emit all nested namespaces.
@@ -216,6 +262,22 @@ static void emitDialectDecl(Dialect &dialect, raw_ostream &os) {
       os << regionResultAttrVerifierDecl;
     if (dialect.hasOperationInterfaceFallback())
       os << operationInterfaceFallbackDecl;
+
+    llvm::DagInit *discardableAttrDag = dialect.getDiscardableAttributes();
+    SmallVector<std::pair<std::string, std::string>> discardableAttributes;
+    populateDiscardableAttributes(dialect, discardableAttrDag,
+                                  discardableAttributes);
+
+    for (const auto &attrPair : discardableAttributes) {
+      std::string camelNameUpper = llvm::convertToCamelFromSnakeCase(
+          attrPair.first, /*capitalizeFirst=*/true);
+      std::string camelName = llvm::convertToCamelFromSnakeCase(
+          attrPair.first, /*capitalizeFirst=*/false);
+      os << llvm::formatv(discardableAttrHelperDecl, camelNameUpper,
+                          attrPair.first, attrPair.second, camelName,
+                          dialect.getName());
+    }
+
     if (std::optional<StringRef> extraDecl = dialect.getExtraClassDeclaration())
       os << *extraDecl;
 
@@ -253,9 +315,12 @@ static bool emitDialectDecls(const llvm::RecordKeeper &recordKeeper,
 /// {1}: initialization code that is emitted in the ctor body before calling
 ///      initialize().
 /// {2}: The dialect parent class.
+/// {3}: Extra members to initialize
 static const char *const dialectConstructorStr = R"(
 {0}::{0}(::mlir::MLIRContext *context)
-    : ::mlir::{2}(getDialectNamespace(), context, ::mlir::TypeID::get<{0}>()) {{
+    : ::mlir::{2}(getDialectNamespace(), context, ::mlir::TypeID::get<{0}>())
+    {3}
+     {{
   {1}
   initialize();
 }
@@ -269,7 +334,9 @@ static const char *const dialectDestructorStr = R"(
 
 )";
 
-static void emitDialectDef(Dialect &dialect, raw_ostream &os) {
+static void emitDialectDef(Dialect &dialect,
+                           const llvm::RecordKeeper &recordKeeper,
+                           raw_ostream &os) {
   std::string cppClassName = dialect.getCppClassName();
 
   // Emit the TypeID explicit specializations to have a single symbol def.
@@ -292,8 +359,22 @@ static void emitDialectDef(Dialect &dialect, raw_ostream &os) {
   // Emit the constructor and destructor.
   StringRef superClassName =
       dialect.isExtensible() ? "ExtensibleDialect" : "Dialect";
+
+  llvm::DagInit *discardableAttrDag = dialect.getDiscardableAttributes();
+  SmallVector<std::pair<std::string, std::string>> discardableAttributes;
+  populateDiscardableAttributes(dialect, discardableAttrDag,
+                                discardableAttributes);
+  std::string discardableAttributesInit;
+  for (const auto &attrPair : discardableAttributes) {
+    std::string camelName = llvm::convertToCamelFromSnakeCase(
+        attrPair.first, /*capitalizeFirst=*/false);
+    llvm::raw_string_ostream os(discardableAttributesInit);
+    os << ", " << camelName << "AttrName(context)";
+  }
+
   os << llvm::formatv(dialectConstructorStr, cppClassName,
-                      dependentDialectRegistrations, superClassName);
+                      dependentDialectRegistrations, superClassName,
+                      discardableAttributesInit);
   if (!dialect.hasNonDefaultDestructor())
     os << llvm::formatv(dialectDestructorStr, cppClassName);
 }
@@ -310,7 +391,7 @@ static bool emitDialectDefs(const llvm::RecordKeeper &recordKeeper,
   std::optional<Dialect> dialect = findDialectToGenerate(dialects);
   if (!dialect)
     return true;
-  emitDialectDef(*dialect, os);
+  emitDialectDef(*dialect, recordKeeper, os);
   return false;
 }
 

``````````

</details>


https://github.com/llvm/llvm-project/pull/77024


More information about the Mlir-commits mailing list