[Mlir-commits] [mlir] a518299 - [mlir] Support for mutable types

Alex Zinenko llvmlistbot at llvm.org
Mon Jul 27 04:07:53 PDT 2020


Author: Alex Zinenko
Date: 2020-07-27T13:07:44+02:00
New Revision: a51829913dba28dae603fdcdddd242c7e20192a1

URL: https://github.com/llvm/llvm-project/commit/a51829913dba28dae603fdcdddd242c7e20192a1
DIFF: https://github.com/llvm/llvm-project/commit/a51829913dba28dae603fdcdddd242c7e20192a1.diff

LOG: [mlir] Support for mutable types

Introduce support for mutable storage in the StorageUniquer infrastructure.
This makes MLIR have key-value storage instead of just uniqued key storage. A
storage instance now contains a unique immutable key and a mutable value, both
stored in the arena allocator that belongs to the context. This is a
preconditio for supporting recursive types that require delayed initialization,
in particular LLVM structure types.  The functionality is exercised in the test
pass with trivial self-recursive type. So far, recursive types can only be
printed in parsed in a closed type system. Removing this restriction is left
for future work.

Differential Revision: https://reviews.llvm.org/D84171

Added: 
    mlir/test/IR/recursive-type.mlir
    mlir/test/lib/IR/TestTypes.cpp

Modified: 
    mlir/docs/Tutorials/DefiningAttributesAndTypes.md
    mlir/include/mlir/IR/AttributeSupport.h
    mlir/include/mlir/IR/Attributes.h
    mlir/include/mlir/IR/StorageUniquerSupport.h
    mlir/include/mlir/IR/TypeSupport.h
    mlir/include/mlir/IR/Types.h
    mlir/include/mlir/Support/StorageUniquer.h
    mlir/lib/Support/StorageUniquer.cpp
    mlir/test/lib/Dialect/Test/TestDialect.cpp
    mlir/test/lib/Dialect/Test/TestTypes.h
    mlir/test/lib/IR/CMakeLists.txt
    mlir/tools/mlir-opt/mlir-opt.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/docs/Tutorials/DefiningAttributesAndTypes.md b/mlir/docs/Tutorials/DefiningAttributesAndTypes.md
index cab2441b3320..45756e1a31ea 100644
--- a/mlir/docs/Tutorials/DefiningAttributesAndTypes.md
+++ b/mlir/docs/Tutorials/DefiningAttributesAndTypes.md
@@ -47,7 +47,8 @@ namespace MyTypes {
 enum Kinds {
   // These kinds will be used in the examples below.
   Simple = Type::Kind::FIRST_PRIVATE_EXPERIMENTAL_0_TYPE,
-  Complex
+  Complex,
+  Recursive
 };
 }
 ```
@@ -58,13 +59,17 @@ As described above, `Type` objects in MLIR are value-typed and rely on having an
 implicitly internal storage object that holds the actual data for the type. When
 defining a new `Type` it isn't always necessary to define a new storage class.
 So before defining the derived `Type`, it's important to know which of the two
-classes of `Type` we are defining. Some types are `primitives` meaning they do
+classes of `Type` we are defining. Some types are _primitives_ meaning they do
 not have any parameters and are singletons uniqued by kind, like the
 [`index` type](LangRef.md#index-type). Parametric types on the other hand, have
 additional information that 
diff erentiates 
diff erent instances of the same
 `Type` kind. For example the [`integer` type](LangRef.md#integer-type) has a
 bitwidth, making `i8` and `i16` be 
diff erent instances of
-[`integer` type](LangRef.md#integer-type).
+[`integer` type](LangRef.md#integer-type). Types can also have a mutable
+component, which can be used, for example, to construct self-referring recursive
+types. The mutable component _cannot_ be used to 
diff erentiate types within the
+same kind, so usually such types are also parametric where the parameters serve
+to identify them.
 
 #### Simple non-parametric types
 
@@ -240,6 +245,126 @@ public:
 };
 ```
 
