[Mlir-commits] [mlir] [mlir] Add a MappableContainer trait. (PR #99493)

Alexander Belyaev llvmlistbot at llvm.org
Fri Jul 19 01:00:15 PDT 2024


https://github.com/pifon2a updated https://github.com/llvm/llvm-project/pull/99493

>From 151af4a23a244739faf258caa75ed714c58ce878 Mon Sep 17 00:00:00 2001
From: Alexander Belyaev <pifon at google.com>
Date: Fri, 19 Jul 2024 09:58:56 +0200
Subject: [PATCH] [mlir] Add a MappableContainer trait.

This is needed for downstream users to define their custom vector and tensor
types that can work with the arith/math dialect.

RFC https://discourse.llvm.org/t/rfc-mlir-types-with-encoding/80189
---
 mlir/include/mlir/IR/BuiltinTypes.h           | 10 +++++++
 mlir/include/mlir/IR/BuiltinTypes.td          | 19 ++++++++++--
 mlir/include/mlir/IR/CommonTypeConstraints.td | 30 +++++++++++++++----
 3 files changed, 52 insertions(+), 7 deletions(-)

diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h
index 5579b138668d2..542a1352ae4dc 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.h
+++ b/mlir/include/mlir/IR/BuiltinTypes.h
@@ -39,6 +39,16 @@ struct IntegerTypeStorage;
 struct TupleTypeStorage;
 } // namespace detail
 
+/// Type trait indicating that the type can be an operand to an elementwise op.
+template <typename ConcreteType>
+class MappableContainer
+    : public TypeTrait::TraitBase<ConcreteType, MappableContainer> {};
+
+/// Type trait indicating that the type has value semantics.
+template <typename ConcreteType>
+class ValueSemantics
+    : public TypeTrait::TraitBase<ConcreteType, ValueSemantics> {};
+
 //===----------------------------------------------------------------------===//
 // FloatType
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index 4cade83dd3c32..82ee3c148c8b8 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -30,6 +30,20 @@ class Builtin_Type<string name, string typeMnemonic, list<Trait> traits = [],
   let typeName = "builtin." # typeMnemonic;
 }
 
+//===----------------------------------------------------------------------===//
+// Traits
+//===----------------------------------------------------------------------===//
+
+/// Type trait indicating that the type can be an operand to an elementwise op.
+def MappableContainer : NativeTypeTrait<"MappableContainer"> {
+  let cppNamespace = "::mlir";
+}
+
+/// Type trait indicating that the type has value semantics.
+def ValueSemantics : NativeTypeTrait<"ValueSemantics"> {
+  let cppNamespace = "::mlir";
+}
+
 //===----------------------------------------------------------------------===//
 // ComplexType
 //===----------------------------------------------------------------------===//
@@ -745,7 +759,7 @@ def Builtin_Opaque : Builtin_Type<"Opaque", "opaque"> {
 //===----------------------------------------------------------------------===//
 
 def Builtin_RankedTensor : Builtin_Type<"RankedTensor", "tensor", [
-    ShapedTypeInterface
+    MappableContainer, ShapedTypeInterface, ValueSemantics
   ], "TensorType"> {
   let summary = "Multi-dimensional array with a fixed number of dimensions";
   let description = [{
@@ -1049,7 +1063,8 @@ def Builtin_UnrankedTensor : Builtin_Type<"UnrankedTensor", "unranked_tensor", [
 // VectorType
 //===----------------------------------------------------------------------===//
 
-def Builtin_Vector : Builtin_Type<"Vector", "vector", [ShapedTypeInterface], "Type"> {
+def Builtin_Vector : Builtin_Type<"Vector", "vector",
+    [MappableContainer, ShapedTypeInterface, ValueSemantics], "Type"> {
   let summary = "Multi-dimensional SIMD vector type";
   let description = [{
     Syntax:
diff --git a/mlir/include/mlir/IR/CommonTypeConstraints.td b/mlir/include/mlir/IR/CommonTypeConstraints.td
index af4f13dc09360..61c1f47b1d6dc 100644
--- a/mlir/include/mlir/IR/CommonTypeConstraints.td
+++ b/mlir/include/mlir/IR/CommonTypeConstraints.td
@@ -89,6 +89,12 @@ def HasStaticShapePred :
 // Whether a type is a TupleType.
 def IsTupleTypePred : CPred<"::llvm::isa<::mlir::TupleType>($_self)">;
 
+// Whether a type has a MappableContainer trait.
+def IsMappableContainerPred : CPred<"$_self.hasTrait<MappableContainer>()">;
+
+// Whether a type has a ValueSemantics trait.
+def HasValueSemanticsPred : CPred<"$_self.hasTrait<ValueSemantics>()">;
+
 //===----------------------------------------------------------------------===//
 // Type definitions
 //===----------------------------------------------------------------------===//
@@ -403,6 +409,12 @@ class HasRankGreaterOrEqualPred<int rank> : And<[
     CPred<[{::llvm::cast<::mlir::ShapedType>($_self).getRank() >= }] # rank>
 ]>;
 
+// Mappable types with value semantics.
+class ValueSemanticsMappableContainerOf<list<Type> allowedTypes> :
+  ShapedContainerType<allowedTypes,
+  And<[HasValueSemanticsPred, IsMappableContainerPred]>,
+  "mappable container with value semantics">;
+
 // Vector types.
 
 class VectorOf<list<Type> allowedTypes> :
@@ -842,10 +854,18 @@ class NestedTupleOf<list<Type> allowedTypes> :
 // Common type constraints
 //===----------------------------------------------------------------------===//
 // Type constraint for types that are "like" some type or set of types T, that is
-// they're either a T, a vector of Ts, or a tensor of Ts
+// they're either a T, a vector of Ts, or a tensor of Ts.
 class TypeOrContainer<Type allowedType, string name> : TypeConstraint<Or<[
-  allowedType.predicate, VectorOf<[allowedType]>.predicate,
-  TensorOf<[allowedType]>.predicate]>,
+  allowedType.predicate,
+  ValueSemanticsMappableContainerOf<[allowedType]>.predicate]>,
+  name>;
+
+// Type constraint for types that are "like" some type or set of types T, that is
+// they're either a T or a mapable container of Ts.
+class TypeOrValueSemanticsMappableContainer<Type allowedType, string name>
+    : TypeConstraint<Or<[
+  allowedType.predicate,
+  ValueSemanticsMappableContainerOf<[allowedType]>.predicate]>,
   name>;
 
 // Temporary constraint to allow gradual transition to supporting 0-D vectors.
@@ -864,8 +884,8 @@ def BoolLikeOfAnyRank : TypeOrContainerOfAnyRank<I1, "bool-like">;
 
 // Type constraint for signless-integer-like types: signless integers, indices,
 // vectors of signless integers or indices, tensors of signless integers.
-def SignlessIntegerLike : TypeOrContainer<AnySignlessIntegerOrIndex,
-    "signless-integer-like">;
+def SignlessIntegerLike : TypeOrValueSemanticsMappableContainer<
+    AnySignlessIntegerOrIndex, "signless-integer-like">;
 
 def SignlessIntegerLikeOfAnyRank : TypeOrContainerOfAnyRank<
     AnySignlessIntegerOrIndex,



More information about the Mlir-commits mailing list