[Mlir-commits] [mlir] [mlir] NamedAttribute utility generator (PR #75118)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Dec 11 15:47:11 PST 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-gpu

Author: SJW (sjw36)

<details>
<summary>Changes</summary>

All attributes in MLIR are named, inherent attributes have unscoped names and arbitrary attributes should be scoped with a dialect. Current usage is ad-hoc and much of the codebase is sprinkled with constant strings used to lookup and set attributes, leading to potential bugs when names are not updated in all usages.

This PR adds a tablegen'd utility wrapper for a NamedAttribute that manages scoped/unscoped name lookup for accessing the attribute on an Operation based on inherentness.

---
Full diff: https://github.com/llvm/llvm-project/pull/75118.diff


5 Files Affected:

- (modified) mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td (+16-10) 
- (modified) mlir/include/mlir/IR/AttrTypeBase.td (+72) 
- (modified) mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp (+3-6) 
- (modified) mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp (+2-2) 
- (modified) mlir/lib/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.cpp (+4-6) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
index 48b830ae34f29..a40599e91e4b5 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
@@ -28,16 +28,6 @@ def ROCDL_Dialect : Dialect {
   let hasOperationAttrVerify = 1;
 
   let extraClassDeclaration = [{
-    /// Get the name of the attribute used to annotate external kernel
-    /// functions.
-    static StringRef getKernelFuncAttrName() { return "rocdl.kernel"; }
-    static constexpr ::llvm::StringLiteral getFlatWorkGroupSizeAttrName() {
-      return ::llvm::StringLiteral("rocdl.flat_work_group_size");
-    }
-    static constexpr ::llvm::StringLiteral getReqdWorkGroupSizeAttrName() {
-      return ::llvm::StringLiteral("rocdl.reqd_work_group_size");
-    }
-
     /// The address space value that represents global memory.
     static constexpr unsigned kGlobalMemoryAddressSpace = 1;
     /// The address space value that represents shared memory.
@@ -58,6 +48,22 @@ class ROCDL_Attr<string attrName, string attrMnemonic, list<Trait> traits = []>
   let mnemonic = attrMnemonic;
 }
 
+//===----------------------------------------------------------------------===//
+// ROCDL named attr definitions
+//===----------------------------------------------------------------------===//
+
+class ROCDL_NamedAttr<string name, string userName, string baseAttrType = "::mlir::Attribute"> :
+  NamedAttrDef<ROCDL_Dialect, name, userName, baseAttrType>;
+
+def ROCDL_KernelAttr : ROCDL_NamedAttr<"Kernel", "kernel", "::mlir::UnitAttr">;
+def ROCDL_ReqdWorkGroupSizeAttr :
+    ROCDL_NamedAttr<"ReqdWorkGroupSize", "reqd_work_group_size", "::mlir::DenseI32ArrayAttr">;
+def ROCDL_FlatWorkGroupSizeAttr :
+    ROCDL_NamedAttr<"FlatWorkGroupSize", "flat_work_group_size", "::mlir::StringAttr">;
+def ROCDL_MaxFlatWorkGroupSizeAttr :
+    ROCDL_NamedAttr<"MaxFlatWorkGroupSize", "max_flat_work_group_size", "::mlir::IntegerAttr">;
+def ROCDL_WavesPerEuAttr :
+    ROCDL_NamedAttr<"WavesPerEu", "waves_per_eu", "::mlir::IntegerAttr">;
 
 //===----------------------------------------------------------------------===//
 // ROCDL op definitions
diff --git a/mlir/include/mlir/IR/AttrTypeBase.td b/mlir/include/mlir/IR/AttrTypeBase.td
index 91c9283de8bd4..eeabbf7e06471 100644
--- a/mlir/include/mlir/IR/AttrTypeBase.td
+++ b/mlir/include/mlir/IR/AttrTypeBase.td
@@ -283,6 +283,78 @@ class AttrDef<Dialect dialect, string name, list<Trait> traits = [],
                                  "::" # cppClassName # ">($_self)">;
 }
 
+// Define a StringAttr wrapper for the NamedAttribute `name`
+// - `name` is dialect-scoped when not-inherent.
+// - Utilities to is/has/get/set/lookup/create typed Attr on an Operation
+//   including typed `value` attribute
+class NamedAttrDef<Dialect dialect, string name, string userName,
+    string valueAttrType = "::mlir::Attribute">
+    : AttrDef<dialect, name, [], "::mlir::StringAttr"> {
+  let mnemonic = userName;
+
+  string scopedName = dialect.name # "." # mnemonic;
+  code typedefValueAttr = "typedef " # valueAttrType # " ValueAttrType;\n";
+  code getNameFunc = "static constexpr llvm::StringLiteral getScopedName() { return \""
+      # scopedName # "\"; }\n";
+
+  code namedAttrFuncs = !strconcat(typedefValueAttr, getNameFunc, [{
+    // Get name based on inherentness
+    static llvm::StringLiteral getName(Operation *op = nullptr) {
+      if (op && op->getPropertiesStorageSize()) {
+       auto mnemonic = getMnemonic();
+       if (op->getInherentAttr(mnemonic))
+         return mnemonic;
+      }
+      return getScopedName();
+    }
+    // Is or Has
+    static bool is(::mlir::NamedAttribute attr) {
+      return attr.getName() == getScopedName();
+    }
+    static bool isInherent(::mlir::NamedAttribute attr) {
+      return attr.getName() == getMnemonic();
+    }
+    static bool has(::mlir::Operation *op) {
+      return op->hasAttr(getName(op));
+    }
+    // Get Name
+    static ::mlir::StringAttr get(::mlir::MLIRContext *ctx, ::mlir::Operation *op = nullptr) {
+      return ::mlir::StringAttr::get(ctx, getName(op));
+    }
+    // Get Value
+    static ValueAttrType getValue(::mlir::Operation *op) {
+      return op->getAttrOfType<ValueAttrType>(getName(op));
+    }
+    // Scoped lookup for inheritance
+    static ValueAttrType lookupValue(::mlir::Operation *op) {
+      if (auto attr = getValue(op))
+        return attr;
+      std::optional<RegisteredOperationName> opInfo = op->getRegisteredInfo();
+      if (!opInfo || !opInfo->hasTrait<OpTrait::IsIsolatedFromAbove>()) {
+        if (auto *par = op->getParentOp())
+          return lookupValue(par);
+      }
+      return ValueAttrType();
+    }
+    // Set Value on Op
+    static void setValue(::mlir::Operation *op, ValueAttrType val) {
+      assert(op);
+      op->setAttr(getName(op), val);
+    }
+    // Remove Value from Op
+    static void removeValue(::mlir::Operation *op) {
+      assert(op);
+      op->removeAttr(getName(op));
+    }
+    // Create (scoped) NamedAttribute
+    static ::mlir::NamedAttribute create(::mlir::Builder &b, ValueAttrType val) {
+      return b.getNamedAttr(getScopedName(), val);
+    }
+  }]);
+
+  let extraClassDeclaration = namedAttrFuncs;
+}
+
 // Define a new type, named `name`, belonging to `dialect` that inherits from
 // the given C++ base class.
 class TypeDef<Dialect dialect, string name, list<Trait> traits = [],
diff --git a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
index 599bb13190f12..81342e0679a7b 100644
--- a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
+++ b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
@@ -291,8 +291,7 @@ struct LowerGpuOpsToROCDLOpsPass
     m.walk([ctx](LLVM::LLVMFuncOp op) {
       if (auto blockSizes = dyn_cast_or_null<DenseI32ArrayAttr>(
               op->removeAttr(gpu::GPUFuncOp::getKnownBlockSizeAttrName()))) {
-        op->setAttr(ROCDL::ROCDLDialect::getReqdWorkGroupSizeAttrName(),
-                    blockSizes);
+        ROCDL::ReqdWorkGroupSizeAttr::setValue(op, blockSizes);
         // Also set up the rocdl.flat_work_group_size attribute to prevent
         // conflicting metadata.
         uint32_t flatSize = 1;
@@ -301,8 +300,7 @@ struct LowerGpuOpsToROCDLOpsPass
         }
         StringAttr flatSizeAttr =
             StringAttr::get(ctx, Twine(flatSize) + "," + Twine(flatSize));
-        op->setAttr(ROCDL::ROCDLDialect::getFlatWorkGroupSizeAttrName(),
-                    flatSizeAttr);
+        ROCDL::FlatWorkGroupSizeAttr::setValue(op, flatSizeAttr);
       }
     });
   }