+#### Types with a mutable component
+
+Types with a mutable component require defining a type storage class regardless
+of being parametric. The storage contains both the parameters and the mutable
+component and is accessed in a thread-safe way by the type support
+infrastructure.
+
+##### Defining a type storage
+
+In addition to the requirements for the type storage class for parametric types,
+the storage class for types with a mutable component must additionally obey the
+following.
+
+*   The mutable component must not participate in the storage key.
+*   Provide a mutation method that is used to modify an existing instance of the
+    storage. This method modifies the mutable component based on arguments,
+    using `allocator` for any new dynamically-allocated storage, and indicates
+    whether the modification was successful.
+    -   `LogicalResult mutate(StorageAllocator &allocator, Args ...&& args)`
+
+Let's define a simple storage for recursive types, where a type is identified by
+its name and can contain another type including itself.
+
+```c++
+/// Here we define a storage class for a RecursiveType that is identified by its
+/// name and contains another type.
+struct RecursiveTypeStorage : public TypeStorage {
+  /// The type is uniquely identified by its name. Note that the contained type
+  /// is _not_ a part of the key.
+  using KeyTy = StringRef;
+
+  /// Construct the storage from the type name. Explicitly initialize the
+  /// containedType to nullptr, which is used as marker for the mutable
+  /// component being not yet initialized.
+  RecursiveTypeStorage(StringRef name) : name(name), containedType(nullptr) {}
+
+  /// Define the comparison function.
+  bool operator==(const KeyTy &key) const { return key == name; }
+
+  /// Define a construction method for creating a new instance of the storage.
+  static RecursiveTypeStorage *construct(StorageAllocator &allocator,
+                                         const KeyTy &key) {
+    // Note that the key string is copied into the allocator to ensure it
+    // remains live as long as the storage itself.
+    return new (allocator.allocate<RecursiveTypeStorage>())
+        RecursiveTypeStorage(allocator.copyInto(key));
+  }
+
+  /// Define a mutation method for changing the type after it is created. In
+  /// many cases, we only want to set the mutable component once and reject
+  /// any further modification, which can be achieved by returning failure from
+  /// this function.
+  LogicalResult mutate(StorageAllocator &, Type body) {
+    // If the contained type has been initialized already, and the call tries
+    // to change it, reject the change.
+    if (containedType && containedType != body)
+      return failure();
+
+    // Change the body successfully.
+    containedType = body;
+    return success();
+  }
+
+  StringRef name;
+  Type containedType;
+};
+```
+
+##### Type class definition
+
+Having defined the storage class, we can define the type class itself. This is
+similar to parametric types. `Type::TypeBase` provides a `mutate` method that
+forwards its arguments to the `mutate` method of the storage and ensures the
+modification happens under lock.
+
+```c++
+class RecursiveType : public Type::TypeBase<RecursiveType, Type,
+                                            RecursiveTypeStorage> {
+public:
+  /// Inherit parent constructors.
+  using Base::Base;
+
+  /// This static method is used to support type inquiry through isa, cast,
+  /// and dyn_cast.
+  static bool kindof(unsigned kind) { return kind == MyTypes::Recursive; }
+
+  /// Creates an instance of the Recursive type. This only takes the type name
+  /// and returns the type with uninitialized body.
+  static RecursiveType get(MLIRContext *ctx, StringRef name) {
+    // Call into the base to get a uniqued instance of this type. The parameter
+    // (name) is passed after the kind.
+    return Base::get(ctx, MyTypes::Recursive, name);
+  }
+
+  /// Now we can change the mutable component of the type. This is an instance
+  /// method callable on an already existing RecursiveType.
+  void setBody(Type body) {
+    // Call into the base to mutate the type.
+    LogicalResult result = Base::mutate(body);
+    // Most types expect mutation to always succeed, but types can implement
+    // custom logic for handling mutation failures.
+    assert(succeeded(result) &&
+           "attempting to change the body of an already-initialized type");
+    // Avoid unused-variable warning when building without assertions.
+    (void) result;
+  }
+
+  /// Returns the contained type, which may be null if it has not been
+  /// initialized yet.
+  Type getBody() {
+    return getImpl()->containedType;
+  }
+
+  /// Returns the name.
+  StringRef getName() {
+    return getImpl()->name;
+  }
+};
+```
+
 ### Registering types with a Dialect
 
 Once the dialect types have been defined, they must then be registered with a

diff  --git a/mlir/include/mlir/IR/AttributeSupport.h b/mlir/include/mlir/IR/AttributeSupport.h
index cd722bb0f2c5..72a89be43867 100644
--- a/mlir/include/mlir/IR/AttributeSupport.h
+++ b/mlir/include/mlir/IR/AttributeSupport.h
@@ -139,6 +139,13 @@ class AttributeUniquer {
         kind, std::forward<Args>(args)...);
   }
 
