[llvm-branch-commits] [mlir] 74438ef - [mlir] Use thread_local stack in LLVM dialect type parsing and printing

Alex Zinenko via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Wed Jan 6 03:10:08 PST 2021


Author: Alex Zinenko
Date: 2021-01-06T12:05:24+01:00
New Revision: 74438eff511e71dc33841546d89cb34206551d55

URL: https://github.com/llvm/llvm-project/commit/74438eff511e71dc33841546d89cb34206551d55
DIFF: https://github.com/llvm/llvm-project/commit/74438eff511e71dc33841546d89cb34206551d55.diff

LOG: [mlir] Use thread_local stack in LLVM dialect type parsing and printing

LLVM dialect type parsing and printing have been using a local stack object
forwarded between recursive functions responsible for parsing or printing
specific types. This stack is necessary to intercept (mutually) recursive
structure types and avoid inifinite recursion. This approach works only thanks
to the closedness of the LLVM dialect type system: types that don't belong to
the dialect are not allowed. Switch the approach to using a `thread_local`
stack inside the functions parsing the structure types. This makes the code
slightly cleaner by avoiding the need to pass the stack object around and, more
importantly, makes it possible to reconsider the closedness of the LLVM dialect
type system. As a nice side effect of this change, container LLVM dialect types
now support type aliases in their body (although it is currently impossible to
also use the alises when printing).

Depends On D93713

Reviewed By: mehdi_amini

Differential Revision: https://reviews.llvm.org/D93714

Added: 
    

Modified: 
    mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp
    mlir/test/Dialect/LLVMIR/types-invalid.mlir
    mlir/test/Dialect/LLVMIR/types.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp
index 3d72e254f338..08c00befcf18 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp
@@ -9,6 +9,7 @@
 #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/DialectImplementation.h"
+#include "llvm/ADT/ScopeExit.h"
 #include "llvm/ADT/SetVector.h"
 #include "llvm/ADT/TypeSwitch.h"
 
@@ -19,8 +20,14 @@ using namespace mlir::LLVM;
 // Printing.
 //===----------------------------------------------------------------------===//
 
-static void printTypeImpl(llvm::raw_ostream &os, Type type,
-                          llvm::SetVector<StringRef> &stack);
+/// If the given type is compatible with the LLVM dialect, prints it using
+/// internal functions to avoid getting a verbose `!llvm` prefix. Otherwise
+/// prints it as usual.
+static void dispatchPrint(DialectAsmPrinter &printer, Type type) {
+  if (isCompatibleType(type))
+    return mlir::LLVM::detail::printType(type, printer);
+  printer.printType(type);
+}
 
 /// Returns the keyword to use for the given type.
 static StringRef getTypeKeyword(Type type) {
@@ -48,76 +55,79 @@ static StringRef getTypeKeyword(Type type) {
       });
 }
 