@@ -355,8 +353,7 @@ void mlir::populateGpuToROCDLConversionPatterns(
       converter,
       /*allocaAddrSpace=*/ROCDL::ROCDLDialect::kPrivateMemoryAddressSpace,
       /*workgroupAddrSpace=*/ROCDL::ROCDLDialect::kSharedMemoryAddressSpace,
-      StringAttr::get(&converter.getContext(),
-                      ROCDL::ROCDLDialect::getKernelFuncAttrName()));
+      ROCDL::KernelAttr::get(&converter.getContext()));
   if (Runtime::HIP == runtime) {
     patterns.add<GPUPrintfOpToHIPLowering>(converter);
   } else if (Runtime::OpenCL == runtime) {
diff --git a/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp
index 26e46b31ddc01..078d026ac5222 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp
@@ -253,9 +253,9 @@ void ROCDLDialect::initialize() {
 LogicalResult ROCDLDialect::verifyOperationAttribute(Operation *op,
                                                      NamedAttribute attr) {
   // Kernel function attribute should be attached to functions.
-  if (attr.getName() == ROCDLDialect::getKernelFuncAttrName()) {
+  if (ROCDL::KernelAttr::is(attr)) {
     if (!isa<LLVM::LLVMFuncOp>(op)) {
-      return op->emitError() << "'" << ROCDLDialect::getKernelFuncAttrName()
+      return op->emitError() << "'" << ROCDL::KernelAttr::getName()
                              << "' attribute attached to unexpected op";
     }
   }
diff --git a/mlir/lib/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.cpp
index 5ab70280f6c81..0942bd2b5f3b3 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.cpp
@@ -83,7 +83,7 @@ class ROCDLDialectLLVMIRTranslationInterface
   LogicalResult
   amendOperation(Operation *op, NamedAttribute attribute,
                  LLVM::ModuleTranslation &moduleTranslation) const final {
-    if (attribute.getName() == ROCDL::ROCDLDialect::getKernelFuncAttrName()) {
+    if (ROCDL::KernelAttr::is(attribute)) {
       auto func = dyn_cast<LLVM::LLVMFuncOp>(op);
       if (!func)
         return failure();
@@ -105,7 +105,7 @@ class ROCDLDialectLLVMIRTranslationInterface
     // Override flat-work-group-size
     // TODO: update clients to rocdl.flat_work_group_size instead,
     // then remove this half of the branch
-    if ("rocdl.max_flat_work_group_size" == attribute.getName()) {
+    if (ROCDL::MaxFlatWorkGroupSizeAttr::is(attribute)) {
       auto func = dyn_cast<LLVM::LLVMFuncOp>(op);
       if (!func)
         return failure();
@@ -120,8 +120,7 @@ class ROCDLDialectLLVMIRTranslationInterface
       attrValueStream << "1," << value.getInt();
       llvmFunc->addFnAttr("amdgpu-flat-work-group-size", llvmAttrValue);
     }
-    if (ROCDL::ROCDLDialect::getFlatWorkGroupSizeAttrName() ==
-        attribute.getName()) {
+    if (ROCDL::FlatWorkGroupSizeAttr::is(attribute)) {
       auto func = dyn_cast<LLVM::LLVMFuncOp>(op);
       if (!func)
         return failure();
@@ -137,8 +136,7 @@ class ROCDLDialectLLVMIRTranslationInterface
     }
 
     // Set reqd_work_group_size metadata
-    if (ROCDL::ROCDLDialect::getReqdWorkGroupSizeAttrName() ==
-        attribute.getName()) {
+    if (ROCDL::ReqdWorkGroupSizeAttr::is(attribute)) {
       auto func = dyn_cast<LLVM::LLVMFuncOp>(op);
       if (!func)
         return failure();

``````````

</details>


https://github.com/llvm/llvm-project/pull/75118


More information about the Mlir-commits mailing list