+  template <typename ImplType, typename... Args>
+  static LogicalResult mutate(MLIRContext *ctx, ImplType *impl,
+                              Args &&...args) {
+    assert(impl && "cannot mutate null attribute");
+    return ctx->getAttributeUniquer().mutate(impl, std::forward<Args>(args)...);
+  }
+
 private:
   /// Initialize the given attribute storage instance.
   static void initializeAttributeStorage(AttributeStorage *storage,

diff  --git a/mlir/include/mlir/IR/Attributes.h b/mlir/include/mlir/IR/Attributes.h
index 89dad2ec40cf..5ecf5763ecd4 100644
--- a/mlir/include/mlir/IR/Attributes.h
+++ b/mlir/include/mlir/IR/Attributes.h
@@ -48,10 +48,10 @@ struct SparseElementsAttributeStorage;
 
 /// Attributes are known-constant values of operations and functions.
 ///
-/// Instances of the Attribute class are references to immutable, uniqued,
-/// and immortal values owned by MLIRContext. As such, an Attribute is a thin
-/// wrapper around an underlying storage pointer. Attributes are usually passed
-/// by value.
+/// Instances of the Attribute class are references to immortal key-value pairs
+/// with immutable, uniqued key owned by MLIRContext. As such, an Attribute is a
+/// thin wrapper around an underlying storage pointer. Attributes are usually
+/// passed by value.
 class Attribute {
 public:
   /// Integer identifier for all the concrete attribute kinds.

diff  --git a/mlir/include/mlir/IR/StorageUniquerSupport.h b/mlir/include/mlir/IR/StorageUniquerSupport.h
index c2250e854716..4c7693c28d2f 100644
--- a/mlir/include/mlir/IR/StorageUniquerSupport.h
+++ b/mlir/include/mlir/IR/StorageUniquerSupport.h
@@ -105,6 +105,14 @@ class StorageUserBase : public BaseT, public Traits<ConcreteT>... {
     return UniquerT::template get<ConcreteT>(loc.getContext(), kind, args...);
   }
 
+  /// Mutate the current storage instance. This will not change the unique key.
+  /// The arguments are forwarded to 'ConcreteT::mutate'.
+  template <typename... Args>
+  LogicalResult mutate(Args &&...args) {
+    return UniquerT::mutate(this->getContext(), getImpl(),
+                            std::forward<Args>(args)...);
+  }
+
   /// Default implementation that just returns success.
   template <typename... Args>
   static LogicalResult verifyConstructionInvariants(Args... args) {

diff  --git a/mlir/include/mlir/IR/TypeSupport.h b/mlir/include/mlir/IR/TypeSupport.h
index 7961dd22d47d..ddb91e09dc89 100644
--- a/mlir/include/mlir/IR/TypeSupport.h
+++ b/mlir/include/mlir/IR/TypeSupport.h
@@ -132,6 +132,15 @@ struct TypeUniquer {
         },
         kind, std::forward<Args>(args)...);
   }
+
+  /// Change the mutable component of the given type instance in the provided
+  /// context.
+  template <typename ImplType, typename... Args>
+  static LogicalResult mutate(MLIRContext *ctx, ImplType *impl,
+                              Args &&...args) {
+    assert(impl && "cannot mutate null type");
+    return ctx->getTypeUniquer().mutate(impl, std::forward<Args>(args)...);
+  }
 };
 } // namespace detail
 

diff  --git a/mlir/include/mlir/IR/Types.h b/mlir/include/mlir/IR/Types.h
index c14f8558d850..83636585c499 100644
--- a/mlir/include/mlir/IR/Types.h
+++ b/mlir/include/mlir/IR/Types.h
@@ -27,15 +27,17 @@ struct FunctionTypeStorage;
 struct OpaqueTypeStorage;
 } // namespace detail
 
-/// Instances of the Type class are immutable and uniqued.  They wrap a pointer
-/// to the storage object owned by MLIRContext.  Therefore, instances of Type
-/// are passed around by value.
+/// Instances of the Type class are uniqued, have an immutable identifier and an
+/// optional mutable component.  They wrap a pointer to the storage object owned
+/// by MLIRContext.  Therefore, instances of Type are passed around by value.
 ///
 /// Some types are "primitives" meaning they do not have any parameters, for
 /// example the Index type.  Parametric types have additional information that
 /// 
