[Mlir-commits] [mlir] [mlir] NamedAttribute utility generator (PR #75118)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Dec 14 12:24:22 PST 2023


https://github.com/sjw36 updated https://github.com/llvm/llvm-project/pull/75118

>From f49cb730203dcf97b5e4669e8926e046054d9b9f Mon Sep 17 00:00:00 2001
From: SJW <swaters at amd.com>
Date: Mon, 11 Dec 2023 22:14:35 +0000
Subject: [PATCH] [mlir] NamedAttribute utility generator

All attributes in MLIR are named, inherent attributes have unscoped names
and discardable attributes should be scoped with a dialect. Current usage is
ad-hoc and much of the codebase is sprinkled with constant strings used to
lookup and set attributes, leading to potential bugs when names are not
updated in all usages.

This PR adds a tablegen'd utility wrapper for a NamedAttribute that manages
scoped/unscoped name lookup for consistent typed access the attribute on an
Operation.
---
 mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td  |  26 ++--
 mlir/include/mlir/IR/AttrTypeBase.td          |  69 ++++++++++
 .../GPUToROCDL/LowerGpuOpsToROCDLOps.cpp      |   9 +-
 mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp   |   4 +-
 .../ROCDL/ROCDLToLLVMIRTranslation.cpp        |  10 +-
 mlir/test/IR/test-named-attrs.mlir            |  64 +++++++++
 mlir/test/lib/Dialect/Test/TestAttrDefs.td    |   7 +
 mlir/test/lib/Dialect/Test/TestAttributes.h   |   2 +
 mlir/test/lib/Dialect/Test/TestNamedAttrs.h   | 126 ++++++++++++++++++
 mlir/test/lib/Dialect/Test/TestOps.td         |   1 +
 mlir/test/lib/IR/CMakeLists.txt               |   1 +
 mlir/test/lib/IR/TestNamedAttrs.cpp           |  95 +++++++++++++
 mlir/tools/mlir-opt/mlir-opt.cpp              |   2 +
 13 files changed, 392 insertions(+), 24 deletions(-)
 create mode 100644 mlir/test/IR/test-named-attrs.mlir
 create mode 100644 mlir/test/lib/Dialect/Test/TestNamedAttrs.h
 create mode 100644 mlir/test/lib/IR/TestNamedAttrs.cpp

diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
index 48b830ae34f292..a40599e91e4b56 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
@@ -28,16 +28,6 @@ def ROCDL_Dialect : Dialect {
   let hasOperationAttrVerify = 1;
 
   let extraClassDeclaration = [{
-    /// Get the name of the attribute used to annotate external kernel
-    /// functions.
-    static StringRef getKernelFuncAttrName() { return "rocdl.kernel"; }
-    static constexpr ::llvm::StringLiteral getFlatWorkGroupSizeAttrName() {
-      return ::llvm::StringLiteral("rocdl.flat_work_group_size");
-    }
-    static constexpr ::llvm::StringLiteral getReqdWorkGroupSizeAttrName() {
-      return ::llvm::StringLiteral("rocdl.reqd_work_group_size");
-    }
-
     /// The address space value that represents global memory.
     static constexpr unsigned kGlobalMemoryAddressSpace = 1;
     /// The address space value that represents shared memory.
@@ -58,6 +48,22 @@ class ROCDL_Attr<string attrName, string attrMnemonic, list<Trait> traits = []>
   let mnemonic = attrMnemonic;
 }
 
+//===----------------------------------------------------------------------===//
+// ROCDL named attr definitions
+//===----------------------------------------------------------------------===//
+
+class ROCDL_NamedAttr<string name, string userName, string baseAttrType = "::mlir::Attribute"> :
+  NamedAttrDef<ROCDL_Dialect, name, userName, baseAttrType>;
+
+def ROCDL_KernelAttr : ROCDL_NamedAttr<"Kernel", "kernel", "::mlir::UnitAttr">;
+def ROCDL_ReqdWorkGroupSizeAttr :
+    ROCDL_NamedAttr<"ReqdWorkGroupSize", "reqd_work_group_size", "::mlir::DenseI32ArrayAttr">;
+def ROCDL_FlatWorkGroupSizeAttr :
+    ROCDL_NamedAttr<"FlatWorkGroupSize", "flat_work_group_size", "::mlir::StringAttr">;
+def ROCDL_MaxFlatWorkGroupSizeAttr :
+    ROCDL_NamedAttr<"MaxFlatWorkGroupSize", "max_flat_work_group_size", "::mlir::IntegerAttr">;
+def ROCDL_WavesPerEuAttr :
+    ROCDL_NamedAttr<"WavesPerEu", "waves_per_eu", "::mlir::IntegerAttr">;
 
 //===----------------------------------------------------------------------===//
 // ROCDL op definitions
diff --git a/mlir/include/mlir/IR/AttrTypeBase.td b/mlir/include/mlir/IR/AttrTypeBase.td
index 91c9283de8bd41..ebf816fa844682 100644
--- a/mlir/include/mlir/IR/AttrTypeBase.td
+++ b/mlir/include/mlir/IR/AttrTypeBase.td
@@ -283,6 +283,75 @@ class AttrDef<Dialect dialect, string name, list<Trait> traits = [],
                                  "::" # cppClassName # ">($_self)">;
 }
 
+// Define a StringAttr wrapper for the NamedAttribute `name`
+// - `name` is dialect-qualified, but mnemonic is based
+// - Utilities to is/has/get/set/lookup/create typed Attr on an Operation
+//   including typed `value` attribute
+class NamedAttrDef<Dialect dialect, string name, string userName,
+    string valueAttrType = "::mlir::Attribute">
+    : AttrDef<dialect, name, [], "::mlir::StringAttr"> {
+  let mnemonic = userName;
+
+  string scopedName = dialect.name # "." # mnemonic;
+  code getNameFunc = "static constexpr llvm::StringLiteral getName() { return \""
+      # scopedName # "\"; }\n";
+  code typedefValueAttr = "typedef " # valueAttrType # " ValueAttrType;\n";
+
+  code namedAttrDecls = !strconcat(typedefValueAttr, getNameFunc, [{
+    // Is or Has
+    static bool is(::mlir::NamedAttribute &attr) {
+      return attr.getName() == getName() && ::llvm::isa<ValueAttrType>(attr.getValue());
+    }
+    static bool isInherent(::mlir::NamedAttribute &attr) {
+      return attr.getName() == getMnemonic();
+    }
+    static bool has(::mlir::Operation *op) {
+      return op->hasAttrOfType<ValueAttrType>(getName());
+    }
+    // Get Name
+    static ::mlir::StringAttr get(::mlir::MLIRContext *ctx) {
+      return ::mlir::StringAttr::get(ctx, getName());
+    }
+    // Get Value
+    static ValueAttrType getValue(::mlir::Operation *op) {
+      return op->getAttrOfType<ValueAttrType>(getName());
+    }
+    // Scoped lookup for inheritance
+    static ValueAttrType lookupValue(::mlir::Operation *op) {
+      if (auto attr = getValue(op))
+        return attr;
+      std::optional<::mlir::RegisteredOperationName> opInfo = op->getRegisteredInfo();
+      if (!opInfo || !opInfo->hasTrait<::mlir::OpTrait::IsIsolatedFromAbove>()) {
+        if (auto *par = op->getParentOp())
+          return lookupValue(par);
+      }
+      return ValueAttrType();
+    }
+    // Set Value on Op
+    static void setValue(::mlir::Operation *op, ValueAttrType val) {
+      assert(op);
+      op->setAttr(getName(), val);
+    }
+    // Remove Value from Op
+    static void removeValue(::mlir::Operation *op) {
+      assert(op);
+      op->removeAttr(getName());
+    }
+    // Create (scoped) NamedAttribute
+    static ::mlir::NamedAttribute create(::mlir::Builder &b, ValueAttrType val);
+  }]);
+
+  code namedAttrDefs = [{
+    // Create (scoped) NamedAttribute
+    ::mlir::NamedAttribute $cppClass::create(::mlir::Builder &b, $cppClass::ValueAttrType val) {
+      return b.getNamedAttr($cppClass::getName(), val);
+    }
+  }];
+
+  let extraClassDeclaration = namedAttrDecls;
+  let extraClassDefinition = namedAttrDefs;
+}
+
 // Define a new type, named `name`, belonging to `dialect` that inherits from
 // the given C++ base class.
 class TypeDef<Dialect dialect, string name, list<Trait> traits = [],
diff --git a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
index 599bb13190f12d..81342e0679a7b2 100644
--- a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
+++ b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
@@ -291,8 +291,7 @@ struct LowerGpuOpsToROCDLOpsPass
     m.walk([ctx](LLVM::LLVMFuncOp op) {
       if (auto blockSizes = dyn_cast_or_null<DenseI32ArrayAttr>(
               op->removeAttr(gpu::GPUFuncOp::getKnownBlockSizeAttrName()))) {
-        op->setAttr(ROCDL::ROCDLDialect::getReqdWorkGroupSizeAttrName(),
-                    blockSizes);
+        ROCDL::ReqdWorkGroupSizeAttr::setValue(op, blockSizes);
         // Also set up the rocdl.flat_work_group_size attribute to prevent
         // conflicting metadata.
         uint32_t flatSize = 1;
@@ -301,8 +300,7 @@ struct LowerGpuOpsToROCDLOpsPass
         }
         StringAttr flatSizeAttr =
             StringAttr::get(ctx, Twine(flatSize) + "," + Twine(flatSize));
-        op->setAttr(ROCDL::ROCDLDialect::getFlatWorkGroupSizeAttrName(),
-                    flatSizeAttr);
+        ROCDL::FlatWorkGroupSizeAttr::setValue(op, flatSizeAttr);
       }
     });
   }
