[Mlir-commits] [mlir] [mlir] Add a ValueSemantics trait. (PR #99493)

Alexander Belyaev llvmlistbot at llvm.org
Fri Jul 19 12:20:27 PDT 2024


https://github.com/pifon2a updated https://github.com/llvm/llvm-project/pull/99493

>From d79e52156c94a1775fcb16b7b32f1868040ba978 Mon Sep 17 00:00:00 2001
From: Alexander Belyaev <pifon at google.com>
Date: Fri, 19 Jul 2024 21:19:58 +0200
Subject: [PATCH] [mlir] Add a ValueSemantics trait.

We need to distinguish ShapedTypes with and without value semantics. This is needed for downstream users to define their custom vector and tensor
types that can work with the arith/math dialect.

RFC https://discourse.llvm.org/t/rfc-mlir-types-with-encoding/80189
---
 mlir/include/mlir/IR/BuiltinTypes.h           |  5 ++++
 mlir/include/mlir/IR/BuiltinTypes.td          | 14 ++++++++--
 mlir/include/mlir/IR/CommonTypeConstraints.td | 26 +++++++++++++++----
 3 files changed, 38 insertions(+), 7 deletions(-)

diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h
index 5579b138668d2..f04668ed9142c 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.h
+++ b/mlir/include/mlir/IR/BuiltinTypes.h
@@ -39,6 +39,11 @@ struct IntegerTypeStorage;
 struct TupleTypeStorage;
 } // namespace detail
 
+/// Type trait indicating that the type has value semantics.
+template <typename ConcreteType>
+class ValueSemantics
+    : public TypeTrait::TraitBase<ConcreteType, ValueSemantics> {};
+
 //===----------------------------------------------------------------------===//
 // FloatType
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index 4cade83dd3c32..c8f4bea748129 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -30,6 +30,15 @@ class Builtin_Type<string name, string typeMnemonic, list<Trait> traits = [],
   let typeName = "builtin." # typeMnemonic;
 }
 
+//===----------------------------------------------------------------------===//
+// Traits
+//===----------------------------------------------------------------------===//
+
+/// Type trait indicating that the type has value semantics.
+def ValueSemantics : NativeTypeTrait<"ValueSemantics"> {
+  let cppNamespace = "::mlir";
+}
+
 //===----------------------------------------------------------------------===//
 // ComplexType
 //===----------------------------------------------------------------------===//
@@ -745,7 +754,7 @@ def Builtin_Opaque : Builtin_Type<"Opaque", "opaque"> {
 //===----------------------------------------------------------------------===//
 
 def Builtin_RankedTensor : Builtin_Type<"RankedTensor", "tensor", [
-    ShapedTypeInterface
+    ShapedTypeInterface, ValueSemantics
   ], "TensorType"> {
   let summary = "Multi-dimensional array with a fixed number of dimensions";
   let description = [{
@@ -1049,7 +1058,8 @@ def Builtin_UnrankedTensor : Builtin_Type<"UnrankedTensor", "unranked_tensor", [
 // VectorType
 //===----------------------------------------------------------------------===//
 
-def Builtin_Vector : Builtin_Type<"Vector", "vector", [ShapedTypeInterface], "Type"> {
+def Builtin_Vector : Builtin_Type<"Vector", "vector",
+    [ShapedTypeInterface, ValueSemantics], "Type"> {
   let summary = "Multi-dimensional SIMD vector type";
   let description = [{
     Syntax:
diff --git a/mlir/include/mlir/IR/CommonTypeConstraints.td b/mlir/include/mlir/IR/CommonTypeConstraints.td
index af4f13dc09360..9b524f02e5d26 100644
--- a/mlir/include/mlir/IR/CommonTypeConstraints.td
+++ b/mlir/include/mlir/IR/CommonTypeConstraints.td
@@ -89,6 +89,9 @@ def HasStaticShapePred :
 // Whether a type is a TupleType.
 def IsTupleTypePred : CPred<"::llvm::isa<::mlir::TupleType>($_self)">;
 
+// Whether a type has a ValueSemantics trait.
+def HasValueSemanticsPred : CPred<"$_self.hasTrait<::mlir::ValueSemantics>()">;
+
 //===----------------------------------------------------------------------===//
 // Type definitions
 //===----------------------------------------------------------------------===//
@@ -403,6 +406,11 @@ class HasRankGreaterOrEqualPred<int rank> : And<[
     CPred<[{::llvm::cast<::mlir::ShapedType>($_self).getRank() >= }] # rank>
 ]>;
 
+// Container with value semantics.
+class ValueSemanticsContainerOf<list<Type> allowedTypes> :
+  ShapedContainerType<allowedTypes, HasValueSemanticsPred,
+  "container with value semantics">;
+
 // Vector types.
 
 class VectorOf<list<Type> allowedTypes> :
@@ -842,10 +850,18 @@ class NestedTupleOf<list<Type> allowedTypes> :
 // Common type constraints
 //===----------------------------------------------------------------------===//
 // Type constraint for types that are "like" some type or set of types T, that is
-// they're either a T, a vector of Ts, or a tensor of Ts
+// they're either a T, a vector of Ts, or a tensor of Ts.
 class TypeOrContainer<Type allowedType, string name> : TypeConstraint<Or<[
-  allowedType.predicate, VectorOf<[allowedType]>.predicate,
-  TensorOf<[allowedType]>.predicate]>,
+  allowedType.predicate,
+  ValueSemanticsContainerOf<[allowedType]>.predicate]>,
+  name>;
+
+// Type constraint for types that are "like" some type or set of types T, that is
+// they're either a T or a mapable container of Ts.
+class TypeOrValueSemanticsContainer<Type allowedType, string name>
+    : TypeConstraint<Or<[
+  allowedType.predicate,
+  ValueSemanticsContainerOf<[allowedType]>.predicate]>,
   name>;
 
 // Temporary constraint to allow gradual transition to supporting 0-D vectors.
@@ -864,8 +880,8 @@ def BoolLikeOfAnyRank : TypeOrContainerOfAnyRank<I1, "bool-like">;
 
 // Type constraint for signless-integer-like types: signless integers, indices,
 // vectors of signless integers or indices, tensors of signless integers.
-def SignlessIntegerLike : TypeOrContainer<AnySignlessIntegerOrIndex,
-    "signless-integer-like">;
+def SignlessIntegerLike : TypeOrValueSemanticsContainer<
+    AnySignlessIntegerOrIndex, "signless-integer-like">;
 
 def SignlessIntegerLikeOfAnyRank : TypeOrContainerOfAnyRank<
     AnySignlessIntegerOrIndex,



More information about the Mlir-commits mailing list