diff erentiates the types of the same kind between them, for example the
 /// Integer type has bitwidth, making i8 and i16 belong to the same kind by be
-/// 
diff erent instances of the IntegerType.
+/// 
diff erent instances of the IntegerType.  Type parameters are part of the
+/// unique immutable key.  The mutable component of the type can be modified
+/// after the type is created, but cannot affect the identity of the type.
 ///
 /// Types are constructed and uniqued via the 'detail::TypeUniquer' class.
 ///
@@ -62,6 +64,7 @@ struct OpaqueTypeStorage;
 ///    - The type kind (for LLVM-style RTTI).
 ///    - The dialect that defined the type.
 ///    - Any parameters of the type.
+///    - An optional mutable component.
 /// For non-parametric types, a convenience DefaultTypeStorage is provided.
 /// Parametric storage types must derive TypeStorage and respect the following:
 ///    - Define a type alias, KeyTy, to a type that uniquely identifies the
@@ -75,11 +78,14 @@ struct OpaqueTypeStorage;
 ///    - Provide a method, 'bool operator==(const KeyTy &) const', to
 ///      compare the storage instance against an instance of the key type.
 ///
-///    - Provide a construction method:
+///    - Provide a static construction method:
 ///        'DerivedStorage *construct(TypeStorageAllocator &, const KeyTy &key)'
 ///      that builds a unique instance of the derived storage. The arguments to
 ///      this function are an allocator to store any uniqued data within the
 ///      context and the key type for this storage.