@@ -355,8 +353,7 @@ void mlir::populateGpuToROCDLConversionPatterns(
       converter,
       /*allocaAddrSpace=*/ROCDL::ROCDLDialect::kPrivateMemoryAddressSpace,
       /*workgroupAddrSpace=*/ROCDL::ROCDLDialect::kSharedMemoryAddressSpace,
-      StringAttr::get(&converter.getContext(),
-                      ROCDL::ROCDLDialect::getKernelFuncAttrName()));
+      ROCDL::KernelAttr::get(&converter.getContext()));
   if (Runtime::HIP == runtime) {
     patterns.add<GPUPrintfOpToHIPLowering>(converter);
   } else if (Runtime::OpenCL == runtime) {
diff --git a/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp
index 26e46b31ddc018..078d026ac5222f 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp
@@ -253,9 +253,9 @@ void ROCDLDialect::initialize() {
 LogicalResult ROCDLDialect::verifyOperationAttribute(Operation *op,
                                                      NamedAttribute attr) {
   // Kernel function attribute should be attached to functions.
-  if (attr.getName() == ROCDLDialect::getKernelFuncAttrName()) {
+  if (ROCDL::KernelAttr::is(attr)) {
     if (!isa<LLVM::LLVMFuncOp>(op)) {
-      return op->emitError() << "'" << ROCDLDialect::getKernelFuncAttrName()
+      return op->emitError() << "'" << ROCDL::KernelAttr::getName()
                              << "' attribute attached to unexpected op";
     }
   }
diff --git a/mlir/lib/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.cpp
index 5ab70280f6c818..0942bd2b5f3b38 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.cpp
@@ -83,7 +83,7 @@ class ROCDLDialectLLVMIRTranslationInterface
   LogicalResult
   amendOperation(Operation *op, NamedAttribute attribute,
                  LLVM::ModuleTranslation &moduleTranslation) const final {
-    if (attribute.getName() == ROCDL::ROCDLDialect::getKernelFuncAttrName()) {
+    if (ROCDL::KernelAttr::is(attribute)) {
       auto func = dyn_cast<LLVM::LLVMFuncOp>(op);
       if (!func)
         return failure();
@@ -105,7 +105,7 @@ class ROCDLDialectLLVMIRTranslationInterface
     // Override flat-work-group-size
     // TODO: update clients to rocdl.flat_work_group_size instead,
     // then remove this half of the branch
-    if ("rocdl.max_flat_work_group_size" == attribute.getName()) {
+    if (ROCDL::MaxFlatWorkGroupSizeAttr::is(attribute)) {
       auto func = dyn_cast<LLVM::LLVMFuncOp>(op);
       if (!func)
         return failure();
@@ -120,8 +120,7 @@ class ROCDLDialectLLVMIRTranslationInterface
       attrValueStream << "1," << value.getInt();
       llvmFunc->addFnAttr("amdgpu-flat-work-group-size", llvmAttrValue);
     }
-    if (ROCDL::ROCDLDialect::getFlatWorkGroupSizeAttrName() ==
-        attribute.getName()) {
+    if (ROCDL::FlatWorkGroupSizeAttr::is(attribute)) {
       auto func = dyn_cast<LLVM::LLVMFuncOp>(op);
       if (!func)
         return failure();
@@ -137,8 +136,7 @@ class ROCDLDialectLLVMIRTranslationInterface
     }
 
     // Set reqd_work_group_size metadata
-    if (ROCDL::ROCDLDialect::getReqdWorkGroupSizeAttrName() ==
-        attribute.getName()) {
+    if (ROCDL::ReqdWorkGroupSizeAttr::is(attribute)) {
       auto func = dyn_cast<LLVM::LLVMFuncOp>(op);
       if (!func)
         return failure();
diff --git a/mlir/test/IR/test-named-attrs.mlir b/mlir/test/IR/test-named-attrs.mlir
new file mode 100644
index 00000000000000..e9b4a97035b5b3
--- /dev/null
+++ b/mlir/test/IR/test-named-attrs.mlir
@@ -0,0 +1,64 @@
+// RUN: mlir-opt %s -test-named-attrs -split-input-file --verify-diagnostics | FileCheck %s
+
+func.func @f_unit_attr() attributes {test.test_unit} { // expected-remark {{found unit attr}}
+  %0:2 = "test.producer"() : () -> (i32, i32)
+  return
+}
+
+// -----
+
+func.func @f_unit_attr_fail() attributes {test.test_unit_fail} { // expected-error {{missing unit attr}}
+  %0:2 = "test.producer"() : () -> (i32, i32)
+  return
+}
+
+// -----
+
+func.func @f_int_attr() attributes {test.test_int = 42 : i32} { // expected-remark {{correct int value}}
+  %0:2 = "test.producer"() : () -> (i32, i32)
+  return
+}
+
+// -----
+
+func.func @f_int_attr_fail() attributes {test.test_int = 24 : i32} { // expected-error {{wrong int value}}
+  %0:2 = "test.producer"() : () -> (i32, i32)
+  return
+}
+
+// -----
+
+func.func @f_int_attr_fail2() attributes {test.test_int_fail = 42 : i64} { // expected-error {{missing int attr}}
+  %0:2 = "test.producer"() : () -> (i32, i32)
+  return
+}
+
+// -----
+
+func.func @f_lookup_attr() attributes {test.test_int = 42 : i64} { // expected-remark {{lookup found attr}}
+  %0:2 = "test.producer"() : () -> (i32, i32) // expected-remark {{lookup found attr}}
+  return // expected-remark {{lookup found attr}}
+}
+
+// -----
+
+func.func @f_lookup_attr2() { // expected-error {{lookup failed}}
+  "test.any_attr_of_i32_str"() {attr = 3 : i32, test.test_int = 24 : i32} : () -> () // expected-remark {{lookup found attr}}
+  return // expected-error {{lookup failed}}
+}
+
+// -----
+
+func.func @f_lookup_attr_fail() attributes {test.test_int_fail = 42 : i64} { // expected-error {{lookup failed}}
+  %0:2 = "test.producer"() : () -> (i32, i32) // expected-error {{lookup failed}}
+  return // expected-error {{lookup failed}}
+}
+
+// -----
+
+// CHECK: func.func @f_set_attr() attributes {test.test_int = 42 : i32}
+func.func @f_set_attr() { // expected-remark {{set int attr}}
+  %0:2 = "test.producer"() : () -> (i32, i32)
+  return
+}
+
diff --git a/mlir/test/lib/Dialect/Test/TestAttrDefs.td b/mlir/test/lib/Dialect/Test/TestAttrDefs.td
index 945c54c04d47ce..42af8a5dde4c23 100644
--- a/mlir/test/lib/Dialect/Test/TestAttrDefs.td
+++ b/mlir/test/lib/Dialect/Test/TestAttrDefs.td
@@ -333,6 +333,13 @@ def TestCopyCount : Test_Attr<"TestCopyCount"> {
 }
 
 
+// Test NamedAttr attributes
+class Test_NamedAttr<string name, string userName, string baseAttrType = "::mlir::Attribute"> :
+  NamedAttrDef<Test_Dialect, name, userName, baseAttrType>;
+
+def Test_NamedUnitAttr : Test_NamedAttr<"TestNamedUnit", "test_unit", "::mlir::UnitAttr">;
+def Test_NamedIntAttr : Test_NamedAttr<"TestNamedInt", "test_int", "::mlir::IntegerAttr">;
+
 
 
 #endif // TEST_ATTRDEFS
diff --git a/mlir/test/lib/Dialect/Test/TestAttributes.h b/mlir/test/lib/Dialect/Test/TestAttributes.h
index ef6eae51fdd628..393572e1a3c8ee 100644
--- a/mlir/test/lib/Dialect/Test/TestAttributes.h
+++ b/mlir/test/lib/Dialect/Test/TestAttributes.h
@@ -43,6 +43,8 @@ class CopyCount {
 llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
                               const test::CopyCount &value);
 
+#include "mlir/IR/Operation.h"
+
 /// A handle used to reference external elements instances.
 using TestDialectResourceBlobHandle =
     mlir::DialectResourceBlobHandle<TestDialect>;
diff --git a/mlir/test/lib/Dialect/Test/TestNamedAttrs.h b/mlir/test/lib/Dialect/Test/TestNamedAttrs.h
new file mode 100644
index 00000000000000..24cd96fc665873
--- /dev/null
+++ b/mlir/test/lib/Dialect/Test/TestNamedAttrs.h
@@ -0,0 +1,126 @@
+//===- TestTypes.h - MLIR Test Dialect Types --------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains types defined by the TestDialect for testing various
+// features of MLIR.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TESTTYPES_H
+#define MLIR_TESTTYPES_H
+
+#include <optional>
+#include <tuple>
+
+#include "TestTraits.h"
+#include "mlir/IR/Diagnostics.h"
+#include "mlir/IR/Dialect.h"
+#include "mlir/IR/DialectImplementation.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/Types.h"
+#include "mlir/Interfaces/DataLayoutInterfaces.h"
+
+namespace test {
+class TestAttrWithFormatAttr;
+
+/// FieldInfo represents a field in the StructType data type. It is used as a
+/// parameter in TestTypeDefs.td.
+struct FieldInfo {
+  ::llvm::StringRef name;
+  ::mlir::Type type;
+
+  // Custom allocation called from generated constructor code
+  FieldInfo allocateInto(::mlir::TypeStorageAllocator &alloc) const {
+    return FieldInfo{alloc.copyInto(name), type};
+  }
+};
+
+/// A custom type for a test type parameter.
+struct CustomParam {
+  int value;
+
+  bool operator==(const CustomParam &other) const {
+    return other.value == value;
+  }
+};
+
+inline llvm::hash_code hash_value(const test::CustomParam &param) {
+  return llvm::hash_value(param.value);
+}
+
+} // namespace test
+
+namespace mlir {
+template <>
+struct FieldParser<test::CustomParam> {
+  static FailureOr<test::CustomParam> parse(AsmParser &parser) {
+    auto value = FieldParser<int>::parse(parser);
+    if (failed(value))
+      return failure();
+    return test::CustomParam{*value};
+  }
+};
+
+inline mlir::AsmPrinter &operator<<(mlir::AsmPrinter &printer,
+                                    test::CustomParam param) {
+  return printer << param.value;
+}
+
+/// Overload the attribute parameter parser for optional integers.
+template <>
+struct FieldParser<std::optional<int>> {
+  static FailureOr<std::optional<int>> parse(AsmParser &parser) {
+    std::optional<int> value;
+    value.emplace();
+    OptionalParseResult result = parser.parseOptionalInteger(*value);
+    if (result.has_value()) {
+      if (succeeded(*result))
+        return value;
+      return failure();
+    }
+    value.reset();
+    return value;
+  }
+};
+} // namespace mlir
+
+#include "TestTypeInterfaces.h.inc"
+
+namespace test {
+
+/// Simple recursive type identified by its name and pointing to another named
+/// type, potentially itself. This requires the body to be mutated separately
+/// from type creation.
+class TestRecursiveType
+    : public ::mlir::Type::TypeBase<TestRecursiveType, ::mlir::Type,
+                                    TestRecursiveTypeStorage,
+                                    ::mlir::TypeTrait::IsMutable> {
+public:
+  using Base::Base;
+
+  static constexpr ::mlir::StringLiteral name = "test.recursive";
+
+  static TestRecursiveType get(::mlir::MLIRContext *ctx,
+                               ::llvm::StringRef name) {
+    return Base::get(ctx, name);
+  }
+
+  /// Body getter and setter.
+  ::mlir::LogicalResult setBody(Type body) { return Base::mutate(body); }
+  ::mlir::Type getBody() const { return getImpl()->body; }
+
+  /// Name/key getter.
+  ::llvm::StringRef getName() { return getImpl()->name; }
+};
+
+} // namespace test
+
+#define GET_TYPEDEF_CLASSES
+#include "TestTypeDefs.h.inc"
+
+#endif // MLIR_TESTTYPES_H
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 96f66c2ca06ecf..5629ed4ac91ca8 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -15,6 +15,7 @@ include "mlir/Dialect/DLTI/DLTIBase.td"
 include "mlir/Dialect/Linalg/IR/LinalgInterfaces.td"
 include "mlir/IR/EnumAttr.td"
 include "mlir/Interfaces/FunctionInterfaces.td"
+include "mlir/IR/AttrTypeBase.td"
 include "mlir/IR/OpBase.td"
 include "mlir/IR/OpAsmInterface.td"
 include "mlir/IR/PatternBase.td"
diff --git a/mlir/test/lib/IR/CMakeLists.txt b/mlir/test/lib/IR/CMakeLists.txt
index 69c63fd7e524b6..605a5b6ad7bedd 100644
--- a/mlir/test/lib/IR/CMakeLists.txt
+++ b/mlir/test/lib/IR/CMakeLists.txt
@@ -10,6 +10,7 @@ add_mlir_library(MLIRTestIR
   TestInterfaces.cpp
   TestMatchers.cpp
   TestLazyLoading.cpp
+  TestNamedAttrs.cpp
   TestOpaqueLoc.cpp
   TestOperationEquals.cpp
   TestPrintDefUse.cpp
diff --git a/mlir/test/lib/IR/TestNamedAttrs.cpp b/mlir/test/lib/IR/TestNamedAttrs.cpp
new file mode 100644
index 00000000000000..f1167dead460e9
--- /dev/null
+++ b/mlir/test/lib/IR/TestNamedAttrs.cpp
@@ -0,0 +1,95 @@
+//===- TestNamedAttrs.cpp - Test passes for MLIR types -------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "TestDialect.h"
+#include "TestAttributes.h"
+#include "mlir/Pass/Pass.h"
+
+using namespace mlir;
+using namespace test;
+
+namespace {
+struct TestNamedAttrsPass
+    : public PassWrapper<TestNamedAttrsPass, OperationPass<func::FuncOp>> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestNamedAttrsPass)
+
+  StringRef getArgument() const final { return "test-named-attrs"; }
+  StringRef getDescription() const final {
+    return "Test support for recursive types";
+  }
+  void runOnOperation() override {
+    func::FuncOp func = getOperation();
+
+    auto funcName = func.getName();
+    // Just make sure recursive types are printed and parsed.
+    if (funcName.contains("f_unit_attr")) {
+      if (test::TestNamedUnitAttr::has(func)) {
+        func.emitRemark() << "found unit attr";
+      } else {
+        func.emitOpError() << "missing unit attr";
+        signalPassFailure();
+      }
+      return;
+    }
+
+    if (funcName.contains("f_int_attr")) {
+      if (test::TestNamedIntAttr::has(func)) {
+        if (test::TestNamedIntAttr::getValue(func).getInt() == 42) {
+          func.emitRemark() << "correct int value";
+        } else {
+          func.emitOpError() << "wrong int value";
+          signalPassFailure();
+        }
+        return;
+      } else {
+        func.emitOpError() << "missing int attr";
+        signalPassFailure();
+      }
+      return;
+    }
+    
+    if (funcName.contains("f_lookup_attr")) {
+      func.walk([&](Operation *op) {
+        if (test::TestNamedIntAttr::lookupValue(op)) {
+          op->emitRemark() << "lookup found attr";
+        } else {
+          op->emitOpError() << "lookup failed";
+          signalPassFailure();
+        }
+      });
+      return;
+    }
+    
+    if (funcName.contains("f_set_attr")) {
+      if (!test::TestNamedIntAttr::has(func)) {
+        auto intTy = IntegerType::get(func.getContext(), 32);
+        test::TestNamedIntAttr::setValue(func, IntegerAttr::get(intTy, 42));
+        func.emitRemark() << "set int attr";
+      } else {
+        func.emitOpError() << "attr already set";
+        signalPassFailure();
+      }
+      return;
+    }
+    
+    // Unknown key.
+    func.emitOpError() << "unexpected function name";
+    signalPassFailure();
+  }
+};
+} // namespace
+
+namespace mlir {
+namespace test {
+
+void registerTestNamedAttrsPass() {
+  PassRegistration<TestNamedAttrsPass>();
+}
+
+} // namespace test
+} // namespace mlir
diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index 3e3223b4850560..232aef81cfaa84 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -118,6 +118,7 @@ void registerTestMathAlgebraicSimplificationPass();
 void registerTestMathPolynomialApproximationPass();
 void registerTestMemRefDependenceCheck();
 void registerTestMemRefStrideCalculation();
+void registerTestNamedAttrsPass();
 void registerTestNextAccessPass();
 void registerTestOneToNTypeConversionPass();
 void registerTestOpaqueLoc();
@@ -238,6 +239,7 @@ void registerTestPasses() {
   mlir::test::registerTestMathPolynomialApproximationPass();
   mlir::test::registerTestMemRefDependenceCheck();
   mlir::test::registerTestMemRefStrideCalculation();
+  mlir::test::registerTestNamedAttrsPass();
   mlir::test::registerTestNextAccessPass();
   mlir::test::registerTestOneToNTypeConversionPass();
   mlir::test::registerTestOpaqueLoc();



More information about the Mlir-commits mailing list