[Mlir-commits] [mlir] 802bf02 - [mlir] Allows to query traits from types and attributes

Mehdi Amini llvmlistbot at llvm.org
Sun Sep 12 23:29:46 PDT 2021


Author: Mathieu Fehr
Date: 2021-09-13T06:26:45Z
New Revision: 802bf02a738e091d5bf22c03e83204a38d2c7950

URL: https://github.com/llvm/llvm-project/commit/802bf02a738e091d5bf22c03e83204a38d2c7950
DIFF: https://github.com/llvm/llvm-project/commit/802bf02a738e091d5bf22c03e83204a38d2c7950.diff

LOG: [mlir] Allows to query traits from types and attributes

Types and attributes now have a `hasTrait` function that allow users to check
if a type defines a trait.
Also, AbstractType and AbstractAttribute has now a `hasTraitFn` field to carry
the implementation of the `hasTrait` function of the concrete type or attribute.
This patch also adds the remaining functions to access type and attribute traits
in TableGen.

Reviewed By: rriddle

Differential Revision: https://reviews.llvm.org/D105202

Added: 
    mlir/test/lib/Dialect/Test/TestTraits.h

Modified: 
    mlir/include/mlir/IR/AttributeSupport.h
    mlir/include/mlir/IR/Attributes.h
    mlir/include/mlir/IR/OpBase.td
    mlir/include/mlir/IR/StorageUniquerSupport.h
    mlir/include/mlir/IR/TypeSupport.h
    mlir/include/mlir/IR/Types.h
    mlir/test/IR/traits.mlir
    mlir/test/lib/Dialect/Test/TestAttrDefs.td
    mlir/test/lib/Dialect/Test/TestAttributes.h
    mlir/test/lib/Dialect/Test/TestOps.td
    mlir/test/lib/Dialect/Test/TestTypeDefs.td
    mlir/test/lib/Dialect/Test/TestTypes.h

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/IR/AttributeSupport.h b/mlir/include/mlir/IR/AttributeSupport.h
index c84be620b8f0f..d18f0ab8aad70 100644
--- a/mlir/include/mlir/IR/AttributeSupport.h
+++ b/mlir/include/mlir/IR/AttributeSupport.h
@@ -30,14 +30,18 @@ class Type;
 /// a registered Attribute.
 class AbstractAttribute {
 public:
+  using HasTraitFn = llvm::unique_function<bool(TypeID) const>;
+
   /// Look up the specified abstract attribute in the MLIRContext and return a
   /// reference to it.
   static const AbstractAttribute &lookup(TypeID typeID, MLIRContext *context);
 
   /// This method is used by Dialect objects when they register the list of
   /// attributes they contain.
-  template <typename T> static AbstractAttribute get(Dialect &dialect) {
-    return AbstractAttribute(dialect, T::getInterfaceMap(), T::getTypeID());
+  template <typename T>
+  static AbstractAttribute get(Dialect &dialect) {
+    return AbstractAttribute(dialect, T::getInterfaceMap(), T::getHasTraitFn(),
+                             T::getTypeID());
   }
 
   /// Return the dialect this attribute was registered to.
@@ -46,7 +50,8 @@ class AbstractAttribute {
   /// Returns an instance of the concept object for the given interface if it
   /// was registered to this attribute, null otherwise. This should not be used
   /// directly.
-  template <typename T> typename T::Concept *getInterface() const {
+  template <typename T>
+  typename T::Concept *getInterface() const {
     return interfaceMap.lookup<T>();
   }
 
@@ -56,14 +61,23 @@ class AbstractAttribute {
     return interfaceMap.contains(interfaceID);
   }
 
+  /// Returns true if the attribute has a particular trait.
+  template <template <typename T> class Trait>
+  bool hasTrait() const {
+    return hasTraitFn(TypeID::get<Trait>());
+  }
+
+  /// Returns true if the attribute has a particular trait.
+  bool hasTrait(TypeID traitID) const { return hasTraitFn(traitID); }
+
   /// Return the unique identifier representing the concrete attribute class.
   TypeID getTypeID() const { return typeID; }
 
 private:
   AbstractAttribute(Dialect &dialect, detail::InterfaceMap &&interfaceMap,
-                    TypeID typeID)
+                    HasTraitFn &&hasTrait, TypeID typeID)
       : dialect(dialect), interfaceMap(std::move(interfaceMap)),
-        typeID(typeID) {}
+        hasTraitFn(std::move(hasTrait)), typeID(typeID) {}
 
   /// Give StorageUserBase access to the mutable lookup.
   template <typename ConcreteT, typename BaseT, typename StorageT,
@@ -81,6 +95,9 @@ class AbstractAttribute {
   /// This is a collection of the interfaces registered to this attribute.
   detail::InterfaceMap interfaceMap;
 
+  /// Function to check if the attribute has a particular trait.
+  HasTraitFn hasTraitFn;
+
   /// The unique identifier of the derived Attribute class.
   const TypeID typeID;
 };

diff  --git a/mlir/include/mlir/IR/Attributes.h b/mlir/include/mlir/IR/Attributes.h
index 2bc294b3609f6..69e48d7a72561 100644
--- a/mlir/include/mlir/IR/Attributes.h
+++ b/mlir/include/mlir/IR/Attributes.h
@@ -84,6 +84,12 @@ class Attribute {
 
   friend ::llvm::hash_code hash_value(Attribute arg);
 
+  /// Returns true if the type was registered with a particular trait.
+  template <template <typename T> class Trait>
+  bool hasTrait() {
+    return getAbstractAttribute().hasTrait<Trait>();
+  }
+
   /// Return the abstract descriptor for this attribute.
   const AbstractTy &getAbstractAttribute() const {
     return impl->getAbstractAttribute();

diff  --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td
index 4c8b5e5014e0a..cb97424281012 100644
--- a/mlir/include/mlir/IR/OpBase.td
+++ b/mlir/include/mlir/IR/OpBase.td
@@ -1809,6 +1809,40 @@ class PredTrait<string descr, Pred pred> : Trait {
   Pred predicate = pred;
 }
 
+//===----------------------------------------------------------------------===//
+// TypeTrait definitions
+//===----------------------------------------------------------------------===//
+
+// TypeTrait represents a trait regarding a type.
+// TODO: Remove this class in favor of using Trait.
+class TypeTrait;
+
+// These classes are used to define type specific traits.
+class NativeTypeTrait<string name> : NativeTrait<name, "Type">, TypeTrait;
+class ParamNativeTypeTrait<string prop, string params>
+    : ParamNativeTrait<prop, params, "Type">, TypeTrait;
+class GenInternalTypeTrait<string prop>
+    : GenInternalTrait<prop, "Type">, TypeTrait;
+class PredTypeTrait<string descr, Pred pred>
+    : PredTrait<descr, pred>, TypeTrait;
+
+//===----------------------------------------------------------------------===//
+// AttrTrait definitions
+//===----------------------------------------------------------------------===//
+
+// AttrTrait represents a trait regarding an attribute.
+// TODO: Remove this class in favor of using Trait.
+class AttrTrait;
+
+// These classes are used to define attribute specific traits.
+class NativeAttrTrait<string name> : NativeTrait<name, "Attribute">, AttrTrait;
+class ParamNativeAttrTrait<string prop, string params>
+    : ParamNativeTrait<prop, params, "Attribute">, AttrTrait;
+class GenInternalAttrTrait<string prop>
+    : GenInternalTrait<prop, "Attribute">, AttrTrait;
+class PredAttrTrait<string descr, Pred pred>
+    : PredTrait<descr, pred>, AttrTrait;
+
 //===----------------------------------------------------------------------===//
 // OpTrait definitions
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/IR/StorageUniquerSupport.h b/mlir/include/mlir/IR/StorageUniquerSupport.h
index a670d62eddc2e..28b3326e9f75d 100644
--- a/mlir/include/mlir/IR/StorageUniquerSupport.h
+++ b/mlir/include/mlir/IR/StorageUniquerSupport.h
@@ -57,6 +57,25 @@ class StorageUserTraitBase {
 // StorageUserBase
 //===----------------------------------------------------------------------===//
 
+namespace storage_user_base_impl {
+/// Returns true if this given Trait ID matches the IDs of any of the provided
+/// trait types `Traits`.
+template <template <typename T> class... Traits>
+bool hasTrait(TypeID traitID) {
+  TypeID traitIDs[] = {TypeID::get<Traits>()...};
+  for (unsigned i = 0, e = sizeof...(Traits); i != e; ++i)
+    if (traitIDs[i] == traitID)
+      return true;
+  return false;
+}
+
+// We specialize for the empty case to not define an empty array.
+template <>
+inline bool hasTrait(TypeID traitID) {
+  return false;
+}
+} // namespace storage_user_base_impl
+
 /// Utility class for implementing users of storage classes uniqued by a
 /// StorageUniquer. Clients are not expected to interact with this class
 /// directly.
@@ -69,6 +88,7 @@ class StorageUserBase : public BaseT, public Traits<ConcreteT>... {
   /// Utility declarations for the concrete attribute class.
   using Base = StorageUserBase<ConcreteT, BaseT, StorageT, UniquerT, Traits...>;
   using ImplType = StorageT;
+  using HasTraitFn = bool (*)(TypeID);
 
   /// Return a unique identifier for the concrete type.
   static TypeID getTypeID() { return TypeID::get<ConcreteT>(); }
@@ -87,6 +107,14 @@ class StorageUserBase : public BaseT, public Traits<ConcreteT>... {
     return detail::InterfaceMap::template get<Traits<ConcreteT>...>();
   }
 
+  /// Returns the function that returns true if the given Trait ID matches the
+  /// IDs of any of the traits defined by the storage user.
+  static HasTraitFn getHasTraitFn() {
+    return [](TypeID id) {
+      return storage_user_base_impl::hasTrait<Traits...>(id);
+    };
+  }
+
   /// Attach the given models as implementations of the corresponding interfaces
   /// for the concrete storage user class. The type must be registered with the
   /// context, i.e. the dialect to which the type belongs must be loaded. The

diff  --git a/mlir/include/mlir/IR/TypeSupport.h b/mlir/include/mlir/IR/TypeSupport.h
index 40113c41fc229..6a929c6e48e0a 100644
--- a/mlir/include/mlir/IR/TypeSupport.h
+++ b/mlir/include/mlir/IR/TypeSupport.h
@@ -29,14 +29,18 @@ class MLIRContext;
 /// a registered Type.
 class AbstractType {
 public:
+  using HasTraitFn = llvm::unique_function<bool(TypeID) const>;
+
   /// Look up the specified abstract type in the MLIRContext and return a
   /// reference to it.
   static const AbstractType &lookup(TypeID typeID, MLIRContext *context);
 
   /// This method is used by Dialect objects when they register the list of
   /// types they contain.
-  template <typename T> static AbstractType get(Dialect &dialect) {
-    return AbstractType(dialect, T::getInterfaceMap(), T::getTypeID());
+  template <typename T>
+  static AbstractType get(Dialect &dialect) {
+    return AbstractType(dialect, T::getInterfaceMap(), T::getHasTraitFn(),
+                        T::getTypeID());
   }
 
   /// This method is used by Dialect objects to register types with
@@ -44,8 +48,9 @@ class AbstractType {
   /// The use of this method is in general discouraged in favor of
   /// 'get<CustomType>(dialect)';
   static AbstractType get(Dialect &dialect, detail::InterfaceMap &&interfaceMap,
-                          TypeID typeID) {
-    return AbstractType(dialect, std::move(interfaceMap), typeID);
+                          HasTraitFn &&hasTrait, TypeID typeID) {
+    return AbstractType(dialect, std::move(interfaceMap), std::move(hasTrait),
+                        typeID);
   }
 
   /// Return the dialect this type was registered to.
@@ -63,14 +68,23 @@ class AbstractType {
     return interfaceMap.contains(interfaceID);
   }
 
+  /// Returns true if the type has a particular trait.
+  template <template <typename T> class Trait>
+  bool hasTrait() const {
+    return hasTraitFn(TypeID::get<Trait>());
+  }
+
+  /// Returns true if the type has a particular trait.
+  bool hasTrait(TypeID traitID) const { return hasTraitFn(traitID); }
+
   /// Return the unique identifier representing the concrete type class.
   TypeID getTypeID() const { return typeID; }
 
 private:
   AbstractType(Dialect &dialect, detail::InterfaceMap &&interfaceMap,
-               TypeID typeID)
+               HasTraitFn &&hasTrait, TypeID typeID)
       : dialect(dialect), interfaceMap(std::move(interfaceMap)),
-        typeID(typeID) {}
+        hasTraitFn(std::move(hasTrait)), typeID(typeID) {}
 
   /// Give StorageUserBase access to the mutable lookup.
   template <typename ConcreteT, typename BaseT, typename StorageT,
@@ -88,6 +102,9 @@ class AbstractType {
   /// This is a collection of the interfaces registered to this type.
   detail::InterfaceMap interfaceMap;
 
+  /// Function to check if the type has a particular trait.
+  HasTraitFn hasTraitFn;
+
   /// The unique identifier of the derived Type class.
   const TypeID typeID;
 };

diff  --git a/mlir/include/mlir/IR/Types.h b/mlir/include/mlir/IR/Types.h
index a260aa5fc3636..ade0a77b04e09 100644
--- a/mlir/include/mlir/IR/Types.h
+++ b/mlir/include/mlir/IR/Types.h
@@ -169,6 +169,12 @@ class Type {
     return Type(reinterpret_cast<ImplType *>(const_cast<void *>(pointer)));
   }
 
+  /// Returns true if the type was registered with a particular trait.
+  template <template <typename T> class Trait>
+  bool hasTrait() {
+    return getAbstractType().hasTrait<Trait>();
+  }
+
   /// Return the abstract type descriptor for this type.
   const AbstractTy &getAbstractType() { return impl->getAbstractType(); }
 

diff  --git a/mlir/test/IR/traits.mlir b/mlir/test/IR/traits.mlir
index fe4dcacc0ff89..4d15d70b21c02 100644
--- a/mlir/test/IR/traits.mlir
+++ b/mlir/test/IR/traits.mlir
@@ -564,3 +564,39 @@ func @graph_region_cant_have_blocks() {
     "terminator"() : () -> ()
   }
 }
+
+// -----
+
+// Check that we can query traits in types
+func @succeeded_type_traits() {
+  // CHECK: "test.result_type_with_trait"() : () -> !test.test_type_with_trait
+  "test.result_type_with_trait"() : () -> !test.test_type_with_trait
+  return
+}
+
+// -----
+
+// Check that we can query traits in types
+func @failed_type_traits() {
+  // expected-error at +1 {{result type should have trait 'TestTypeTrait'}}
+  "test.result_type_with_trait"() : () -> i32
+  return
+}
+
+// -----
+
+// Check that we can query traits in attributes
+func @succeeded_attr_traits() {
+  // CHECK: "test.attr_with_trait"() {attr = #test.attr_with_trait} : () -> ()
+  "test.attr_with_trait"() {attr = #test.attr_with_trait} : () -> ()
+  return
+}
+
+// -----
+
+// Check that we can query traits in attributes
+func @failed_attr_traits() {
+  // expected-error at +1 {{'attr' attribute should have trait 'TestAttrTrait'}}
+  "test.attr_with_trait"() {attr = 42 : i32} : () -> ()
+  return
+}
\ No newline at end of file

diff  --git a/mlir/test/lib/Dialect/Test/TestAttrDefs.td b/mlir/test/lib/Dialect/Test/TestAttrDefs.td
index 33a94a822a994..2dc583a8cfd03 100644
--- a/mlir/test/lib/Dialect/Test/TestAttrDefs.td
+++ b/mlir/test/lib/Dialect/Test/TestAttrDefs.td
@@ -17,7 +17,8 @@
 include "TestOps.td"
 
 // All of the attributes will extend this class.
-class Test_Attr<string name> : AttrDef<Test_Dialect, name>;
+class Test_Attr<string name, list<Trait> traits = []>
+    : AttrDef<Test_Dialect, name, traits>;
 
 def SimpleAttrA : Test_Attr<"SimpleA"> {
   let mnemonic = "smpla";
@@ -54,4 +55,12 @@ def AttrWithTypeBuilder : Test_Attr<"AttrWithTypeBuilder"> {
   let typeBuilder = "$_attr.getType()";
 }
 
+def TestAttrTrait : NativeAttrTrait<"TestAttrTrait">;
+
+// The definition of a singleton attribute that has a trait.
+def AttrWithTrait : Test_Attr<"AttrWithTrait", [TestAttrTrait]> {
+  let mnemonic = "attr_with_trait";
+  let parameters = (ins );
+}
+
 #endif // TEST_ATTRDEFS

diff  --git a/mlir/test/lib/Dialect/Test/TestAttributes.h b/mlir/test/lib/Dialect/Test/TestAttributes.h
index f1021141729d7..37c409f853ae3 100644
--- a/mlir/test/lib/Dialect/Test/TestAttributes.h
+++ b/mlir/test/lib/Dialect/Test/TestAttributes.h
@@ -16,6 +16,7 @@
 
 #include <tuple>
 
+#include "TestTraits.h"
 #include "mlir/IR/Attributes.h"
 #include "mlir/IR/Diagnostics.h"
 #include "mlir/IR/Dialect.h"

diff  --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index e21ae0ba0f031..fcc0f2861ccf2 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -648,6 +648,30 @@ def DefaultDialectOp : TEST_Op<"default_dialect", [OpAsmOpInterface]> {
   let assemblyFormat = "regions attr-dict-with-keyword";
 }
 
+// This operation requires its return type to have the trait 'TestTypeTrait'.
+def ResultTypeWithTraitOp : TEST_Op<"result_type_with_trait", []> {
+  let results = (outs AnyType);
+
+  let verifier = [{
+    if((*this)->getResultTypes()[0].hasTrait<TypeTrait::TestTypeTrait>())
+      return success();
+    return this->emitError("result type should have trait 'TestTypeTrait'");
+  }];
+}
+
+// This operation requires its "attr" attribute to have the
+// trait 'TestAttrTrait'.
+def AttrWithTraitOp : TEST_Op<"attr_with_trait", []> {
+  let arguments = (ins AnyAttr:$attr);
+
+  let verifier = [{
+    if (this->attr().hasTrait<AttributeTrait::TestAttrTrait>())
+      return success();
+    return this->emitError("'attr' attribute should have trait 'TestAttrTrait'");
+  }];
+}
+
+
 //===----------------------------------------------------------------------===//
 // Test Locations
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/lib/Dialect/Test/TestTraits.h b/mlir/test/lib/Dialect/Test/TestTraits.h
new file mode 100644
index 0000000000000..309d77a29032c
--- /dev/null
+++ b/mlir/test/lib/Dialect/Test/TestTraits.h
@@ -0,0 +1,39 @@
+//===- TestTraits.h - MLIR Test Traits --------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains traits defined by the TestDialect.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TESTTRAITS_H
+#define MLIR_TESTTRAITS_H
+
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Types.h"
+
+namespace mlir {
+namespace TypeTrait {
+
+/// A trait defined on types for testing purposes.
+template <typename ConcreteType>
+class TestTypeTrait : public TypeTrait::TraitBase<ConcreteType, TestTypeTrait> {
+};
+
+} // namespace TypeTrait
+
+namespace AttributeTrait {
+
+/// A trait defined on attributes for testing purposes.
+template <typename ConcreteType>
+class TestAttrTrait
+    : public AttributeTrait::TraitBase<ConcreteType, TestAttrTrait> {};
+
+} // namespace AttributeTrait
+} // namespace mlir
+
+#endif // MLIR_TESTTRAITS_H

diff  --git a/mlir/test/lib/Dialect/Test/TestTypeDefs.td b/mlir/test/lib/Dialect/Test/TestTypeDefs.td
index e11a042766bf0..aef9baa894737 100644
--- a/mlir/test/lib/Dialect/Test/TestTypeDefs.td
+++ b/mlir/test/lib/Dialect/Test/TestTypeDefs.td
@@ -182,4 +182,11 @@ def TestMemRefElementType : Test_Type<"TestMemRefElementType",
   let mnemonic = "memref_element";
 }
 
+def TestTypeTrait : NativeTypeTrait<"TestTypeTrait">;
+
+// The definition of a singleton type that has a trait.
+def TestTypeWithTrait : Test_Type<"TestTypeWithTrait", [TestTypeTrait]> {
+  let mnemonic = "test_type_with_trait";
+}
+
 #endif // TEST_TYPEDEFS

diff  --git a/mlir/test/lib/Dialect/Test/TestTypes.h b/mlir/test/lib/Dialect/Test/TestTypes.h
index 7ee722197a25f..9da2e1713d9d0 100644
--- a/mlir/test/lib/Dialect/Test/TestTypes.h
+++ b/mlir/test/lib/Dialect/Test/TestTypes.h
@@ -16,6 +16,7 @@
 
 #include <tuple>
 
+#include "TestTraits.h"
 #include "mlir/IR/Diagnostics.h"
 #include "mlir/IR/Dialect.h"
 #include "mlir/IR/DialectImplementation.h"


        


More information about the Mlir-commits mailing list