+///
+///    - If they have a mutable component, this component must not be a part of
+//       the key.
 class Type {
 public:
   /// Integer identifier for all the concrete type kinds.

diff  --git a/mlir/include/mlir/Support/StorageUniquer.h b/mlir/include/mlir/Support/StorageUniquer.h
index f13a2fef9d50..3100b4454197 100644
--- a/mlir/include/mlir/Support/StorageUniquer.h
+++ b/mlir/include/mlir/Support/StorageUniquer.h
@@ -10,6 +10,7 @@
 #define MLIR_SUPPORT_STORAGEUNIQUER_H
 
 #include "mlir/Support/LLVM.h"
+#include "mlir/Support/LogicalResult.h"
 #include "llvm/ADT/DenseSet.h"
 #include "llvm/Support/Allocator.h"
 
@@ -60,6 +61,20 @@ using has_impltype_hash_t = decltype(ImplTy::hashKey(std::declval<T>()));
 ///      that is called when erasing a storage instance. This should cleanup any
 ///      fields of the storage as necessary and not attempt to free the memory
 ///      of the storage itself.
+///
+/// Storage classes may have an optional mutable component, which must not take
+/// part in the unique immutable key. In this case, storage classes may be
+/// mutated with `mutate` and must additionally respect the following:
+///    - Provide a mutation method:
+///        'LogicalResult mutate(StorageAllocator &, <...>)'
+///      that is called when mutating a storage instance. The first argument is
+///      an allocator to store any mutable data, and the remaining arguments are
+///      forwarded from the call site. The storage can be mutated at any time
+///      after creation. Care must be taken to avoid excessive mutation since
+///      the allocated storage can keep containing previous states. The return
+///      value of the function is used to indicate whether the mutation was
+///      successful, e.g., to limit the number of mutations or enable deferred
+///      one-time assignment of the mutable component.
 class StorageUniquer {
 public:
   StorageUniquer();
@@ -166,6 +181,17 @@ class StorageUniquer {
     return static_cast<Storage *>(getImpl(kind, ctorFn));
   }
 
+  /// Changes the mutable component of 'storage' by forwarding the trailing
+  /// arguments to the 'mutate' function of the derived class.
+  template <typename Storage, typename... Args>
+  LogicalResult mutate(Storage *storage, Args &&...args) {
+    auto mutationFn = [&](StorageAllocator &allocator) -> LogicalResult {
+      return static_cast<Storage &>(*storage).mutate(
+          allocator, std::forward<Args>(args)...);
+    };
+    return mutateImpl(mutationFn);
+  }
+
   /// Erases a uniqued instance of 'Storage'. This function is used for derived
   /// types that have complex storage or uniquing constraints.
   template <typename Storage, typename Arg, typename... Args>
@@ -206,6 +232,10 @@ class StorageUniquer {
                  function_ref<bool(const BaseStorage *)> isEqual,
                  function_ref<void(BaseStorage *)> cleanupFn);
 
+  /// Implementation for mutating an instance of a derived storage.
+  LogicalResult
+  mutateImpl(function_ref<LogicalResult(StorageAllocator &)> mutationFn);
+
   /// The internal implementation class.
   std::unique_ptr<detail::StorageUniquerImpl> impl;
 

diff  --git a/mlir/lib/Support/StorageUniquer.cpp b/mlir/lib/Support/StorageUniquer.cpp
index 40304a544c4f..f7c953e98140 100644
--- a/mlir/lib/Support/StorageUniquer.cpp
+++ b/mlir/lib/Support/StorageUniquer.cpp
@@ -124,6 +124,16 @@ struct StorageUniquerImpl {
     storageTypes.erase(existing);
   }
 
+  /// Mutates an instance of a derived storage in a thread-safe way.
+  LogicalResult
+  mutate(function_ref<LogicalResult(StorageAllocator &)> mutationFn) {
+    if (!threadingIsEnabled)
+      return mutationFn(allocator);
+
+    llvm::sys::SmartScopedWriter<true> lock(mutex);
+    return mutationFn(allocator);
+  }
+
   //===--------------------------------------------------------------------===//
   // Instance Storage
   //===--------------------------------------------------------------------===//
@@ -214,3 +224,9 @@ void StorageUniquer::eraseImpl(unsigned kind, unsigned hashValue,
                                function_ref<void(BaseStorage *)> cleanupFn) {
   impl->erase(kind, hashValue, isEqual, cleanupFn);
 }
+
+/// Implementation for mutating an instance of a derived storage.
+LogicalResult StorageUniquer::mutateImpl(
+    function_ref<LogicalResult(StorageAllocator &)> mutationFn) {
+  return impl->mutate(mutationFn);
+}

diff  --git a/mlir/test/IR/recursive-type.mlir b/mlir/test/IR/recursive-type.mlir
new file mode 100644
index 000000000000..6f90c3c78051
--- /dev/null
+++ b/mlir/test/IR/recursive-type.mlir
@@ -0,0 +1,16 @@
+// RUN: mlir-opt %s -test-recursive-types | FileCheck %s
+
+// CHECK-LABEL: @roundtrip
+func @roundtrip() {
+  // CHECK: !test.test_rec<a, test_rec<b, test_type>>
+  "test.dummy_op_for_roundtrip"() : () -> !test.test_rec<a, test_rec<b, test_type>>
+  // CHECK: !test.test_rec<c, test_rec<c>>
+  "test.dummy_op_for_roundtrip"() : () -> !test.test_rec<c, test_rec<c>>
+  return
+}
+
+// CHECK-LABEL: @create
+func @create() {
+  // CHECK: !test.test_rec<some_long_and_unique_name, test_rec<some_long_and_unique_name>>
+  return
+}

diff  --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp
index 47aa86c45cc6..cdbf974679bd 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.cpp
+++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp
@@ -16,6 +16,7 @@
 #include "mlir/IR/TypeUtilities.h"
 #include "mlir/Transforms/FoldUtils.h"
 #include "mlir/Transforms/InliningUtils.h"
+#include "llvm/ADT/SetVector.h"
 #include "llvm/ADT/StringSwitch.h"
 
 using namespace mlir;
@@ -137,19 +138,73 @@ TestDialect::TestDialect(MLIRContext *context)
       >();
   addInterfaces<TestOpAsmInterface, TestOpFolderDialectInterface,
                 TestInlinerInterface>();
-  addTypes<TestType>();
+  addTypes<TestType, TestRecursiveType>();
   allowUnknownOperations();
 }
 
-Type TestDialect::parseType(DialectAsmParser &parser) const {
-  if (failed(parser.parseKeyword("test_type")))
+static Type parseTestType(DialectAsmParser &parser,
+                          llvm::SetVector<Type> &stack) {
+  StringRef typeTag;
+  if (failed(parser.parseKeyword(&typeTag)))
+    return Type();
+
+  if (typeTag == "test_type")
+    return TestType::get(parser.getBuilder().getContext());
+
+  if (typeTag != "test_rec")
+    return Type();
+
+  StringRef name;
+  if (parser.parseLess() || parser.parseKeyword(&name))
+    return Type();
+  auto rec = TestRecursiveType::create(parser.getBuilder().getContext(), name);
+
+  // If this type already has been parsed above in the stack, expect just the
+  // name.
+  if (stack.contains(rec)) {
+    if (failed(parser.parseGreater()))
+      return Type();
+    return rec;
+  }
+
+  // Otherwise, parse the body and update the type.
+  if (failed(parser.parseComma()))
+    return Type();
+  stack.insert(rec);
+  Type subtype = parseTestType(parser, stack);
+  stack.pop_back();
+  if (!subtype || failed(parser.parseGreater()) || failed(rec.setBody(subtype)))
     return Type();
-  return TestType::get(getContext());
+
+  return rec;
+}
+
+Type TestDialect::parseType(DialectAsmParser &parser) const {
+  llvm::SetVector<Type> stack;
+  return parseTestType(parser, stack);
+}
+
+static void printTestType(Type type, DialectAsmPrinter &printer,
+                          llvm::SetVector<Type> &stack) {
+  if (type.isa<TestType>()) {
+    printer << "test_type";
+    return;
+  }
+
+  auto rec = type.cast<TestRecursiveType>();
+  printer << "test_rec<" << rec.getName();
+  if (!stack.contains(rec)) {
+    printer << ", ";
+    stack.insert(rec);
+    printTestType(rec.getBody(), printer, stack);
+    stack.pop_back();
+  }
+  printer << ">";
 }
 
 void TestDialect::printType(Type type, DialectAsmPrinter &printer) const {
-  assert(type.isa<TestType>() && "unexpected type");
-  printer << "test_type";
+  llvm::SetVector<Type> stack;
+  printTestType(type, printer, stack);
 }
 
 LogicalResult TestDialect::verifyOperationAttribute(Operation *op,

diff  --git a/mlir/test/lib/Dialect/Test/TestTypes.h b/mlir/test/lib/Dialect/Test/TestTypes.h
index 0596f61c1fa1..9e2c297c6a89 100644
--- a/mlir/test/lib/Dialect/Test/TestTypes.h
+++ b/mlir/test/lib/Dialect/Test/TestTypes.h
@@ -39,6 +39,60 @@ struct TestType : public Type::TypeBase<TestType, Type, TypeStorage,
     emitRemark(loc) << *this << " - TestC";
   }
 };
+
+/// Storage for simple named recursive types, where the type is identified by
+/// its name and can "contain" another type, including itself.
+struct TestRecursiveTypeStorage : public TypeStorage {
+  using KeyTy = StringRef;
+
+  explicit TestRecursiveTypeStorage(StringRef key) : name(key), body(Type()) {}
+
+  bool operator==(const KeyTy &other) const { return name == other; }
+
+  static TestRecursiveTypeStorage *construct(TypeStorageAllocator &allocator,
+                                             const KeyTy &key) {
+    return new (allocator.allocate<TestRecursiveTypeStorage>())
+        TestRecursiveTypeStorage(allocator.copyInto(key));
+  }
+
+  LogicalResult mutate(TypeStorageAllocator &allocator, Type newBody) {
+    // Cannot set a 
diff erent body than before.
+    if (body && body != newBody)
+      return failure();
+
+    body = newBody;
+    return success();
+  }
+
+  StringRef name;
+  Type body;
+};
+
+/// Simple recursive type identified by its name and pointing to another named
+/// type, potentially itself. This requires the body to be mutated separately
+/// from type creation.
+class TestRecursiveType
+    : public Type::TypeBase<TestRecursiveType, Type, TestRecursiveTypeStorage> {
+public:
+  using Base::Base;
+
+  static bool kindof(unsigned kind) {
+    return kind == Type::Kind::FIRST_PRIVATE_EXPERIMENTAL_9_TYPE + 1;
+  }
+
+  static TestRecursiveType create(MLIRContext *ctx, StringRef name) {
+    return Base::get(ctx, Type::Kind::FIRST_PRIVATE_EXPERIMENTAL_9_TYPE + 1,
+                     name);
+  }
+
+  /// Body getter and setter.
+  LogicalResult setBody(Type body) { return Base::mutate(body); }
+  Type getBody() { return getImpl()->body; }
+
+  /// Name/key getter.
+  StringRef getName() { return getImpl()->name; }
+};
+
 } // end namespace mlir
 
 #endif // MLIR_TESTTYPES_H