-/// Prints the body of a structure type. Uses `stack` to avoid printing
-/// recursive structs indefinitely.
-static void printStructTypeBody(llvm::raw_ostream &os, LLVMStructType type,
-                                llvm::SetVector<StringRef> &stack) {
-  if (type.isIdentified() && type.isOpaque()) {
-    os << "opaque";
-    return;
-  }
-
-  if (type.isPacked())
-    os << "packed ";
-
-  // Put the current type on stack to avoid infinite recursion.
-  os << '(';
-  if (type.isIdentified())
-    stack.insert(type.getName());
-  llvm::interleaveComma(type.getBody(), os, [&](Type subtype) {
-    printTypeImpl(os, subtype, stack);
+/// Prints a structure type. Keeps track of known struct names to handle self-
+/// or mutually-referring structs without falling into infinite recursion.
+static void printStructType(DialectAsmPrinter &printer, LLVMStructType type) {
+  // This keeps track of the names of identified structure types that are
+  // currently being printed. Since such types can refer themselves, this
+  // tracking is necessary to stop the recursion: the current function may be
+  // called recursively from DialectAsmPrinter::printType after the appropriate
+  // dispatch. We maintain the invariant of this storage being modified
+  // exclusively in this function, and at most one name being added per call.
+  // TODO: consider having such functionality inside DialectAsmPrinter.
+  thread_local llvm::SetVector<StringRef> knownStructNames;
+  unsigned stackSize = knownStructNames.size();
+  (void)stackSize;
+  auto guard = llvm::make_scope_exit([&]() {
+    assert(knownStructNames.size() == stackSize &&
+           "malformed identified stack when printing recursive structs");
   });
-  if (type.isIdentified())
-    stack.pop_back();
-  os << ')';
-}
 
-/// Prints a structure type. Uses `stack` to keep track of the identifiers of
-/// the structs being printed. Checks if the identifier of a struct is contained
-/// in `stack`, i.e. whether a self-reference to a recursive stack is being
-/// printed, and only prints the name to avoid infinite recursion.
-static void printStructType(llvm::raw_ostream &os, LLVMStructType type,
-                            llvm::SetVector<StringRef> &stack) {
-  os << "<";
+  printer << "<";
   if (type.isIdentified()) {
-    os << '"' << type.getName() << '"';
+    printer << '"' << type.getName() << '"';
     // If we are printing a reference to one of the enclosing structs, just
     // print the name and stop to avoid infinitely long output.
-    if (stack.count(type.getName())) {
-      os << '>';
+    if (knownStructNames.count(type.getName())) {
+      printer << '>';
       return;
     }
-    os << ", ";
+    printer << ", ";
+  }
+
+  if (type.isIdentified() && type.isOpaque()) {
+    printer << "opaque>";
+    return;
   }
 
-  printStructTypeBody(os, type, stack);
-  os << '>';
+  if (type.isPacked())
+    printer << "packed ";
+
+  // Put the current type on stack to avoid infinite recursion.
+  printer << '(';
+  if (type.isIdentified())
+    knownStructNames.insert(type.getName());
+  llvm::interleaveComma(type.getBody(), printer.getStream(),
+                        [&](Type subtype) { dispatchPrint(printer, subtype); });
+  if (type.isIdentified())
+    knownStructNames.pop_back();
+  printer << ')';
+  printer << '>';
 }
 
 /// Prints a type containing a fixed number of elements.
 template <typename TypeTy>
-static void printArrayOrVectorType(llvm::raw_ostream &os, TypeTy type,
-                                   llvm::SetVector<StringRef> &stack) {
-  os << '<' << type.getNumElements() << " x ";
-  printTypeImpl(os, type.getElementType(), stack);
-  os << '>';
+static void printArrayOrVectorType(DialectAsmPrinter &printer, TypeTy type) {
+  printer << '<' << type.getNumElements() << " x ";
+  dispatchPrint(printer, type.getElementType());
+  printer << '>';
 }
 
 /// Prints a function type.
-static void printFunctionType(llvm::raw_ostream &os, LLVMFunctionType funcType,
-                              llvm::SetVector<StringRef> &stack) {
-  os << '<';
-  printTypeImpl(os, funcType.getReturnType(), stack);
-  os << " (";
-  llvm::interleaveComma(funcType.getParams(), os, [&os, &stack](Type subtype) {
-    printTypeImpl(os, subtype, stack);
-  });
+static void printFunctionType(DialectAsmPrinter &printer,
+                              LLVMFunctionType funcType) {
+  printer << '<';
+  dispatchPrint(printer, funcType.getReturnType());
+  printer << " (";
+  llvm::interleaveComma(
+      funcType.getParams(), printer.getStream(),
+      [&printer](Type subtype) { dispatchPrint(printer, subtype); });
   if (funcType.isVarArg()) {
     if (funcType.getNumParams() != 0)
-      os << ", ";
-    os << "...";
+      printer << ", ";
+    printer << "...";
   }
-  os << ")>";
+  printer << ")>";
 }
 
 /// Prints the given LLVM dialect type recursively. This leverages closedness of
@@ -129,75 +139,59 @@ static void printFunctionType(llvm::raw_ostream &os, LLVMFunctionType funcType,
 ///   struct<"c", (ptr<struct<"b", (ptr<struct<"c">>)>>,
 ///                ptr<struct<"b", (ptr<struct<"c">>)>>)>
 /// note that "b" is printed twice.
-static void printTypeImpl(llvm::raw_ostream &os, Type type,
-                          llvm::SetVector<StringRef> &stack) {
+void mlir::LLVM::detail::printType(Type type, DialectAsmPrinter &printer) {
   if (!type) {
-    os << "<<NULL-TYPE>>";
+    printer << "<<NULL-TYPE>>";
     return;
   }
 
-  os << getTypeKeyword(type);
+  printer << getTypeKeyword(type);
 
   if (auto intType = type.dyn_cast<LLVMIntegerType>()) {
-    os << intType.getBitWidth();
+    printer << intType.getBitWidth();
     return;
   }
 
   if (auto ptrType = type.dyn_cast<LLVMPointerType>()) {
-    os << '<';
-    printTypeImpl(os, ptrType.getElementType(), stack);
+    printer << '<';
+    dispatchPrint(printer, ptrType.getElementType());
     if (ptrType.getAddressSpace() != 0)
-      os << ", " << ptrType.getAddressSpace();
-    os << '>';
+      printer << ", " << ptrType.getAddressSpace();
+    printer << '>';
     return;
   }
 
   if (auto arrayType = type.dyn_cast<LLVMArrayType>())
-    return printArrayOrVectorType(os, arrayType, stack);
+    return printArrayOrVectorType(printer, arrayType);
   if (auto vectorType = type.dyn_cast<LLVMFixedVectorType>())
-    return printArrayOrVectorType(os, vectorType, stack);
+    return printArrayOrVectorType(printer, vectorType);
 
   if (auto vectorType = type.dyn_cast<LLVMScalableVectorType>()) {
-    os << "<? x " << vectorType.getMinNumElements() << " x ";
-    printTypeImpl(os, vectorType.getElementType(), stack);
-    os << '>';
+    printer << "<? x " << vectorType.getMinNumElements() << " x ";
+    dispatchPrint(printer, vectorType.getElementType());
+    printer << '>';
     return;
   }
 
   if (auto structType = type.dyn_cast<LLVMStructType>())
-    return printStructType(os, structType, stack);
+    return printStructType(printer, structType);
 
   if (auto funcType = type.dyn_cast<LLVMFunctionType>())
-    return printFunctionType(os, funcType, stack);
-}
-
-void mlir::LLVM::detail::printType(Type type, DialectAsmPrinter &printer) {
-  llvm::SetVector<StringRef> stack;
-  return printTypeImpl(printer.getStream(), type, stack);
+    return printFunctionType(printer, funcType);
 }
 
 //===----------------------------------------------------------------------===//
 // Parsing.
 //===----------------------------------------------------------------------===//
 
-static Type parseTypeImpl(DialectAsmParser &parser,
-                          llvm::SetVector<StringRef> &stack);
-
-/// Helper to be chained with other parsing functions.
-static ParseResult parseTypeImpl(DialectAsmParser &parser,
-                                 llvm::SetVector<StringRef> &stack,
-                                 Type &result) {
-  result = parseTypeImpl(parser, stack);
-  return success(result != nullptr);
-}
+static ParseResult dispatchParse(DialectAsmParser &parser, Type &type);
 
 /// Parses an LLVM dialect function type.
 ///   llvm-type :: = `func<` llvm-type `(` llvm-type-list `...`? `)>`
-static LLVMFunctionType parseFunctionType(DialectAsmParser &parser,
-                                          llvm::SetVector<StringRef> &stack) {
+static LLVMFunctionType parseFunctionType(DialectAsmParser &parser) {
   Location loc = parser.getEncodedSourceLoc(parser.getCurrentLocation());
   Type returnType;
-  if (parser.parseLess() || parseTypeImpl(parser, stack, returnType) ||
+  if (parser.parseLess() || dispatchParse(parser, returnType) ||
       parser.parseLParen())
     return LLVMFunctionType();
 
@@ -219,9 +213,10 @@ static LLVMFunctionType parseFunctionType(DialectAsmParser &parser,
                                           /*isVarArg=*/true);
     }
 
-    argTypes.push_back(parseTypeImpl(parser, stack));
-    if (!argTypes.back())
+    Type arg;
+    if (dispatchParse(parser, arg))
       return LLVMFunctionType();
+    argTypes.push_back(arg);
   } while (succeeded(parser.parseOptionalComma()));
 
   if (parser.parseOptionalRParen() || parser.parseOptionalGreater())
@@ -232,11 +227,10 @@ static LLVMFunctionType parseFunctionType(DialectAsmParser &parser,
 
 /// Parses an LLVM dialect pointer type.
 ///   llvm-type ::= `ptr<` llvm-type (`,` integer)? `>`
-static LLVMPointerType parsePointerType(DialectAsmParser &parser,
-                                        llvm::SetVector<StringRef> &stack) {
+static LLVMPointerType parsePointerType(DialectAsmParser &parser) {
   Location loc = parser.getEncodedSourceLoc(parser.getCurrentLocation());
   Type elementType;
-  if (parser.parseLess() || parseTypeImpl(parser, stack, elementType))
+  if (parser.parseLess() || dispatchParse(parser, elementType))
     return LLVMPointerType();
 
   unsigned addressSpace = 0;
@@ -251,15 +245,14 @@ static LLVMPointerType parsePointerType(DialectAsmParser &parser,
 /// Parses an LLVM dialect vector type.
 ///   llvm-type ::= `vec<` `? x`? integer `x` llvm-type `>`
 /// Supports both fixed and scalable vectors.
-static LLVMVectorType parseVectorType(DialectAsmParser &parser,
-                                      llvm::SetVector<StringRef> &stack) {
+static LLVMVectorType parseVectorType(DialectAsmParser &parser) {
   SmallVector<int64_t, 2> dims;
   llvm::SMLoc dimPos;
   Type elementType;
   Location loc = parser.getEncodedSourceLoc(parser.getCurrentLocation());
   if (parser.parseLess() || parser.getCurrentLocation(&dimPos) ||
       parser.parseDimensionList(dims, /*allowDynamic=*/true) ||
-      parseTypeImpl(parser, stack, elementType) || parser.parseGreater())
+      dispatchParse(parser, elementType) || parser.parseGreater())
     return LLVMVectorType();
 
   // We parsed a generic dimension list, but vectors only support two forms:
@@ -282,15 +275,14 @@ static LLVMVectorType parseVectorType(DialectAsmParser &parser,
 
 /// Parses an LLVM dialect array type.
 ///   llvm-type ::= `array<` integer `x` llvm-type `>`
-static LLVMArrayType parseArrayType(DialectAsmParser &parser,
-                                    llvm::SetVector<StringRef> &stack) {
+static LLVMArrayType parseArrayType(DialectAsmParser &parser) {
   SmallVector<int64_t, 1> dims;
   llvm::SMLoc sizePos;
   Type elementType;
   Location loc = parser.getEncodedSourceLoc(parser.getCurrentLocation());
   if (parser.parseLess() || parser.getCurrentLocation(&sizePos) ||
       parser.parseDimensionList(dims, /*allowDynamic=*/false) ||
-      parseTypeImpl(parser, stack, elementType) || parser.parseGreater())
+      dispatchParse(parser, elementType) || parser.parseGreater())
     return LLVMArrayType();
 
   if (dims.size() != 1) {
@@ -302,13 +294,11 @@ static LLVMArrayType parseArrayType(DialectAsmParser &parser,
 }
 
 /// Attempts to set the body of an identified structure type. Reports a parsing
-/// error at `subtypesLoc` in case of failure, uses `stack` to make sure the
-/// types printed in the error message look like they did when parsed.
+/// error at `subtypesLoc` in case of failure.
 static LLVMStructType trySetStructBody(LLVMStructType type,
                                        ArrayRef<Type> subtypes, bool isPacked,
                                        DialectAsmParser &parser,
-                                       llvm::SMLoc subtypesLoc,
-                                       llvm::SetVector<StringRef> &stack) {
+                                       llvm::SMLoc subtypesLoc) {
   for (Type t : subtypes) {
     if (!LLVMStructType::isValidElementType(t)) {
       parser.emitError(subtypesLoc)
@@ -320,12 +310,8 @@ static LLVMStructType trySetStructBody(LLVMStructType type,
   if (succeeded(type.setBody(subtypes, isPacked)))
     return type;
 
-  std::string currentBody;
-  llvm::raw_string_ostream currentBodyStream(currentBody);
-  printStructTypeBody(currentBodyStream, type, stack);
-  auto diag = parser.emitError(subtypesLoc)
-              << "identified type already used with a 
diff erent body";
-  diag.attachNote() << "existing body: " << currentBodyStream.str();
+  parser.emitError(subtypesLoc)
+      << "identified type already used with a 
diff erent body";
   return LLVMStructType();
 }
 
@@ -334,8 +320,22 @@ static LLVMStructType trySetStructBody(LLVMStructType type,
 ///                 `(` llvm-type-list `)` `>`
 ///               | `struct<` string-literal `>`
 ///               | `struct<` string-literal `, opaque>`
-static LLVMStructType parseStructType(DialectAsmParser &parser,
-                                      llvm::SetVector<StringRef> &stack) {
+static LLVMStructType parseStructType(DialectAsmParser &parser) {
+  // This keeps track of the names of identified structure types that are
+  // currently being parsed. Since such types can refer themselves, this
+  // tracking is necessary to stop the recursion: the current function may be
+  // called recursively from DialectAsmParser::parseType after the appropriate
+  // dispatch. We maintain the invariant of this storage being modified
+  // exclusively in this function, and at most one name being added per call.
+  // TODO: consider having such functionality inside DialectAsmParser.
+  thread_local llvm::SetVector<StringRef> knownStructNames;
+  unsigned stackSize = knownStructNames.size();
+  (void)stackSize;
+  auto guard = llvm::make_scope_exit([&]() {
+    assert(knownStructNames.size() == stackSize &&
+           "malformed identified stack when parsing recursive structs");
+  });
+
   Location loc = parser.getEncodedSourceLoc(parser.getCurrentLocation());
 
   if (failed(parser.parseLess()))
@@ -347,7 +347,7 @@ static LLVMStructType parseStructType(DialectAsmParser &parser,
   StringRef name;
   bool isIdentified = succeeded(parser.parseOptionalString(&name));
   if (isIdentified) {
-    if (stack.count(name)) {
+    if (knownStructNames.count(name)) {
       if (failed(parser.parseGreater()))
         return LLVMStructType();
       return LLVMStructType::getIdentifiedChecked(loc, name);
@@ -384,7 +384,7 @@ static LLVMStructType parseStructType(DialectAsmParser &parser,
     if (!isIdentified)
       return LLVMStructType::getLiteralChecked(loc, {}, isPacked);
     auto type = LLVMStructType::getIdentifiedChecked(loc, name);
-    return trySetStructBody(type, {}, isPacked, parser, kwLoc, stack);
+    return trySetStructBody(type, {}, isPacked, parser, kwLoc);
   }
 
   // Parse subtypes. For identified structs, put the identifier of the struct on
@@ -393,13 +393,13 @@ static LLVMStructType parseStructType(DialectAsmParser &parser,
   llvm::SMLoc subtypesLoc = parser.getCurrentLocation();
   do {
     if (isIdentified)
-      stack.insert(name);
-    Type type = parseTypeImpl(parser, stack);
-    if (!type)
+      knownStructNames.insert(name);
+    Type type;
+    if (dispatchParse(parser, type))
       return LLVMStructType();
     subtypes.push_back(type);
     if (isIdentified)
-      stack.pop_back();
+      knownStructNames.pop_back();
   } while (succeeded(parser.parseOptionalComma()));
 
   if (parser.parseRParen() || parser.parseGreater())
@@ -409,30 +409,30 @@ static LLVMStructType parseStructType(DialectAsmParser &parser,
   if (!isIdentified)
     return LLVMStructType::getLiteralChecked(loc, subtypes, isPacked);
   auto type = LLVMStructType::getIdentifiedChecked(loc, name);
-  return trySetStructBody(type, subtypes, isPacked, parser, subtypesLoc, stack);
+  return trySetStructBody(type, subtypes, isPacked, parser, subtypesLoc);
 }
 
-/// Parses one of the LLVM dialect types.
-static Type parseTypeImpl(DialectAsmParser &parser,
-                          llvm::SetVector<StringRef> &stack) {
-  // Special case for integers (i[1-9][0-9]*) that are literals rather than
-  // keywords for the parser, so they are not caught by the main dispatch below.
-  // Try parsing it a built-in integer type instead.
-  Type maybeIntegerType;
-  MLIRContext *ctx = parser.getBuilder().getContext();
+/// Parses a type appearing inside another LLVM dialect-compatible type. This
+/// will try to parse any type in full form (including types with the `!llvm`
+/// prefix), and on failure fall back to parsing the short-hand version of the
+/// LLVM dialect types without the `!llvm` prefix.
+static Type dispatchParse(DialectAsmParser &parser) {
+  Type type;
   llvm::SMLoc keyLoc = parser.getCurrentLocation();
   Location loc = parser.getEncodedSourceLoc(keyLoc);
-  OptionalParseResult result = parser.parseOptionalType(maybeIntegerType);
-  if (result.hasValue()) {
-    if (failed(*result))
+  OptionalParseResult parseResult = parser.parseOptionalType(type);
+  if (parseResult.hasValue()) {
+    if (failed(*parseResult))
       return Type();
 
-    if (!maybeIntegerType.isSignlessInteger()) {
-      parser.emitError(keyLoc) << "unexpected type, expected i* or keyword";
-      return Type();
-    }
-    return LLVMIntegerType::getChecked(
-        loc, maybeIntegerType.getIntOrFloatBitWidth());
+    // Special case for integers (i[1-9][0-9]*) that are literals rather than
+    // keywords for the parser, so they are not caught by the main dispatch
+    // below. Try parsing it a built-in integer type instead.
+    auto intType = type.dyn_cast<IntegerType>();
+    if (!intType || !intType.isSignless())
+      return type;
+
+    return LLVMIntegerType::getChecked(loc, intType.getWidth());
   }
 
   // Dispatch to concrete functions.
@@ -440,6 +440,7 @@ static Type parseTypeImpl(DialectAsmParser &parser,
   if (failed(parser.parseKeyword(&key)))
     return Type();
 
+  MLIRContext *ctx = parser.getBuilder().getContext();
   return StringSwitch<function_ref<Type()>>(key)
       .Case("void", [&] { return LLVMVoidType::get(ctx); })
       .Case("half", [&] { return LLVMHalfType::get(ctx); })
@@ -453,18 +454,32 @@ static Type parseTypeImpl(DialectAsmParser &parser,
       .Case("token", [&] { return LLVMTokenType::get(ctx); })
       .Case("label", [&] { return LLVMLabelType::get(ctx); })
       .Case("metadata", [&] { return LLVMMetadataType::get(ctx); })
-      .Case("func", [&] { return parseFunctionType(parser, stack); })
-      .Case("ptr", [&] { return parsePointerType(parser, stack); })
-      .Case("vec", [&] { return parseVectorType(parser, stack); })
-      .Case("array", [&] { return parseArrayType(parser, stack); })
-      .Case("struct", [&] { return parseStructType(parser, stack); })
+      .Case("func", [&] { return parseFunctionType(parser); })
+      .Case("ptr", [&] { return parsePointerType(parser); })
+      .Case("vec", [&] { return parseVectorType(parser); })
+      .Case("array", [&] { return parseArrayType(parser); })
+      .Case("struct", [&] { return parseStructType(parser); })
       .Default([&] {
         parser.emitError(keyLoc) << "unknown LLVM type: " << key;
         return Type();
       })();
 }
 
+/// Helper to use in parse lists.
+static ParseResult dispatchParse(DialectAsmParser &parser, Type &type) {
+  type = dispatchParse(parser);
+  return success(type != nullptr);
+}
+
+/// Parses one of the LLVM dialect types.
 Type mlir::LLVM::detail::parseType(DialectAsmParser &parser) {
-  llvm::SetVector<StringRef> stack;
-  return parseTypeImpl(parser, stack);
+  llvm::SMLoc loc = parser.getCurrentLocation();
+  Type type = dispatchParse(parser);
+  if (!type)
+    return type;
+  if (!isCompatibleType(type)) {
+    parser.emitError(loc) << "unexpected type, expected i* or keyword";
+    return nullptr;
+  }
+  return type;
 }

diff  --git a/mlir/test/Dialect/LLVMIR/types-invalid.mlir b/mlir/test/Dialect/LLVMIR/types-invalid.mlir
index 29da75aae584..3277e177bc9b 100644
--- a/mlir/test/Dialect/LLVMIR/types-invalid.mlir
+++ b/mlir/test/Dialect/LLVMIR/types-invalid.mlir
@@ -30,8 +30,7 @@ func @void_pointer() {
 
 func @repeated_struct_name() {
   "some.op"() : () -> !llvm.struct<"a", (ptr<struct<"a">>)>
-  // expected-error @+2 {{identified type already used with a 
diff erent body}}
-  // expected-note @+1 {{existing body: (ptr<struct<"a">>)}}
+  // expected-error @+1 {{identified type already used with a 
diff erent body}}
   "some.op"() : () -> !llvm.struct<"a", (i32)>
 }
 
@@ -39,8 +38,7 @@ func @repeated_struct_name() {
 
 func @repeated_struct_name_packed() {
   "some.op"() : () -> !llvm.struct<"a", packed (i32)>
-  // expected-error @+2 {{identified type already used with a 
diff erent body}}
-  // expected-note @+1 {{existing body: packed (i32)}}
+  // expected-error @+1 {{identified type already used with a 
diff erent body}}
   "some.op"() : () -> !llvm.struct<"a", (i32)>
 }
 
@@ -48,8 +46,7 @@ func @repeated_struct_name_packed() {
 
 func @repeated_struct_opaque() {
   "some.op"() : () -> !llvm.struct<"a", opaque>
-  // expected-error @+2 {{identified type already used with a 
diff erent body}}
-  // expected-note @+1 {{existing body: opaque}}
+  // expected-error @+1 {{identified type already used with a 
diff erent body}}
   "some.op"() : () -> !llvm.struct<"a", ()>
 }
 
@@ -57,8 +54,7 @@ func @repeated_struct_opaque() {
 
 func @repeated_struct_opaque_non_empty() {
   "some.op"() : () -> !llvm.struct<"a", opaque>
-  // expected-error @+2 {{identified type already used with a 
diff erent body}}
-  // expected-note @+1 {{existing body: opaque}}
+  // expected-error @+1 {{identified type already used with a 
diff erent body}}
   "some.op"() : () -> !llvm.struct<"a", (i32, i32)>
 }
 
@@ -95,8 +91,7 @@ func @unexpected_type() {
 
 func @explicitly_opaque_struct() {
   "some.op"() : () -> !llvm.struct<"a", opaque>
-  // expected-error @+2 {{identified type already used with a 
diff erent body}}
-  // expected-note @+1 {{existing body: opaque}}
+  // expected-error @+1 {{identified type already used with a 
diff erent body}}
   "some.op"() : () -> !llvm.struct<"a", ()>
 }
 

diff  --git a/mlir/test/Dialect/LLVMIR/types.mlir b/mlir/test/Dialect/LLVMIR/types.mlir
index bd24c68b8883..5258158efb69 100644
--- a/mlir/test/Dialect/LLVMIR/types.mlir
+++ b/mlir/test/Dialect/LLVMIR/types.mlir
@@ -182,3 +182,29 @@ func @identified_struct() {
   return
 }
 
+func @verbose() {
+  // CHECK: !llvm.struct<(i64, struct<(float)>)>
+  "some.op"() : () -> !llvm.struct<(!llvm.i64, !llvm.struct<(!llvm.float)>)>
+  return
+}
+
+// -----
+
+// Check that type aliases can be used inside LLVM dialect types. Note that
+// currently they are _not_ printed back as this would require
+// DialectAsmPrinter to have a mechanism for querying the presence and
+// usability of an alias outside of its `printType` method.
+
+!baz = type !llvm.i64
+!qux = type !llvm.struct<(!baz)>
+
+!rec = type !llvm.struct<"a", (ptr<struct<"a">>)>
+
+// CHECK: aliases
+llvm.func @aliases() {
+  // CHECK: !llvm.struct<(i32, float, struct<(i64)>)>
+  "some.op"() : () -> !llvm.struct<(i32, float, !qux)>
+  // CHECK: !llvm.struct<"a", (ptr<struct<"a">>)>
+  "some.op"() : () -> !rec
+  llvm.return
+}


        


More information about the llvm-branch-commits mailing list