[Mlir-commits] [mlir] b121c26 - [mlir] Add helper method to print and parse cyclic attributes and types (#65210)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Sep 4 09:19:21 PDT 2023


Author: Markus Böck
Date: 2023-09-04T18:19:18+02:00
New Revision: b121c266744d030120c59e6256559cbccacd3c6f

URL: https://github.com/llvm/llvm-project/commit/b121c266744d030120c59e6256559cbccacd3c6f
DIFF: https://github.com/llvm/llvm-project/commit/b121c266744d030120c59e6256559cbccacd3c6f.diff

LOG: [mlir] Add helper method to print and parse cyclic attributes and types (#65210)

Printing cyclic attributes and types currently has no first-class
support within the AsmPrinter and AsmParser. The workaround for this
issue used in all mutable attributes and types upstream has been to
create a `thread_local static SetVector` keeping track of currently
parsed and printed attributes.

This solution is not ideal readability wise due to the use of globals
and keeping track of state. Worst of all, this pattern had to be
reimplemented for every mutable attribute and type.

This patch therefore adds support for this pattern in `AsmPrinter` and
`AsmParser` replacing the use of this pattern. By calling
`tryStartCyclingPrint/Parse`, the mutable attribute or type are
registered in an internal stack. All subsequent calls to the function
with the same attribute or type will lead to returning failure. This way
the nesting can be detected and a short form printed or parsed instead.
Through the resetter returned by the call, the cyclic printing or
parsing region automatically ends on return.

Added: 
    

Modified: 
    mlir/include/mlir/IR/OpImplementation.h
    mlir/lib/AsmParser/AsmParserImpl.h
    mlir/lib/AsmParser/ParserState.h
    mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp
    mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
    mlir/lib/IR/AsmPrinter.cpp
    mlir/test/Dialect/LLVMIR/types-invalid.mlir
    mlir/test/lib/Dialect/Test/TestDialect.td
    mlir/test/lib/Dialect/Test/TestTypes.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h
index 2131fe313f8c597..f894ee64a27b0cf 100644
--- a/mlir/include/mlir/IR/OpImplementation.h
+++ b/mlir/include/mlir/IR/OpImplementation.h
@@ -222,11 +222,69 @@ class AsmPrinter {
     printArrowTypeList(results);
   }
 
+  /// Class used to automatically end a cyclic region on destruction.
+  class CyclicPrintReset {
+  public:
+    explicit CyclicPrintReset(AsmPrinter *printer) : printer(printer) {}
+
+    ~CyclicPrintReset() {
+      if (printer)
+        printer->popCyclicPrinting();
+    }
+
+    CyclicPrintReset(const CyclicPrintReset &) = delete;
+
+    CyclicPrintReset &operator=(const CyclicPrintReset &) = delete;
+
+    CyclicPrintReset(CyclicPrintReset &&rhs)
+        : printer(std::exchange(rhs.printer, nullptr)) {}
+
+    CyclicPrintReset &operator=(CyclicPrintReset &&rhs) {
+      printer = std::exchange(rhs.printer, nullptr);
+      return *this;
+    }
+
+  private:
+    AsmPrinter *printer;
+  };
+
+  /// Attempts to start a cyclic printing region for `attrOrType`.
+  /// A cyclic printing region starts with this call and ends with the
+  /// destruction of the returned `CyclicPrintReset`. During this time,
+  /// calling `tryStartCyclicPrint` with the same attribute in any printer
+  /// will lead to returning failure.
+  ///
+  /// This makes it possible to break infinite recursions when trying to print
+  /// cyclic attributes or types by printing only immutable parameters if nested
+  /// within itself.
+  template <class AttrOrTypeT>
+  FailureOr<CyclicPrintReset> tryStartCyclicPrint(AttrOrTypeT attrOrType) {
+    static_assert(
+        std::is_base_of_v<AttributeTrait::IsMutable<AttrOrTypeT>,
+                          AttrOrTypeT> ||
+            std::is_base_of_v<TypeTrait::IsMutable<AttrOrTypeT>, AttrOrTypeT>,
+        "Only mutable attributes or types can be cyclic");
+    if (failed(pushCyclicPrinting(attrOrType.getAsOpaquePointer())))
+      return failure();
+    return CyclicPrintReset(this);
+  }
+
 protected:
   /// Initialize the printer with no internal implementation. In this case, all
   /// virtual methods of this class must be overriden.
   AsmPrinter() = default;
 
+  /// Pushes a new attribute or type in the form of a type erased pointer
+  /// into an internal set.
+  /// Returns success if the type or attribute was inserted in the set or
+  /// failure if it was already contained.
+  virtual LogicalResult pushCyclicPrinting(const void *opaquePointer);
+
+  /// Removes the element that was last inserted with a successful call to
+  /// `pushCyclicPrinting`. There must be exactly one `popCyclicPrinting` call
+  /// in reverse order of all successful `pushCyclicPrinting`.
+  virtual void popCyclicPrinting();
+
 private:
   AsmPrinter(const AsmPrinter &) = delete;
   void operator=(const AsmPrinter &) = delete;
@@ -1265,12 +1323,67 @@ class AsmParser {
   /// next token.
   virtual ParseResult parseXInDimensionList() = 0;
 
+  /// Class used to automatically end a cyclic region on destruction.
+  class CyclicParseReset {
+  public:
+    explicit CyclicParseReset(AsmParser *parser) : parser(parser) {}
+
+    ~CyclicParseReset() {
+      if (parser)
+        parser->popCyclicParsing();
+    }
+
+    CyclicParseReset(const CyclicParseReset &) = delete;
+    CyclicParseReset &operator=(const CyclicParseReset &) = delete;
+    CyclicParseReset(CyclicParseReset &&rhs)
+        : parser(std::exchange(rhs.parser, nullptr)) {}
+    CyclicParseReset &operator=(CyclicParseReset &&rhs) {
+      parser = std::exchange(rhs.parser, nullptr);
+      return *this;
+    }
+
+  private:
+    AsmParser *parser;
+  };
+
+  /// Attempts to start a cyclic parsing region for `attrOrType`.
+  /// A cyclic parsing region starts with this call and ends with the
+  /// destruction of the returned `CyclicParseReset`. During this time,
+  /// calling `tryStartCyclicParse` with the same attribute in any parser
+  /// will lead to returning failure.
+  ///
+  /// This makes it possible to parse cyclic attributes or types by parsing a
+  /// short from if nested within itself.
+  template <class AttrOrTypeT>
+  FailureOr<CyclicParseReset> tryStartCyclicParse(AttrOrTypeT attrOrType) {
+    static_assert(
+        std::is_base_of_v<AttributeTrait::IsMutable<AttrOrTypeT>,
+                          AttrOrTypeT> ||
+            std::is_base_of_v<TypeTrait::IsMutable<AttrOrTypeT>, AttrOrTypeT>,
+        "Only mutable attributes or types can be cyclic");
+    if (failed(pushCyclicParsing(attrOrType.getAsOpaquePointer())))
+      return failure();
+
+    return CyclicParseReset(this);
+  }
+
 protected:
   /// Parse a handle to a resource within the assembly format for the given
   /// dialect.
   virtual FailureOr<AsmDialectResourceHandle>
   parseResourceHandle(Dialect *dialect) = 0;
 
+  /// Pushes a new attribute or type in the form of a type erased pointer
+  /// into an internal set.
+  /// Returns success if the type or attribute was inserted in the set or
+  /// failure if it was already contained.
+  virtual LogicalResult pushCyclicParsing(const void *opaquePointer) = 0;
+
+  /// Removes the element that was last inserted with a successful call to
+  /// `pushCyclicParsing`. There must be exactly one `popCyclicParsing` call
+  /// in reverse order of all successful `pushCyclicParsing`.
+  virtual void popCyclicParsing() = 0;
+
   //===--------------------------------------------------------------------===//
   // Code Completion
   //===--------------------------------------------------------------------===//

diff  --git a/mlir/lib/AsmParser/AsmParserImpl.h b/mlir/lib/AsmParser/AsmParserImpl.h
index 7208198f89e22c1..30c0079cda08611 100644
--- a/mlir/lib/AsmParser/AsmParserImpl.h
+++ b/mlir/lib/AsmParser/AsmParserImpl.h
@@ -570,6 +570,14 @@ class AsmParserImpl : public BaseT {
     return parser.parseXInDimensionList();
   }
 
+  LogicalResult pushCyclicParsing(const void *opaquePointer) override {
+    return success(parser.getState().cyclicParsingStack.insert(opaquePointer));
+  }
+
+  void popCyclicParsing() override {
+    parser.getState().cyclicParsingStack.pop_back();
+  }
+
   //===--------------------------------------------------------------------===//
   // Code Completion
   //===--------------------------------------------------------------------===//

diff  --git a/mlir/lib/AsmParser/ParserState.h b/mlir/lib/AsmParser/ParserState.h
index 0bf9bffaf1b6009..1428ea3a82cee9f 100644
--- a/mlir/lib/AsmParser/ParserState.h
+++ b/mlir/lib/AsmParser/ParserState.h
@@ -12,6 +12,7 @@
 #include "Lexer.h"
 #include "mlir/IR/Attributes.h"
 #include "mlir/IR/OpImplementation.h"
+#include "llvm/ADT/SetVector.h"
 #include "llvm/ADT/StringMap.h"
 
 namespace mlir {
@@ -70,6 +71,10 @@ struct ParserState {
   /// The current state for symbol parsing.
   SymbolState &symbols;
 
+  /// Stack of potentially cyclic mutable attributes or type currently being
+  /// parsed.
+  SetVector<const void *> cyclicParsingStack;
+
   /// An optional pointer to a struct containing high level parser state to be
   /// populated during parsing.
   AsmParserState *asmState;

diff  --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp
index afb8c9060619185..9810d4d64367739 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp
@@ -54,27 +54,16 @@ static StringRef getTypeKeyword(Type type) {
 /// Prints a structure type. Keeps track of known struct names to handle self-
 /// or mutually-referring structs without falling into infinite recursion.
 static void printStructType(AsmPrinter &printer, LLVMStructType type) {
-  // This keeps track of the names of identified structure types that are
-  // currently being printed. Since such types can refer themselves, this
-  // tracking is necessary to stop the recursion: the current function may be
-  // called recursively from AsmPrinter::printType after the appropriate
-  // dispatch. We maintain the invariant of this storage being modified
-  // exclusively in this function, and at most one name being added per call.
-  // TODO: consider having such functionality inside AsmPrinter.
-  thread_local SetVector<StringRef> knownStructNames;
-  unsigned stackSize = knownStructNames.size();
-  (void)stackSize;
-  auto guard = llvm::make_scope_exit([&]() {
-    assert(knownStructNames.size() == stackSize &&
-           "malformed identified stack when printing recursive structs");
-  });
+  FailureOr<AsmPrinter::CyclicPrintReset> cyclicPrint;
 
   printer << "<";
   if (type.isIdentified()) {
+    cyclicPrint = printer.tryStartCyclicPrint(type);
+
     printer << '"' << type.getName() << '"';
     // If we are printing a reference to one of the enclosing structs, just
     // print the name and stop to avoid infinitely long output.
-    if (knownStructNames.count(type.getName())) {
+    if (failed(cyclicPrint)) {
       printer << '>';
       return;
     }
@@ -91,12 +80,8 @@ static void printStructType(AsmPrinter &printer, LLVMStructType type) {
 
   // Put the current type on stack to avoid infinite recursion.
   printer << '(';
-  if (type.isIdentified())
-    knownStructNames.insert(type.getName());
   llvm::interleaveComma(type.getBody(), printer.getStream(),
                         [&](Type subtype) { dispatchPrint(printer, subtype); });
-  if (type.isIdentified())
-    knownStructNames.pop_back();
   printer << ')';
   printer << '>';
 }
@@ -198,21 +183,6 @@ static LLVMStructType trySetStructBody(LLVMStructType type,
 ///               | `struct<` string-literal `>`
 ///               | `struct<` string-literal `, opaque>`
 static LLVMStructType parseStructType(AsmParser &parser) {
-  // This keeps track of the names of identified structure types that are
-  // currently being parsed. Since such types can refer themselves, this
-  // tracking is necessary to stop the recursion: the current function may be
-  // called recursively from AsmParser::parseType after the appropriate
-  // dispatch. We maintain the invariant of this storage being modified
-  // exclusively in this function, and at most one name being added per call.
-  // TODO: consider having such functionality inside AsmParser.
-  thread_local SetVector<StringRef> knownStructNames;
-  unsigned stackSize = knownStructNames.size();
-  (void)stackSize;
-  auto guard = llvm::make_scope_exit([&]() {
-    assert(knownStructNames.size() == stackSize &&
-           "malformed identified stack when parsing recursive structs");
-  });
-
   Location loc = parser.getEncodedSourceLoc(parser.getCurrentLocation());
 
   if (failed(parser.parseLess()))
@@ -224,11 +194,18 @@ static LLVMStructType parseStructType(AsmParser &parser) {
   std::string name;
   bool isIdentified = succeeded(parser.parseOptionalString(&name));
   if (isIdentified) {
-    if (knownStructNames.count(name)) {
-      if (failed(parser.parseGreater()))
-        return LLVMStructType();
-      return LLVMStructType::getIdentifiedChecked(
+    SMLoc greaterLoc = parser.getCurrentLocation();
+    if (succeeded(parser.parseOptionalGreater())) {
+      auto type = LLVMStructType::getIdentifiedChecked(
           [loc] { return emitError(loc); }, loc.getContext(), name);
+      if (succeeded(parser.tryStartCyclicParse(type))) {
+        parser.emitError(
+            greaterLoc,
+            "struct without a body only allowed in a recursive struct");
+        return nullptr;
+      }
+
+      return type;
     }
     if (failed(parser.parseComma()))
       return LLVMStructType();
@@ -251,6 +228,18 @@ static LLVMStructType parseStructType(AsmParser &parser) {
     return type;
   }
 
+  FailureOr<AsmParser::CyclicParseReset> cyclicParse;
+  if (isIdentified) {
+    cyclicParse =
+        parser.tryStartCyclicParse(LLVMStructType::getIdentifiedChecked(
+            [loc] { return emitError(loc); }, loc.getContext(), name));
+    if (failed(cyclicParse)) {
+      parser.emitError(kwLoc,
+                       "identifier already used for an enclosing struct");
+      return nullptr;
+    }
+  }
+
   // Check for packedness.
   bool isPacked = succeeded(parser.parseOptionalKeyword("packed"));
   if (failed(parser.parseLParen()))
@@ -273,14 +262,10 @@ static LLVMStructType parseStructType(AsmParser &parser) {
   SmallVector<Type, 4> subtypes;
   SMLoc subtypesLoc = parser.getCurrentLocation();
   do {
-    if (isIdentified)
-      knownStructNames.insert(name);
     Type type;
     if (dispatchParse(parser, type))
       return LLVMStructType();
     subtypes.push_back(type);
-    if (isIdentified)
-      knownStructNames.pop_back();
   } while (succeeded(parser.parseOptionalComma()));
 
   if (parser.parseRParen() || parser.parseGreater())

diff  --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
index 124d4ed6e8e6edc..76e703946428361 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
@@ -710,33 +710,20 @@ static Type parseStructType(SPIRVDialect const &dialect,
                             DialectAsmParser &parser) {
   // TODO: This function is quite lengthy. Break it down into smaller chunks.
 
-  // To properly resolve recursive references while parsing recursive struct
-  // types, we need to maintain a list of enclosing struct type names. This set
-  // maintains the names of struct types in which the type we are about to parse
-  // is nested.
-  //
-  // Note: This has to be thread_local to enable multiple threads to safely
-  // parse concurrently.
-  thread_local SetVector<StringRef> structContext;
-
-  static auto removeIdentifierAndFail = [](SetVector<StringRef> &structContext,
-                                           StringRef identifier) {
-    if (!identifier.empty())
-      structContext.remove(identifier);
-
-    return Type();
-  };
-
   if (parser.parseLess())
     return Type();
 
   StringRef identifier;
+  FailureOr<DialectAsmParser::CyclicParseReset> cyclicParse;
 
   // Check if this is an identified struct type.
   if (succeeded(parser.parseOptionalKeyword(&identifier))) {
     // Check if this is a possible recursive reference.
+    auto structType =
+        StructType::getIdentified(dialect.getContext(), identifier);
+    cyclicParse = parser.tryStartCyclicParse(structType);
     if (succeeded(parser.parseOptionalGreater())) {
-      if (structContext.count(identifier) == 0) {
+      if (succeeded(cyclicParse)) {
         parser.emitError(
             parser.getNameLoc(),
             "recursive struct reference not nested in struct definition");
@@ -744,30 +731,24 @@ static Type parseStructType(SPIRVDialect const &dialect,
         return Type();
       }
 
-      return StructType::getIdentified(dialect.getContext(), identifier);
+      return structType;
     }
 
     if (failed(parser.parseComma()))
       return Type();
 
-    if (structContext.count(identifier) != 0) {
+    if (failed(cyclicParse)) {
       parser.emitError(parser.getNameLoc(),
                        "identifier already used for an enclosing struct");
-
-      return removeIdentifierAndFail(structContext, identifier);
+      return Type();
     }
-
-    structContext.insert(identifier);
   }
 
   if (failed(parser.parseLParen()))
-    return removeIdentifierAndFail(structContext, identifier);
+    return Type();
 
   if (succeeded(parser.parseOptionalRParen()) &&
       succeeded(parser.parseOptionalGreater())) {
-    if (!identifier.empty())
-      structContext.remove(identifier);
-
     return StructType::getEmpty(dialect.getContext(), identifier);
   }
 
@@ -783,30 +764,28 @@ static Type parseStructType(SPIRVDialect const &dialect,
   do {
     Type memberType;
     if (parser.parseType(memberType))
-      return removeIdentifierAndFail(structContext, identifier);
+      return Type();
     memberTypes.push_back(memberType);
 
     if (succeeded(parser.parseOptionalLSquare()))
       if (parseStructMemberDecorations(dialect, parser, memberTypes, offsetInfo,
                                        memberDecorationInfo))
-        return removeIdentifierAndFail(structContext, identifier);
+        return Type();
   } while (succeeded(parser.parseOptionalComma()));
 
   if (!offsetInfo.empty() && memberTypes.size() != offsetInfo.size()) {
     parser.emitError(parser.getNameLoc(),
                      "offset specification must be given for all members");
-    return removeIdentifierAndFail(structContext, identifier);
+    return Type();
   }
 
   if (failed(parser.parseRParen()) || failed(parser.parseGreater()))
-    return removeIdentifierAndFail(structContext, identifier);
+    return Type();
 
   if (!identifier.empty()) {
     if (failed(idStructTy.trySetBody(memberTypes, offsetInfo,
                                      memberDecorationInfo)))
       return Type();
-
-    structContext.remove(identifier);
     return idStructTy;
   }
 
@@ -886,20 +865,20 @@ static void print(SampledImageType type, DialectAsmPrinter &os) {
 }
 
 static void print(StructType type, DialectAsmPrinter &os) {
-  thread_local SetVector<StringRef> structContext;
+  FailureOr<AsmPrinter::CyclicPrintReset> cyclicPrint;
 
   os << "struct<";
 
   if (type.isIdentified()) {
     os << type.getIdentifier();
 
-    if (structContext.count(type.getIdentifier())) {
+    cyclicPrint = os.tryStartCyclicPrint(type);
+    if (failed(cyclicPrint)) {
       os << ">";
       return;
     }
 
     os << ", ";
-    structContext.insert(type.getIdentifier());
   }
 
   os << "(";
@@ -928,9 +907,6 @@ static void print(StructType type, DialectAsmPrinter &os) {
   llvm::interleaveComma(llvm::seq<unsigned>(0, type.getNumElements()), os,
                         printMember);
   os << ")>";
-
-  if (type.isIdentified())
-    structContext.remove(type.getIdentifier());
 }
 
 static void print(CooperativeMatrixType type, DialectAsmPrinter &os) {

diff  --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 333f4e537fcc749..c662edd592036ce 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -406,6 +406,10 @@ class AsmPrinter::Impl {
   void printAffineConstraint(AffineExpr expr, bool isEq);
   void printIntegerSet(IntegerSet set);
 
+  LogicalResult pushCyclicPrinting(const void *opaquePointer);
+
+  void popCyclicPrinting();
+
 protected:
   void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
                              ArrayRef<StringRef> elidedAttrs = {},
@@ -918,6 +922,16 @@ class DummyAliasDialectAsmPrinter : public DialectAsmPrinter {
   void printSymbolName(StringRef) override {}
   void printResourceHandle(const AsmDialectResourceHandle &) override {}
 
+  LogicalResult pushCyclicPrinting(const void *opaquePointer) override {
+    return success(cyclicPrintingStack.insert(opaquePointer));
+  }
+
+  void popCyclicPrinting() override { cyclicPrintingStack.pop_back(); }
+
+  /// Stack of potentially cyclic mutable attributes or type currently being
+  /// printed.
+  SetVector<const void *> cyclicPrintingStack;
+
   /// The initializer to use when identifying aliases.
   AliasInitializer &initializer;
 
@@ -1791,6 +1805,12 @@ class AsmStateImpl {
     return dialectResources;
   }
 
+  LogicalResult pushCyclicPrinting(const void *opaquePointer) {
+    return success(cyclicPrintingStack.insert(opaquePointer));
+  }
+
+  void popCyclicPrinting() { cyclicPrintingStack.pop_back(); }
+
 private:
   /// Collection of OpAsm interfaces implemented in the context.
   DialectInterfaceCollection<OpAsmDialectInterface> interfaces;
@@ -1816,6 +1836,10 @@ class AsmStateImpl {
   /// An optional location map to be populated.
   AsmState::LocationMap *locationMap;
 
+  /// Stack of potentially cyclic mutable attributes or type currently being
+  /// printed.
+  SetVector<const void *> cyclicPrintingStack;
+
   // Allow direct access to the impl fields.
   friend AsmState;
 };
@@ -2689,6 +2713,12 @@ void AsmPrinter::Impl::printHexString(ArrayRef<char> data) {
   printHexString(StringRef(data.data(), data.size()));
 }
 
+LogicalResult AsmPrinter::Impl::pushCyclicPrinting(const void *opaquePointer) {
+  return state.pushCyclicPrinting(opaquePointer);
+}
+
+void AsmPrinter::Impl::popCyclicPrinting() { state.popCyclicPrinting(); }
+
 //===--------------------------------------------------------------------===//
 // AsmPrinter
 //===--------------------------------------------------------------------===//
@@ -2747,6 +2777,12 @@ void AsmPrinter::printResourceHandle(const AsmDialectResourceHandle &resource) {
   impl->printResourceHandle(resource);
 }
 
+LogicalResult AsmPrinter::pushCyclicPrinting(const void *opaquePointer) {
+  return impl->pushCyclicPrinting(opaquePointer);
+}
+
+void AsmPrinter::popCyclicPrinting() { impl->popCyclicPrinting(); }
+
 //===----------------------------------------------------------------------===//
 // Affine expressions and maps
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/LLVMIR/types-invalid.mlir b/mlir/test/Dialect/LLVMIR/types-invalid.mlir
index f06f056cf49047f..76fb6780d8668fd 100644
--- a/mlir/test/Dialect/LLVMIR/types-invalid.mlir
+++ b/mlir/test/Dialect/LLVMIR/types-invalid.mlir
@@ -68,6 +68,20 @@ func.func @struct_literal_opaque() {
 
 // -----
 
+func.func @top_level_struct_no_body() {
+  // expected-error @below {{struct without a body only allowed in a recursive struct}}
+  "some.op"() : () -> !llvm.struct<"a">
+}
+
+// -----
+
+func.func @nested_redefine_attempt() {
+  // expected-error @below {{identifier already used for an enclosing struct}}
+  "some.op"() : () -> !llvm.struct<"a", (struct<"a", ()>)>
+}
+
+// -----
+
 func.func @unexpected_type() {
   // expected-error @+1 {{unexpected type, expected keyword}}
   "some.op"() : () -> !llvm.tensor<*xf32>

diff  --git a/mlir/test/lib/Dialect/Test/TestDialect.td b/mlir/test/lib/Dialect/Test/TestDialect.td
index e41c6040fe57a9b..8524d5b14584472 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.td
+++ b/mlir/test/lib/Dialect/Test/TestDialect.td
@@ -46,11 +46,6 @@ def Test_Dialect : Dialect {
   private:
     // Storage for a custom fallback interface.
     void *fallbackEffectOpInterfaces;
-
-    ::mlir::Type parseTestType(::mlir::AsmParser &parser,
-                               ::llvm::SetVector<::mlir::Type> &stack) const;
-    void printTestType(::mlir::Type type, ::mlir::AsmPrinter &printer,
-                       ::llvm::SetVector<::mlir::Type> &stack) const;
   }];
 }
 

diff  --git a/mlir/test/lib/Dialect/Test/TestTypes.cpp b/mlir/test/lib/Dialect/Test/TestTypes.cpp
index 20dc03a76526978..abb35d71d7f6d58 100644
--- a/mlir/test/lib/Dialect/Test/TestTypes.cpp
+++ b/mlir/test/lib/Dialect/Test/TestTypes.cpp
@@ -404,8 +404,7 @@ void TestDialect::registerTypes() {
   registerDynamicType(getCustomAssemblyFormatDynamicType(this));
 }
 
-Type TestDialect::parseTestType(AsmParser &parser,
-                                SetVector<Type> &stack) const {
+Type TestDialect::parseType(DialectAsmParser &parser) const {
   StringRef typeTag;
   {
     Type genType;
@@ -434,9 +433,12 @@ Type TestDialect::parseTestType(AsmParser &parser,
     return Type();
   auto rec = TestRecursiveType::get(parser.getContext(), name);
 
+  FailureOr<AsmParser::CyclicParseReset> cyclicParse =
+      parser.tryStartCyclicParse(rec);
+
   // If this type already has been parsed above in the stack, expect just the
   // name.
-  if (stack.contains(rec)) {
+  if (failed(cyclicParse)) {
     if (failed(parser.parseGreater()))
       return Type();
     return rec;
@@ -445,22 +447,14 @@ Type TestDialect::parseTestType(AsmParser &parser,
   // Otherwise, parse the body and update the type.
   if (failed(parser.parseComma()))
     return Type();
-  stack.insert(rec);
-  Type subtype = parseTestType(parser, stack);
-  stack.pop_back();
+  Type subtype = parseType(parser);
   if (!subtype || failed(parser.parseGreater()) || failed(rec.setBody(subtype)))
     return Type();
 
   return rec;
 }
 
-Type TestDialect::parseType(DialectAsmParser &parser) const {
-  SetVector<Type> stack;
-  return parseTestType(parser, stack);
-}
-
-void TestDialect::printTestType(Type type, AsmPrinter &printer,
-                                SetVector<Type> &stack) const {
+void TestDialect::printType(Type type, DialectAsmPrinter &printer) const {
   if (succeeded(generatedTypePrinter(type, printer)))
     return;
 
@@ -468,21 +462,18 @@ void TestDialect::printTestType(Type type, AsmPrinter &printer,
     return;
 
   auto rec = llvm::cast<TestRecursiveType>(type);
+
+  FailureOr<AsmPrinter::CyclicPrintReset> cyclicPrint =
+      printer.tryStartCyclicPrint(rec);
+
   printer << "test_rec<" << rec.getName();
-  if (!stack.contains(rec)) {
+  if (succeeded(cyclicPrint)) {
     printer << ", ";
-    stack.insert(rec);
-    printTestType(rec.getBody(), printer, stack);
-    stack.pop_back();
+    printType(rec.getBody(), printer);
   }
   printer << ">";
 }
 
-void TestDialect::printType(Type type, DialectAsmPrinter &printer) const {
-  SetVector<Type> stack;
-  printTestType(type, printer, stack);
-}
-
 Type TestRecursiveAliasType::getBody() const { return getImpl()->body; }
 
 void TestRecursiveAliasType::setBody(Type type) { (void)Base::mutate(type); }
@@ -490,16 +481,17 @@ void TestRecursiveAliasType::setBody(Type type) { (void)Base::mutate(type); }
 StringRef TestRecursiveAliasType::getName() const { return getImpl()->name; }
 
 Type TestRecursiveAliasType::parse(AsmParser &parser) {
-  thread_local static SetVector<Type> stack;
-
   StringRef name;
   if (parser.parseLess() || parser.parseKeyword(&name))
     return Type();
   auto rec = TestRecursiveAliasType::get(parser.getContext(), name);
 
+  FailureOr<AsmParser::CyclicParseReset> cyclicParse =
+      parser.tryStartCyclicParse(rec);
+
   // If this type already has been parsed above in the stack, expect just the
   // name.
-  if (stack.contains(rec)) {
+  if (failed(cyclicParse)) {
     if (failed(parser.parseGreater()))
       return Type();
     return rec;
@@ -508,11 +500,9 @@ Type TestRecursiveAliasType::parse(AsmParser &parser) {
   // Otherwise, parse the body and update the type.
   if (failed(parser.parseComma()))
     return Type();
-  stack.insert(rec);
   Type subtype;
   if (parser.parseType(subtype))
     return nullptr;
-  stack.pop_back();
   if (!subtype || failed(parser.parseGreater()))
     return Type();
 
@@ -522,14 +512,14 @@ Type TestRecursiveAliasType::parse(AsmParser &parser) {
 }
 
 void TestRecursiveAliasType::print(AsmPrinter &printer) const {
-  thread_local static SetVector<Type> stack;
+
+  FailureOr<AsmPrinter::CyclicPrintReset> cyclicPrint =
+      printer.tryStartCyclicPrint(*this);
 
   printer << "<" << getName();
-  if (!stack.contains(*this)) {
+  if (succeeded(cyclicPrint)) {
     printer << ", ";
-    stack.insert(*this);
     printer << getBody();
-    stack.pop_back();
   }
   printer << ">";
 }


        


More information about the Mlir-commits mailing list