diff  --git a/mlir/test/lib/IR/CMakeLists.txt b/mlir/test/lib/IR/CMakeLists.txt
index 5456dc9e8816..f77b26e5ca18 100644
--- a/mlir/test/lib/IR/CMakeLists.txt
+++ b/mlir/test/lib/IR/CMakeLists.txt
@@ -5,6 +5,7 @@ add_mlir_library(MLIRTestIR
   TestMatchers.cpp
   TestSideEffects.cpp
   TestSymbolUses.cpp
+  TestTypes.cpp
 
   EXCLUDE_FROM_LIBMLIR
 

diff  --git a/mlir/test/lib/IR/TestTypes.cpp b/mlir/test/lib/IR/TestTypes.cpp
new file mode 100644
index 000000000000..f62c06eededf
--- /dev/null
+++ b/mlir/test/lib/IR/TestTypes.cpp
@@ -0,0 +1,78 @@
+//===- TestTypes.cpp - Test passes for MLIR types -------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "TestTypes.h"
+#include "TestDialect.h"
+#include "mlir/Pass/Pass.h"
+
+using namespace mlir;
+
+namespace {
+struct TestRecursiveTypesPass
+    : public PassWrapper<TestRecursiveTypesPass, FunctionPass> {
+  LogicalResult createIRWithTypes();
+
+  void runOnFunction() override {
+    FuncOp func = getFunction();
+
+    // Just make sure recurisve types are printed and parsed.
+    if (func.getName() == "roundtrip")
+      return;
+
+    // Create a recursive type and print it as a part of a dummy op.
+    if (func.getName() == "create") {
+      if (failed(createIRWithTypes()))
+        signalPassFailure();
+      return;
+    }
+
+    // Unknown key.
+    func.emitOpError() << "unexpected function name";
+    signalPassFailure();
+  }
+};
+} // end namespace
+
+LogicalResult TestRecursiveTypesPass::createIRWithTypes() {
+  MLIRContext *ctx = &getContext();
+  FuncOp func = getFunction();
+  auto type = TestRecursiveType::create(ctx, "some_long_and_unique_name");
+  if (failed(type.setBody(type)))
+    return func.emitError("expected to be able to set the type body");
+
+  // Setting the same body is fine.
+  if (failed(type.setBody(type)))
+    return func.emitError(
+        "expected to be able to set the type body to the same value");
+
+  // Setting a 
diff erent body is not.
+  if (succeeded(type.setBody(IndexType::get(ctx))))
+    return func.emitError(
+        "not expected to be able to change function body more than once");
+
+  // Expecting to get the same type for the same name.
+  auto other = TestRecursiveType::create(ctx, "some_long_and_unique_name");
+  if (type != other)
+    return func.emitError("expected type name to be the uniquing key");
+
+  // Create the op to check how the type is printed.
+  OperationState state(func.getLoc(), "test.dummy_type_test_op");
+  state.addTypes(type);
+  func.getBody().front().push_front(Operation::create(state));
+
+  return success();
+}
+
+namespace mlir {
+
+void registerTestRecursiveTypesPass() {
+  PassRegistration<TestRecursiveTypesPass> reg(
+      "test-recursive-types", "Test support for recursive types");
+}
+
+} // end namespace mlir

diff  --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index f749c7ad98ad..f60864a6a371 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -63,6 +63,7 @@ void registerTestMemRefDependenceCheck();
 void registerTestMemRefStrideCalculation();
 void registerTestOpaqueLoc();
 void registerTestPreparationPassWithAllowedMemrefResults();
+void registerTestRecursiveTypesPass();
 void registerTestReducer();
 void registerTestGpuParallelLoopMappingPass();
 void registerTestSCFUtilsPass();
@@ -138,6 +139,7 @@ void registerTestPasses() {
   registerTestMemRefStrideCalculation();
   registerTestOpaqueLoc();
   registerTestPreparationPassWithAllowedMemrefResults();
+  registerTestRecursiveTypesPass();
   registerTestReducer();
   registerTestGpuParallelLoopMappingPass();
   registerTestSCFUtilsPass();


        


More information about the Mlir-commits mailing list