[Mlir-commits] [mlir] 802bf02 - [mlir] Allows to query traits from types and attributes
Mehdi Amini
llvmlistbot at llvm.org
Sun Sep 12 23:29:46 PDT 2021
Author: Mathieu Fehr
Date: 2021-09-13T06:26:45Z
New Revision: 802bf02a738e091d5bf22c03e83204a38d2c7950
URL: https://github.com/llvm/llvm-project/commit/802bf02a738e091d5bf22c03e83204a38d2c7950
DIFF: https://github.com/llvm/llvm-project/commit/802bf02a738e091d5bf22c03e83204a38d2c7950.diff
LOG: [mlir] Allows to query traits from types and attributes
Types and attributes now have a `hasTrait` function that allow users to check
if a type defines a trait.
Also, AbstractType and AbstractAttribute has now a `hasTraitFn` field to carry
the implementation of the `hasTrait` function of the concrete type or attribute.
This patch also adds the remaining functions to access type and attribute traits
in TableGen.
Reviewed By: rriddle
Differential Revision: https://reviews.llvm.org/D105202
Added:
mlir/test/lib/Dialect/Test/TestTraits.h
Modified:
mlir/include/mlir/IR/AttributeSupport.h
mlir/include/mlir/IR/Attributes.h
mlir/include/mlir/IR/OpBase.td
mlir/include/mlir/IR/StorageUniquerSupport.h
mlir/include/mlir/IR/TypeSupport.h
mlir/include/mlir/IR/Types.h
mlir/test/IR/traits.mlir
mlir/test/lib/Dialect/Test/TestAttrDefs.td
mlir/test/lib/Dialect/Test/TestAttributes.h
mlir/test/lib/Dialect/Test/TestOps.td
mlir/test/lib/Dialect/Test/TestTypeDefs.td
mlir/test/lib/Dialect/Test/TestTypes.h
Removed:
################################################################################
diff --git a/mlir/include/mlir/IR/AttributeSupport.h b/mlir/include/mlir/IR/AttributeSupport.h
index c84be620b8f0f..d18f0ab8aad70 100644
--- a/mlir/include/mlir/IR/AttributeSupport.h
+++ b/mlir/include/mlir/IR/AttributeSupport.h
@@ -30,14 +30,18 @@ class Type;
/// a registered Attribute.
class AbstractAttribute {
public:
+ using HasTraitFn = llvm::unique_function<bool(TypeID) const>;
+
/// Look up the specified abstract attribute in the MLIRContext and return a
/// reference to it.
static const AbstractAttribute &lookup(TypeID typeID, MLIRContext *context);
/// This method is used by Dialect objects when they register the list of
/// attributes they contain.
- template <typename T> static AbstractAttribute get(Dialect &dialect) {
- return AbstractAttribute(dialect, T::getInterfaceMap(), T::getTypeID());
+ template <typename T>
+ static AbstractAttribute get(Dialect &dialect) {
+ return AbstractAttribute(dialect, T::getInterfaceMap(), T::getHasTraitFn(),
+ T::getTypeID());
}
/// Return the dialect this attribute was registered to.
@@ -46,7 +50,8 @@ class AbstractAttribute {
/// Returns an instance of the concept object for the given interface if it
/// was registered to this attribute, null otherwise. This should not be used
/// directly.
- template <typename T> typename T::Concept *getInterface() const {
+ template <typename T>
+ typename T::Concept *getInterface() const {
return interfaceMap.lookup<T>();
}
@@ -56,14 +61,23 @@ class AbstractAttribute {
return interfaceMap.contains(interfaceID);
}
+ /// Returns true if the attribute has a particular trait.
+ template <template <typename T> class Trait>
+ bool hasTrait() const {
+ return hasTraitFn(TypeID::get<Trait>());
+ }
+
+ /// Returns true if the attribute has a particular trait.
+ bool hasTrait(TypeID traitID) const { return hasTraitFn(traitID); }
+
/// Return the unique identifier representing the concrete attribute class.
TypeID getTypeID() const { return typeID; }
private:
AbstractAttribute(Dialect &dialect, detail::InterfaceMap &&interfaceMap,
- TypeID typeID)
+ HasTraitFn &&hasTrait, TypeID typeID)
: dialect(dialect), interfaceMap(std::move(interfaceMap)),
- typeID(typeID) {}
+ hasTraitFn(std::move(hasTrait)), typeID(typeID) {}
/// Give StorageUserBase access to the mutable lookup.
template <typename ConcreteT, typename BaseT, typename StorageT,
@@ -81,6 +95,9 @@ class AbstractAttribute {
/// This is a collection of the interfaces registered to this attribute.
detail::InterfaceMap interfaceMap;
+ /// Function to check if the attribute has a particular trait.
+ HasTraitFn hasTraitFn;
+
/// The unique identifier of the derived Attribute class.
const TypeID typeID;
};
diff --git a/mlir/include/mlir/IR/Attributes.h b/mlir/include/mlir/IR/Attributes.h
index 2bc294b3609f6..69e48d7a72561 100644
--- a/mlir/include/mlir/IR/Attributes.h
+++ b/mlir/include/mlir/IR/Attributes.h
@@ -84,6 +84,12 @@ class Attribute {
friend ::llvm::hash_code hash_value(Attribute arg);
+ /// Returns true if the type was registered with a particular trait.
+ template <template <typename T> class Trait>
+ bool hasTrait() {
+ return getAbstractAttribute().hasTrait<Trait>();
+ }
+
/// Return the abstract descriptor for this attribute.
const AbstractTy &getAbstractAttribute() const {
return impl->getAbstractAttribute();
diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td
index 4c8b5e5014e0a..cb97424281012 100644
--- a/mlir/include/mlir/IR/OpBase.td
+++ b/mlir/include/mlir/IR/OpBase.td
@@ -1809,6 +1809,40 @@ class PredTrait<string descr, Pred pred> : Trait {
Pred predicate = pred;
}
+//===----------------------------------------------------------------------===//
+// TypeTrait definitions
+//===----------------------------------------------------------------------===//
+
+// TypeTrait represents a trait regarding a type.
+// TODO: Remove this class in favor of using Trait.
+class TypeTrait;
+
+// These classes are used to define type specific traits.
+class NativeTypeTrait<string name> : NativeTrait<name, "Type">, TypeTrait;
+class ParamNativeTypeTrait<string prop, string params>
+ : ParamNativeTrait<prop, params, "Type">, TypeTrait;
+class GenInternalTypeTrait<string prop>
+ : GenInternalTrait<prop, "Type">, TypeTrait;
+class PredTypeTrait<string descr, Pred pred>
+ : PredTrait<descr, pred>, TypeTrait;
+
+//===----------------------------------------------------------------------===//
+// AttrTrait definitions
+//===----------------------------------------------------------------------===//
+
+// AttrTrait represents a trait regarding an attribute.
+// TODO: Remove this class in favor of using Trait.
+class AttrTrait;
+
+// These classes are used to define attribute specific traits.
+class NativeAttrTrait<string name> : NativeTrait<name, "Attribute">, AttrTrait;
+class ParamNativeAttrTrait<string prop, string params>
+ : ParamNativeTrait<prop, params, "Attribute">, AttrTrait;
+class GenInternalAttrTrait<string prop>
+ : GenInternalTrait<prop, "Attribute">, AttrTrait;
+class PredAttrTrait<string descr, Pred pred>
+ : PredTrait<descr, pred>, AttrTrait;
+
//===----------------------------------------------------------------------===//
// OpTrait definitions
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/IR/StorageUniquerSupport.h b/mlir/include/mlir/IR/StorageUniquerSupport.h
index a670d62eddc2e..28b3326e9f75d 100644
--- a/mlir/include/mlir/IR/StorageUniquerSupport.h
+++ b/mlir/include/mlir/IR/StorageUniquerSupport.h
@@ -57,6 +57,25 @@ class StorageUserTraitBase {
// StorageUserBase
//===----------------------------------------------------------------------===//
+namespace storage_user_base_impl {
+/// Returns true if this given Trait ID matches the IDs of any of the provided
+/// trait types `Traits`.
+template <template <typename T> class... Traits>
+bool hasTrait(TypeID traitID) {
+ TypeID traitIDs[] = {TypeID::get<Traits>()...};
+ for (unsigned i = 0, e = sizeof...(Traits); i != e; ++i)
+ if (traitIDs[i] == traitID)
+ return true;
+ return false;
+}
+
+// We specialize for the empty case to not define an empty array.
+template <>
+inline bool hasTrait(TypeID traitID) {
+ return false;
+}
+} // namespace storage_user_base_impl
+
/// Utility class for implementing users of storage classes uniqued by a
/// StorageUniquer. Clients are not expected to interact with this class
/// directly.
@@ -69,6 +88,7 @@ class StorageUserBase : public BaseT, public Traits<ConcreteT>... {
/// Utility declarations for the concrete attribute class.
using Base = StorageUserBase<ConcreteT, BaseT, StorageT, UniquerT, Traits...>;
using ImplType = StorageT;
+ using HasTraitFn = bool (*)(TypeID);
/// Return a unique identifier for the concrete type.
static TypeID getTypeID() { return TypeID::get<ConcreteT>(); }
@@ -87,6 +107,14 @@ class StorageUserBase : public BaseT, public Traits<ConcreteT>... {
return detail::InterfaceMap::template get<Traits<ConcreteT>...>();
}
+ /// Returns the function that returns true if the given Trait ID matches the
+ /// IDs of any of the traits defined by the storage user.
+ static HasTraitFn getHasTraitFn() {
+ return [](TypeID id) {
+ return storage_user_base_impl::hasTrait<Traits...>(id);
+ };
+ }
+
/// Attach the given models as implementations of the corresponding interfaces
/// for the concrete storage user class. The type must be registered with the
/// context, i.e. the dialect to which the type belongs must be loaded. The
diff --git a/mlir/include/mlir/IR/TypeSupport.h b/mlir/include/mlir/IR/TypeSupport.h
index 40113c41fc229..6a929c6e48e0a 100644
--- a/mlir/include/mlir/IR/TypeSupport.h
+++ b/mlir/include/mlir/IR/TypeSupport.h
@@ -29,14 +29,18 @@ class MLIRContext;
/// a registered Type.
class AbstractType {
public:
+ using HasTraitFn = llvm::unique_function<bool(TypeID) const>;
+
/// Look up the specified abstract type in the MLIRContext and return a
/// reference to it.
static const AbstractType &lookup(TypeID typeID, MLIRContext *context);
/// This method is used by Dialect objects when they register the list of
/// types they contain.
- template <typename T> static AbstractType get(Dialect &dialect) {
- return AbstractType(dialect, T::getInterfaceMap(), T::getTypeID());
+ template <typename T>
+ static AbstractType get(Dialect &dialect) {
+ return AbstractType(dialect, T::getInterfaceMap(), T::getHasTraitFn(),
+ T::getTypeID());
}
/// This method is used by Dialect objects to register types with
@@ -44,8 +48,9 @@ class AbstractType {
/// The use of this method is in general discouraged in favor of
/// 'get<CustomType>(dialect)';
static AbstractType get(Dialect &dialect, detail::InterfaceMap &&interfaceMap,
- TypeID typeID) {
- return AbstractType(dialect, std::move(interfaceMap), typeID);
+ HasTraitFn &&hasTrait, TypeID typeID) {
+ return AbstractType(dialect, std::move(interfaceMap), std::move(hasTrait),
+ typeID);
}
/// Return the dialect this type was registered to.
@@ -63,14 +68,23 @@ class AbstractType {
return interfaceMap.contains(interfaceID);
}
+ /// Returns true if the type has a particular trait.
+ template <template <typename T> class Trait>
+ bool hasTrait() const {
+ return hasTraitFn(TypeID::get<Trait>());
+ }
+
+ /// Returns true if the type has a particular trait.
+ bool hasTrait(TypeID traitID) const { return hasTraitFn(traitID); }
+
/// Return the unique identifier representing the concrete type class.
TypeID getTypeID() const { return typeID; }
private:
AbstractType(Dialect &dialect, detail::InterfaceMap &&interfaceMap,
- TypeID typeID)
+ HasTraitFn &&hasTrait, TypeID typeID)
: dialect(dialect), interfaceMap(std::move(interfaceMap)),
- typeID(typeID) {}
+ hasTraitFn(std::move(hasTrait)), typeID(typeID) {}
/// Give StorageUserBase access to the mutable lookup.
template <typename ConcreteT, typename BaseT, typename StorageT,
@@ -88,6 +102,9 @@ class AbstractType {
/// This is a collection of the interfaces registered to this type.
detail::InterfaceMap interfaceMap;
+ /// Function to check if the type has a particular trait.
+ HasTraitFn hasTraitFn;
+
/// The unique identifier of the derived Type class.
const TypeID typeID;
};
diff --git a/mlir/include/mlir/IR/Types.h b/mlir/include/mlir/IR/Types.h
index a260aa5fc3636..ade0a77b04e09 100644
--- a/mlir/include/mlir/IR/Types.h
+++ b/mlir/include/mlir/IR/Types.h
@@ -169,6 +169,12 @@ class Type {
return Type(reinterpret_cast<ImplType *>(const_cast<void *>(pointer)));
}
+ /// Returns true if the type was registered with a particular trait.
+ template <template <typename T> class Trait>
+ bool hasTrait() {
+ return getAbstractType().hasTrait<Trait>();
+ }
+
/// Return the abstract type descriptor for this type.
const AbstractTy &getAbstractType() { return impl->getAbstractType(); }
diff --git a/mlir/test/IR/traits.mlir b/mlir/test/IR/traits.mlir
index fe4dcacc0ff89..4d15d70b21c02 100644
--- a/mlir/test/IR/traits.mlir
+++ b/mlir/test/IR/traits.mlir
@@ -564,3 +564,39 @@ func @graph_region_cant_have_blocks() {
"terminator"() : () -> ()
}
}
+
+// -----
+
+// Check that we can query traits in types
+func @succeeded_type_traits() {
+ // CHECK: "test.result_type_with_trait"() : () -> !test.test_type_with_trait
+ "test.result_type_with_trait"() : () -> !test.test_type_with_trait
+ return
+}
+
+// -----
+
+// Check that we can query traits in types
+func @failed_type_traits() {
+ // expected-error at +1 {{result type should have trait 'TestTypeTrait'}}
+ "test.result_type_with_trait"() : () -> i32
+ return
+}
+
+// -----
+
+// Check that we can query traits in attributes
+func @succeeded_attr_traits() {
+ // CHECK: "test.attr_with_trait"() {attr = #test.attr_with_trait} : () -> ()
+ "test.attr_with_trait"() {attr = #test.attr_with_trait} : () -> ()
+ return
+}
+
+// -----
+
+// Check that we can query traits in attributes
+func @failed_attr_traits() {
+ // expected-error at +1 {{'attr' attribute should have trait 'TestAttrTrait'}}
+ "test.attr_with_trait"() {attr = 42 : i32} : () -> ()
+ return
+}
\ No newline at end of file
diff --git a/mlir/test/lib/Dialect/Test/TestAttrDefs.td b/mlir/test/lib/Dialect/Test/TestAttrDefs.td
index 33a94a822a994..2dc583a8cfd03 100644
--- a/mlir/test/lib/Dialect/Test/TestAttrDefs.td
+++ b/mlir/test/lib/Dialect/Test/TestAttrDefs.td
@@ -17,7 +17,8 @@
include "TestOps.td"
// All of the attributes will extend this class.
-class Test_Attr<string name> : AttrDef<Test_Dialect, name>;
+class Test_Attr<string name, list<Trait> traits = []>
+ : AttrDef<Test_Dialect, name, traits>;
def SimpleAttrA : Test_Attr<"SimpleA"> {
let mnemonic = "smpla";
@@ -54,4 +55,12 @@ def AttrWithTypeBuilder : Test_Attr<"AttrWithTypeBuilder"> {
let typeBuilder = "$_attr.getType()";
}
+def TestAttrTrait : NativeAttrTrait<"TestAttrTrait">;
+
+// The definition of a singleton attribute that has a trait.
+def AttrWithTrait : Test_Attr<"AttrWithTrait", [TestAttrTrait]> {
+ let mnemonic = "attr_with_trait";
+ let parameters = (ins );
+}
+
#endif // TEST_ATTRDEFS
diff --git a/mlir/test/lib/Dialect/Test/TestAttributes.h b/mlir/test/lib/Dialect/Test/TestAttributes.h
index f1021141729d7..37c409f853ae3 100644
--- a/mlir/test/lib/Dialect/Test/TestAttributes.h
+++ b/mlir/test/lib/Dialect/Test/TestAttributes.h
@@ -16,6 +16,7 @@
#include <tuple>
+#include "TestTraits.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/Dialect.h"
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index e21ae0ba0f031..fcc0f2861ccf2 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -648,6 +648,30 @@ def DefaultDialectOp : TEST_Op<"default_dialect", [OpAsmOpInterface]> {
let assemblyFormat = "regions attr-dict-with-keyword";
}
+// This operation requires its return type to have the trait 'TestTypeTrait'.
+def ResultTypeWithTraitOp : TEST_Op<"result_type_with_trait", []> {
+ let results = (outs AnyType);
+
+ let verifier = [{
+ if((*this)->getResultTypes()[0].hasTrait<TypeTrait::TestTypeTrait>())
+ return success();
+ return this->emitError("result type should have trait 'TestTypeTrait'");
+ }];
+}
+
+// This operation requires its "attr" attribute to have the
+// trait 'TestAttrTrait'.
+def AttrWithTraitOp : TEST_Op<"attr_with_trait", []> {
+ let arguments = (ins AnyAttr:$attr);
+
+ let verifier = [{
+ if (this->attr().hasTrait<AttributeTrait::TestAttrTrait>())
+ return success();
+ return this->emitError("'attr' attribute should have trait 'TestAttrTrait'");
+ }];
+}
+
+
//===----------------------------------------------------------------------===//
// Test Locations
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/lib/Dialect/Test/TestTraits.h b/mlir/test/lib/Dialect/Test/TestTraits.h
new file mode 100644
index 0000000000000..309d77a29032c
--- /dev/null
+++ b/mlir/test/lib/Dialect/Test/TestTraits.h
@@ -0,0 +1,39 @@
+//===- TestTraits.h - MLIR Test Traits --------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains traits defined by the TestDialect.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TESTTRAITS_H
+#define MLIR_TESTTRAITS_H
+
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Types.h"
+
+namespace mlir {
+namespace TypeTrait {
+
+/// A trait defined on types for testing purposes.
+template <typename ConcreteType>
+class TestTypeTrait : public TypeTrait::TraitBase<ConcreteType, TestTypeTrait> {
+};
+
+} // namespace TypeTrait
+
+namespace AttributeTrait {
+
+/// A trait defined on attributes for testing purposes.
+template <typename ConcreteType>
+class TestAttrTrait
+ : public AttributeTrait::TraitBase<ConcreteType, TestAttrTrait> {};
+
+} // namespace AttributeTrait
+} // namespace mlir
+
+#endif // MLIR_TESTTRAITS_H
diff --git a/mlir/test/lib/Dialect/Test/TestTypeDefs.td b/mlir/test/lib/Dialect/Test/TestTypeDefs.td
index e11a042766bf0..aef9baa894737 100644
--- a/mlir/test/lib/Dialect/Test/TestTypeDefs.td
+++ b/mlir/test/lib/Dialect/Test/TestTypeDefs.td
@@ -182,4 +182,11 @@ def TestMemRefElementType : Test_Type<"TestMemRefElementType",
let mnemonic = "memref_element";
}
+def TestTypeTrait : NativeTypeTrait<"TestTypeTrait">;
+
+// The definition of a singleton type that has a trait.
+def TestTypeWithTrait : Test_Type<"TestTypeWithTrait", [TestTypeTrait]> {
+ let mnemonic = "test_type_with_trait";
+}
+
#endif // TEST_TYPEDEFS
diff --git a/mlir/test/lib/Dialect/Test/TestTypes.h b/mlir/test/lib/Dialect/Test/TestTypes.h
index 7ee722197a25f..9da2e1713d9d0 100644
--- a/mlir/test/lib/Dialect/Test/TestTypes.h
+++ b/mlir/test/lib/Dialect/Test/TestTypes.h
@@ -16,6 +16,7 @@
#include <tuple>
+#include "TestTraits.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/DialectImplementation.h"
More information about the Mlir-commits
mailing list