[Mlir-commits] [mlir] [mlir] Implement OpAsmAttrInterface for some Builtin Attributes (PR #128191)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Feb 21 08:02:12 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-core

@llvm/pr-subscribers-mlir-ods

Author: Hongren Zheng (ZenithalHourlyRate)

<details>
<summary>Changes</summary>

After the introduction of `OpAsmAttrInterface` for alias in #<!-- -->124721, the natural thought to exercise it would be migrating the MLIR existing alias generation method, i.e. `OpAsmDialectInterface`, to use the new interface.

There is a `BuiltinOpAsmDialectInterface` that generates aliases for `AffineMapAttr` and `IntegerSetAttr`, and these attributes could be migrated to use `OpAsmAttrInterface`.

However, the tricky part is that `OpAsmAttrInterface` lives in `OpImplementation.h`. If `BuiltinAttributes.h` includes that, it would become a cyclic inclusion.

Note that only BuiltinAttribute/Type would face such issue as outside user can just include `OpImplementation.h` (see downstream example https://github.com/google/heir/pull/1437)

The dependency is introduced by the fact that `OpAsmAttrInterface` uses `OpAsmDialectInterface::AliasResult`.

There are two solutions to it

1. Separate `OpAsmDialectInterface` into a `OpAsmDialectInterface.h` and other `OpAsm{Attr,Type,Op}Interface` would include that file

    * The current way of putting `OpAsmDialectInterface` in `OpImplementation.h` is not what the name of the header suggests
  
2. Put the `AliasResult` _somewhere_ that all interfaces can include that header safely

    * Note that `AliasAnalysis.h` already defined a `AliasResult` for different purpose, so the name of the class need to change

I currently take the first solution to demonstrate the dependency problem in this PR, both solution could be taken, and the place to put all these interfaces need some discussion.

---
Full diff: https://github.com/llvm/llvm-project/pull/128191.diff


7 Files Affected:

- (modified) mlir/include/mlir/IR/BuiltinAttributeInterfaces.h (+2) 
- (modified) mlir/include/mlir/IR/BuiltinAttributes.h (+1) 
- (modified) mlir/include/mlir/IR/BuiltinAttributes.td (+26-3) 
- (modified) mlir/include/mlir/IR/BuiltinTypeInterfaces.h (+2) 
- (added) mlir/include/mlir/IR/OpAsmDialectInterface.h (+177) 
- (modified) mlir/include/mlir/IR/OpImplementation.h (+1-160) 
- (modified) mlir/lib/IR/BuiltinDialect.cpp (-8) 


``````````diff
diff --git a/mlir/include/mlir/IR/BuiltinAttributeInterfaces.h b/mlir/include/mlir/IR/BuiltinAttributeInterfaces.h
index c4a42020d1389..982d35460ba41 100644
--- a/mlir/include/mlir/IR/BuiltinAttributeInterfaces.h
+++ b/mlir/include/mlir/IR/BuiltinAttributeInterfaces.h
@@ -12,6 +12,7 @@
 #include "mlir/IR/AffineMap.h"
 #include "mlir/IR/Attributes.h"
 #include "mlir/IR/BuiltinTypeInterfaces.h"
+#include "mlir/IR/OpAsmDialectInterface.h"
 #include "mlir/IR/Types.h"
 #include "llvm/Support/raw_ostream.h"
 #include <complex>
@@ -279,6 +280,7 @@ verifyAffineMapAsLayout(AffineMap m, ArrayRef<int64_t> shape,
 //===----------------------------------------------------------------------===//
 
 #include "mlir/IR/BuiltinAttributeInterfaces.h.inc"
+#include "mlir/IR/OpAsmAttrInterface.h.inc"
 
 //===----------------------------------------------------------------------===//
 // ElementsAttr
diff --git a/mlir/include/mlir/IR/BuiltinAttributes.h b/mlir/include/mlir/IR/BuiltinAttributes.h
index 901df3a25a46f..6155d0c65c67d 100644
--- a/mlir/include/mlir/IR/BuiltinAttributes.h
+++ b/mlir/include/mlir/IR/BuiltinAttributes.h
@@ -10,6 +10,7 @@
 #define MLIR_IR_BUILTINATTRIBUTES_H
 
 #include "mlir/IR/BuiltinAttributeInterfaces.h"
+#include "mlir/IR/OpAsmDialectInterface.h"
 #include "llvm/ADT/APFloat.h"
 #include "llvm/ADT/Sequence.h"
 #include <complex>
diff --git a/mlir/include/mlir/IR/BuiltinAttributes.td b/mlir/include/mlir/IR/BuiltinAttributes.td
index 06f5e172a9909..2c86f92686e23 100644
--- a/mlir/include/mlir/IR/BuiltinAttributes.td
+++ b/mlir/include/mlir/IR/BuiltinAttributes.td
@@ -37,7 +37,8 @@ class Builtin_Attr<string name, string attrMnemonic, list<Trait> traits = [],
 //===----------------------------------------------------------------------===//
 
 def Builtin_AffineMapAttr : Builtin_Attr<"AffineMap", "affine_map", [
-    MemRefLayoutAttrInterface
+    MemRefLayoutAttrInterface,
+    OpAsmAttrInterface
   ]> {
   let summary = "An Attribute containing an AffineMap object";
   let description = [{
@@ -63,6 +64,16 @@ def Builtin_AffineMapAttr : Builtin_Attr<"AffineMap", "affine_map", [
   let extraClassDeclaration = [{
     using ValueType = AffineMap;
     AffineMap getAffineMap() const { return getValue(); }
+
+    //===------------------------------------------------------------------===//
+    // OpAsmAttrInterface Methods
+    //===------------------------------------------------------------------===//
+
+    /// Get a name to use when generating an alias for this attribute.
+    ::mlir::OpAsmDialectInterface::AliasResult getAlias(::llvm::raw_ostream &os) const {
+      os << "map";
+      return ::mlir::OpAsmDialectInterface::AliasResult::OverridableAlias;
+    }
   }];
   let skipDefaultBuilders = 1;
 }
@@ -755,7 +766,7 @@ def Builtin_IntegerAttr : Builtin_Attr<"Integer", "integer",
 // IntegerSetAttr
 //===----------------------------------------------------------------------===//
 
-def Builtin_IntegerSetAttr : Builtin_Attr<"IntegerSet", "integer_set"> {
+def Builtin_IntegerSetAttr : Builtin_Attr<"IntegerSet", "integer_set", [OpAsmAttrInterface]> {
   let summary = "An Attribute containing an IntegerSet object";
   let description = [{
     Syntax:
@@ -776,7 +787,19 @@ def Builtin_IntegerSetAttr : Builtin_Attr<"IntegerSet", "integer_set"> {
       return $_get(value.getContext(), value);
     }]>
   ];
-  let extraClassDeclaration = "using ValueType = IntegerSet;";
+  let extraClassDeclaration = [{
+    using ValueType = IntegerSet;
+
+    //===------------------------------------------------------------------===//
+    // OpAsmAttrInterface Methods
+    //===------------------------------------------------------------------===//
+
+    /// Get a name to use when generating an alias for this attribute.
+    ::mlir::OpAsmDialectInterface::AliasResult getAlias(::llvm::raw_ostream &os) const {
+      os << "set";
+      return ::mlir::OpAsmDialectInterface::AliasResult::OverridableAlias;
+    }
+  }];
   let skipDefaultBuilders = 1;
 }
 
diff --git a/mlir/include/mlir/IR/BuiltinTypeInterfaces.h b/mlir/include/mlir/IR/BuiltinTypeInterfaces.h
index e8011b5488dc9..5851c624635bf 100644
--- a/mlir/include/mlir/IR/BuiltinTypeInterfaces.h
+++ b/mlir/include/mlir/IR/BuiltinTypeInterfaces.h
@@ -9,6 +9,7 @@
 #ifndef MLIR_IR_BUILTINTYPEINTERFACES_H
 #define MLIR_IR_BUILTINTYPEINTERFACES_H
 
+#include "mlir/IR/OpAsmDialectInterface.h"
 #include "mlir/IR/Types.h"
 
 namespace llvm {
@@ -21,5 +22,6 @@ class MLIRContext;
 } // namespace mlir
 
 #include "mlir/IR/BuiltinTypeInterfaces.h.inc"
+#include "mlir/IR/OpAsmTypeInterface.h.inc"
 
 #endif // MLIR_IR_BUILTINTYPEINTERFACES_H
diff --git a/mlir/include/mlir/IR/OpAsmDialectInterface.h b/mlir/include/mlir/IR/OpAsmDialectInterface.h
new file mode 100644
index 0000000000000..9965d858daae4
--- /dev/null
+++ b/mlir/include/mlir/IR/OpAsmDialectInterface.h
@@ -0,0 +1,177 @@
+//===- OpAsmDialectInterface.h - OpAsm Dialect Interface --------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_IR_OPASMDIALECTINTERFACE_H
+#define MLIR_IR_OPASMDIALECTINTERFACE_H
+
+#include "mlir/IR/DialectInterface.h"
+#include "mlir/IR/Types.h"
+#include "mlir/IR/Value.h"
+
+namespace mlir {
+class AsmParsedResourceEntry;
+class AsmResourceBuilder;
+
+//===----------------------------------------------------------------------===//
+// AsmDialectResourceHandle
+//===----------------------------------------------------------------------===//
+
+/// This class represents an opaque handle to a dialect resource entry.
+class AsmDialectResourceHandle {
+public:
+  AsmDialectResourceHandle() = default;
+  AsmDialectResourceHandle(void *resource, TypeID resourceID, Dialect *dialect)
+      : resource(resource), opaqueID(resourceID), dialect(dialect) {}
+  bool operator==(const AsmDialectResourceHandle &other) const {
+    return resource == other.resource;
+  }
+
+  /// Return an opaque pointer to the referenced resource.
+  void *getResource() const { return resource; }
+
+  /// Return the type ID of the resource.
+  TypeID getTypeID() const { return opaqueID; }
+
+  /// Return the dialect that owns the resource.
+  Dialect *getDialect() const { return dialect; }
+
+private:
+  /// The opaque handle to the dialect resource.
+  void *resource = nullptr;
+  /// The type of the resource referenced.
+  TypeID opaqueID;
+  /// The dialect owning the given resource.
+  Dialect *dialect;
+};
+
+/// This class represents a CRTP base class for dialect resource handles. It
+/// abstracts away various utilities necessary for defined derived resource
+/// handles.
+template <typename DerivedT, typename ResourceT, typename DialectT>
+class AsmDialectResourceHandleBase : public AsmDialectResourceHandle {
+public:
+  using Dialect = DialectT;
+
+  /// Construct a handle from a pointer to the resource. The given pointer
+  /// should be guaranteed to live beyond the life of this handle.
+  AsmDialectResourceHandleBase(ResourceT *resource, DialectT *dialect)
+      : AsmDialectResourceHandle(resource, TypeID::get<DerivedT>(), dialect) {}
+  AsmDialectResourceHandleBase(AsmDialectResourceHandle handle)
+      : AsmDialectResourceHandle(handle) {
+    assert(handle.getTypeID() == TypeID::get<DerivedT>());
+  }
+
+  /// Return the resource referenced by this handle.
+  ResourceT *getResource() {
+    return static_cast<ResourceT *>(AsmDialectResourceHandle::getResource());
+  }
+  const ResourceT *getResource() const {
+    return const_cast<AsmDialectResourceHandleBase *>(this)->getResource();
+  }
+
+  /// Return the dialect that owns the resource.
+  DialectT *getDialect() const {
+    return static_cast<DialectT *>(AsmDialectResourceHandle::getDialect());
+  }
+
+  /// Support llvm style casting.
+  static bool classof(const AsmDialectResourceHandle *handle) {
+    return handle->getTypeID() == TypeID::get<DerivedT>();
+  }
+};
+
+inline llvm::hash_code hash_value(const AsmDialectResourceHandle &param) {
+  return llvm::hash_value(param.getResource());
+}
+
+//===--------------------------------------------------------------------===//
+// Dialect OpAsm interface.
+//===--------------------------------------------------------------------===//
+
+/// A functor used to set the name of the result. See 'getAsmResultNames' below
+/// for more details.
+using OpAsmSetNameFn = function_ref<void(StringRef)>;
+
+/// A functor used to set the name of the start of a result group of an
+/// operation. See 'getAsmResultNames' below for more details.
+using OpAsmSetValueNameFn = function_ref<void(Value, StringRef)>;
+
+/// A functor used to set the name of blocks in regions directly nested under
+/// an operation.
+using OpAsmSetBlockNameFn = function_ref<void(Block *, StringRef)>;
+
+class OpAsmDialectInterface
+    : public DialectInterface::Base<OpAsmDialectInterface> {
+public:
+  OpAsmDialectInterface(Dialect *dialect) : Base(dialect) {}
+
+  //===------------------------------------------------------------------===//
+  // Aliases
+  //===------------------------------------------------------------------===//
+
+  /// Holds the result of `getAlias` hook call.
+  enum class AliasResult {
+    /// The object (type or attribute) is not supported by the hook
+    /// and an alias was not provided.
+    NoAlias,
+    /// An alias was provided, but it might be overriden by other hook.
+    OverridableAlias,
+    /// An alias was provided and it should be used
+    /// (no other hooks will be checked).
+    FinalAlias
+  };
+
+  /// Hooks for getting an alias identifier alias for a given symbol, that is
+  /// not necessarily a part of this dialect. The identifier is used in place of
+  /// the symbol when printing textual IR. These aliases must not contain `.` or
+  /// end with a numeric digit([0-9]+).
+  virtual AliasResult getAlias(Attribute attr, raw_ostream &os) const {
+    return AliasResult::NoAlias;
+  }
+  virtual AliasResult getAlias(Type type, raw_ostream &os) const {
+    return AliasResult::NoAlias;
+  }
+
+  //===--------------------------------------------------------------------===//
+  // Resources
+  //===--------------------------------------------------------------------===//
+
+  /// Declare a resource with the given key, returning a handle to use for any
+  /// references of this resource key within the IR during parsing. The result
+  /// of `getResourceKey` on the returned handle is permitted to be different
+  /// than `key`.
+  virtual FailureOr<AsmDialectResourceHandle>
+  declareResource(StringRef key) const {
+    return failure();
+  }
+
+  /// Return a key to use for the given resource. This key should uniquely
+  /// identify this resource within the dialect.
+  virtual std::string
+  getResourceKey(const AsmDialectResourceHandle &handle) const {
+    llvm_unreachable(
+        "Dialect must implement `getResourceKey` when defining resources");
+  }
+
+  /// Hook for parsing resource entries. Returns failure if the entry was not
+  /// valid, or could otherwise not be processed correctly. Any necessary errors
+  /// can be emitted via the provided entry.
+  virtual LogicalResult parseResource(AsmParsedResourceEntry &entry) const;
+
+  /// Hook for building resources to use during printing. The given `op` may be
+  /// inspected to help determine what information to include.
+  /// `referencedResources` contains all of the resources detected when printing
+  /// 'op'.
+  virtual void
+  buildResources(Operation *op,
+                 const SetVector<AsmDialectResourceHandle> &referencedResources,
+                 AsmResourceBuilder &builder) const {}
+};
+} // namespace mlir
+
+#endif
diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h
index d80bff2d78217..1e7ab6ac4575e 100644
--- a/mlir/include/mlir/IR/OpImplementation.h
+++ b/mlir/include/mlir/IR/OpImplementation.h
@@ -15,88 +15,15 @@
 
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/DialectInterface.h"
+#include "mlir/IR/OpAsmDialectInterface.h"
 #include "mlir/IR/OpDefinition.h"
 #include "llvm/ADT/Twine.h"
 #include "llvm/Support/SMLoc.h"
 #include <optional>
 
 namespace mlir {
-class AsmParsedResourceEntry;
-class AsmResourceBuilder;
 class Builder;
 
-//===----------------------------------------------------------------------===//
-// AsmDialectResourceHandle
-//===----------------------------------------------------------------------===//
-
-/// This class represents an opaque handle to a dialect resource entry.
-class AsmDialectResourceHandle {
-public:
-  AsmDialectResourceHandle() = default;
-  AsmDialectResourceHandle(void *resource, TypeID resourceID, Dialect *dialect)
-      : resource(resource), opaqueID(resourceID), dialect(dialect) {}
-  bool operator==(const AsmDialectResourceHandle &other) const {
-    return resource == other.resource;
-  }
-
-  /// Return an opaque pointer to the referenced resource.
-  void *getResource() const { return resource; }
-
-  /// Return the type ID of the resource.
-  TypeID getTypeID() const { return opaqueID; }
-
-  /// Return the dialect that owns the resource.
-  Dialect *getDialect() const { return dialect; }
-
-private:
-  /// The opaque handle to the dialect resource.
-  void *resource = nullptr;
-  /// The type of the resource referenced.
-  TypeID opaqueID;
-  /// The dialect owning the given resource.
-  Dialect *dialect;
-};
-
-/// This class represents a CRTP base class for dialect resource handles. It
-/// abstracts away various utilities necessary for defined derived resource
-/// handles.
-template <typename DerivedT, typename ResourceT, typename DialectT>
-class AsmDialectResourceHandleBase : public AsmDialectResourceHandle {
-public:
-  using Dialect = DialectT;
-
-  /// Construct a handle from a pointer to the resource. The given pointer
-  /// should be guaranteed to live beyond the life of this handle.
-  AsmDialectResourceHandleBase(ResourceT *resource, DialectT *dialect)
-      : AsmDialectResourceHandle(resource, TypeID::get<DerivedT>(), dialect) {}
-  AsmDialectResourceHandleBase(AsmDialectResourceHandle handle)
-      : AsmDialectResourceHandle(handle) {
-    assert(handle.getTypeID() == TypeID::get<DerivedT>());
-  }
-
-  /// Return the resource referenced by this handle.
-  ResourceT *getResource() {
-    return static_cast<ResourceT *>(AsmDialectResourceHandle::getResource());
-  }
-  const ResourceT *getResource() const {
-    return const_cast<AsmDialectResourceHandleBase *>(this)->getResource();
-  }
-
-  /// Return the dialect that owns the resource.
-  DialectT *getDialect() const {
-    return static_cast<DialectT *>(AsmDialectResourceHandle::getDialect());
-  }
-
-  /// Support llvm style casting.
-  static bool classof(const AsmDialectResourceHandle *handle) {
-    return handle->getTypeID() == TypeID::get<DerivedT>();
-  }
-};
-
-inline llvm::hash_code hash_value(const AsmDialectResourceHandle &param) {
-  return llvm::hash_value(param.getResource());
-}
-
 //===----------------------------------------------------------------------===//
 // AsmPrinter
 //===----------------------------------------------------------------------===//
@@ -1726,90 +1653,6 @@ class OpAsmParser : public AsmParser {
                               SmallVectorImpl<UnresolvedOperand> &rhs) = 0;
 };
 
-//===--------------------------------------------------------------------===//
-// Dialect OpAsm interface.
-//===--------------------------------------------------------------------===//
-
-/// A functor used to set the name of the result. See 'getAsmResultNames' below
-/// for more details.
-using OpAsmSetNameFn = function_ref<void(StringRef)>;
-
-/// A functor used to set the name of the start of a result group of an
-/// operation. See 'getAsmResultNames' below for more details.
-using OpAsmSetValueNameFn = function_ref<void(Value, StringRef)>;
-
-/// A functor used to set the name of blocks in regions directly nested under
-/// an operation.
-using OpAsmSetBlockNameFn = function_ref<void(Block *, StringRef)>;
-
-class OpAsmDialectInterface
-    : public DialectInterface::Base<OpAsmDialectInterface> {
-public:
-  OpAsmDialectInterface(Dialect *dialect) : Base(dialect) {}
-
-  //===------------------------------------------------------------------===//
-  // Aliases
-  //===------------------------------------------------------------------===//
-
-  /// Holds the result of `getAlias` hook call.
-  enum class AliasResult {
-    /// The object (type or attribute) is not supported by the hook
-    /// and an alias was not provided.
-    NoAlias,
-    /// An alias was provided, but it might be overriden by other hook.
-    OverridableAlias,
-    /// An alias was provided and it should be used
-    /// (no other hooks will be checked).
-    FinalAlias
-  };
-
-  /// Hooks for getting an alias identifier alias for a given symbol, that is
-  /// not necessarily a part of this dialect. The identifier is used in place of
-  /// the symbol when printing textual IR. These aliases must not contain `.` or
-  /// end with a numeric digit([0-9]+).
-  virtual AliasResult getAlias(Attribute attr, raw_ostream &os) const {
-    return AliasResult::NoAlias;
-  }
-  virtual AliasResult getAlias(Type type, raw_ostream &os) const {
-    return AliasResult::NoAlias;
-  }
-
-  //===--------------------------------------------------------------------===//
-  // Resources
-  //===--------------------------------------------------------------------===//
-
-  /// Declare a resource with the given key, returning a handle to use for any
-  /// references of this resource key within the IR during parsing. The result
-  /// of `getResourceKey` on the returned handle is permitted to be different
-  /// than `key`.
-  virtual FailureOr<AsmDialectResourceHandle>
-  declareResource(StringRef key) const {
-    return failure();
-  }
-
-  /// Return a key to use for the given resource. This key should uniquely
-  /// identify this resource within the dialect.
-  virtual std::string
-  getResourceKey(const AsmDialectResourceHandle &handle) const {
-    llvm_unreachable(
-        "Dialect must implement `getResourceKey` when defining resources");
-  }
-
-  /// Hook for parsing resource entries. Returns failure if the entry was not
-  /// valid, or could otherwise not be processed correctly. Any necessary errors
-  /// can be emitted via the provided entry.
-  virtual LogicalResult parseResource(AsmParsedResourceEntry &entry) const;
-
-  /// Hook for building resources to use during printing. The given `op` may be
-  /// inspected to help determine what information to include.
-  /// `referencedResources` contains all of the resources detected when printing
-  /// 'op'.
-  virtual void
-  buildResources(Operation *op,
-                 const SetVector<AsmDialectResourceHandle> &referencedResources,
-                 AsmResourceBuilder &builder) const {}
-};
-
 //===--------------------------------------------------------------------===//
 // Custom printers and parsers.
 //===--------------------------------------------------------------------===//
@@ -1827,9 +1670,7 @@ ParseResult parseDimensionList(OpAsmParser &parser,
 //===--------------------------------------------------------------------===//
 
 /// The OpAsmOpInterface, see OpAsmInterface.td for more details.
-#include "mlir/IR/OpAsmAttrInterface.h.inc"
 #include "mlir/IR/OpAsmOpInterface.h.inc"
-#include "mlir/IR/OpAsmTypeInterface.h.inc"
 
 namespace llvm {
 template <>
diff --git a/mlir/lib/IR/BuiltinDialect.cpp b/mlir/lib/IR/BuiltinDialect.cpp
index 99796c5f1c371..2867d9b4ef68a 100644
--- a/mlir/lib/IR/BuiltinDialect.cpp
+++ b/mlir/lib/IR/BuiltinDialect.cpp
@@ -48,14 +48,6 @@ struct BuiltinOpAsmDialectInterface : public OpAsmDialectInterface {
       : OpAsmDialectInterface(dialect), blobManager(mgr) {}
 
   AliasResult getAlias(Attribute attr, raw_ostream &os) const override {
-    if (llvm::isa<AffineMapAttr>(attr)) {
-      os << "map";
-      return AliasResult::OverridableAlias;
-    }
-    if (llvm::isa<IntegerSetAttr>(attr)) {
-      os << "set";
-      return AliasResult::OverridableAlias;
-    }
     if (llvm::isa<LocationAttr>(attr)) {
       os << "loc";
       return AliasResult::OverridableAlias;

``````````

</details>


https://github.com/llvm/llvm-project/pull/128191


More information about the Mlir-commits mailing list