[Mlir-commits] [mlir] [mlir] Support better printing for mutually recursive types (PR #112160)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sun Oct 13 22:38:44 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-core

Author: Shoaib Meenai (smeenai)

<details>
<summary>Changes</summary>

For mutually recursive types, the current way types are printed forces
the earlier type alias to include a full definition of the later type.
Many recursive types (e.g. structs in ClangIR) have a notion of an
incomplete type definition, and by exposing a simple hook in the
AsmPrinter to determine whether a type will be printed in the future, we
can enable dialects to use incomplete type definitions (which they know
will be completed later) when printing mutually recursive types instead,
which makes them much easier to read.


---
Full diff: https://github.com/llvm/llvm-project/pull/112160.diff


4 Files Affected:

- (modified) mlir/include/mlir/IR/OpImplementation.h (+5) 
- (modified) mlir/lib/IR/AsmPrinter.cpp (+37-3) 
- (modified) mlir/test/IR/recursive-type.mlir (+13-4) 
- (modified) mlir/test/lib/Dialect/Test/TestTypes.cpp (+5-1) 


``````````diff
diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h
index e2472eea8a3714..2b79727d8c9325 100644
--- a/mlir/include/mlir/IR/OpImplementation.h
+++ b/mlir/include/mlir/IR/OpImplementation.h
@@ -188,6 +188,11 @@ class AsmPrinter {
   /// be printed.
   virtual LogicalResult printAlias(Type type);
 
+  /// Check if the given type has an alias that will be printed in the future.
+  /// Returns false if the type has an alias that's currently being printed or
+  /// has already been printed. This can aid printing mutually recursive types.
+  virtual bool hasFutureAlias(Type type) const;
+
   /// Print the given string as a keyword, or a quoted and escaped string if it
   /// has any special or non-printable characters in it.
   virtual void printKeywordOrString(StringRef keyword);
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index a728425f2ec6ba..2c5e0f5b92a4e4 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -430,6 +430,11 @@ class AsmPrinter::Impl {
   /// be printed.
   LogicalResult printAlias(Type type);
 
+  /// Check if the given type has an alias that will be printed in the future.
+  /// Returns false if the type has an alias that's currently being printed or
+  /// has already been printed. This can aid printing mutually recursive types.
+  bool hasFutureAlias(Type type) const;
+
   /// Print the given location to the stream. If `allowAlias` is true, this
   /// allows for the internal location to use an attribute alias.
   void printLocation(LocationAttr loc, bool allowAlias = false);
@@ -547,8 +552,13 @@ class SymbolAlias {
   bool isDeferrable : 1;
 
 public:
+  /// Used to distinguish aliases that are currently being or have previously
+  /// been printed from those that will be printed in the future, which can aid
+  /// printing mutually recursive types.
+  bool hasStartedPrinting = false;
+
   /// Used to avoid printing incomplete aliases for recursive types.
-  bool isPrinted = false;
+  bool hasFinishedPrinting = false;
 };
 
 /// This class represents a utility that initializes the set of attribute and
@@ -774,6 +784,7 @@ class DummyAliasOperationPrinter : private OpAsmPrinter {
     initializer.visit(type);
     return success();
   }
+  bool hasFutureAlias(Type) const override { return false; }
 
   /// Consider the given location to be printed for an alias.
   void printOptionalLocationSpecifier(Location loc) override {
@@ -948,6 +959,7 @@ class DummyAliasDialectAsmPrinter : public DialectAsmPrinter {
     printType(type);
     return success();
   }
+  bool hasFutureAlias(Type) const override { return false; }
 
   /// Record the alias result of a child element.
   void recordAliasResult(std::pair<size_t, size_t> aliasDepthAndIndex) {
@@ -1182,6 +1194,11 @@ class AliasState {
   /// Returns success if an alias was printed, failure otherwise.
   LogicalResult getAlias(Type ty, raw_ostream &os) const;
 
+  /// Check if the given type has an alias that will be printed in the future.
+  /// Returns false if the type has an alias that's currently being printed or
+  /// has already been printed. This can aid printing mutually recursive types.
+  bool hasFutureAlias(Type ty) const;
+
   /// Print all of the referenced aliases that can not be resolved in a deferred
   /// manner.
   void printNonDeferredAliases(AsmPrinter::Impl &p, NewLineCounter &newLine) {
@@ -1226,13 +1243,20 @@ LogicalResult AliasState::getAlias(Type ty, raw_ostream &os) const {
   const auto *it = attrTypeToAlias.find(ty.getAsOpaquePointer());
   if (it == attrTypeToAlias.end())
     return failure();
-  if (!it->second.isPrinted)
+  if (!it->second.hasFinishedPrinting)
     return failure();
 
   it->second.print(os);
   return success();
 }
 
+bool AliasState::hasFutureAlias(Type ty) const {
+  const auto *it = attrTypeToAlias.find(ty.getAsOpaquePointer());
+  if (it == attrTypeToAlias.end())
+    return false;
+  return !it->second.hasStartedPrinting;
+}
+
 void AliasState::printAliases(AsmPrinter::Impl &p, NewLineCounter &newLine,
                               bool isDeferred) {
   auto filterFn = [=](const auto &aliasIt) {
@@ -1245,8 +1269,9 @@ void AliasState::printAliases(AsmPrinter::Impl &p, NewLineCounter &newLine,
 
     if (alias.isTypeAlias()) {
       Type type = Type::getFromOpaquePointer(opaqueSymbol);
+      alias.hasStartedPrinting = true;
       p.printTypeImpl(type);
-      alias.isPrinted = true;
+      alias.hasFinishedPrinting = true;
     } else {
       // TODO: Support nested aliases in mutable attributes.
       Attribute attr = Attribute::getFromOpaquePointer(opaqueSymbol);
@@ -2234,6 +2259,10 @@ LogicalResult AsmPrinter::Impl::printAlias(Type type) {
   return state.getAliasState().getAlias(type, os);
 }
 
+bool AsmPrinter::Impl::hasFutureAlias(Type type) const {
+  return state.getAliasState().hasFutureAlias(type);
+}
+
 void AsmPrinter::Impl::printAttribute(Attribute attr,
                                       AttrTypeElision typeElision) {
   if (!attr) {
@@ -2832,6 +2861,11 @@ LogicalResult AsmPrinter::printAlias(Type type) {
   return impl->printAlias(type);
 }
 
+bool AsmPrinter::hasFutureAlias(Type type) const {
+  assert(impl && "expected AsmPrinter::hasFutureAlias to be overridden");
+  return impl->hasFutureAlias(type);
+}
+
 void AsmPrinter::printAttributeWithoutType(Attribute attr) {
   assert(impl &&
          "expected AsmPrinter::printAttributeWithoutType to be overriden");
diff --git a/mlir/test/IR/recursive-type.mlir b/mlir/test/IR/recursive-type.mlir
index 42aecb41d998d1..c5d0cd09b220f1 100644
--- a/mlir/test/IR/recursive-type.mlir
+++ b/mlir/test/IR/recursive-type.mlir
@@ -2,10 +2,12 @@
 
 // CHECK: !testrec = !test.test_rec<type_to_alias, test_rec<type_to_alias>>
 // CHECK: ![[$NAME:.*]] = !test.test_rec_alias<name, !test.test_rec_alias<name>>
-// CHECK: ![[$NAME5:.*]] = !test.test_rec_alias<name5, !test.test_rec_alias<name3, !test.test_rec_alias<name4, !test.test_rec_alias<name5>>>>
+// CHECK: ![[$NAME5:.*]] = !test.test_rec_alias<name5, !test.test_rec_alias<name3>>
+// CHECK: ![[$NAME7:.*]] = !test.test_rec_alias<name7, !test.test_rec_alias<name6>>
 // CHECK: ![[$NAME2:.*]] = !test.test_rec_alias<name2, tuple<!test.test_rec_alias<name2>, i32>>
-// CHECK: ![[$NAME4:.*]] = !test.test_rec_alias<name4, !name5_>
-// CHECK: ![[$NAME3:.*]] = !test.test_rec_alias<name3, !name4_>
+// CHECK: ![[$NAME4:.*]] = !test.test_rec_alias<name4, ![[$NAME5]]>
+// CHECK: ![[$NAME6:.*]] = !test.test_rec_alias<name6, ![[$NAME7]]>
+// CHECK: ![[$NAME3:.*]] = !test.test_rec_alias<name3, ![[$NAME4]]>
 
 // CHECK-LABEL: @roundtrip
 func.func @roundtrip() {
@@ -28,13 +30,20 @@ func.func @roundtrip() {
   "test.dummy_op_for_roundtrip"() : () -> !test.test_rec_alias<name2, tuple<!test.test_rec_alias<name2>, i32>>
   "test.dummy_op_for_roundtrip"() : () -> !test.test_rec_alias<name2, tuple<!test.test_rec_alias<name2>, i32>>
 
-  // Mutual recursion.
+  // Mutual recursion with types fully spelled out.
   // CHECK: () -> ![[$NAME3]]
   // CHECK: () -> ![[$NAME4]]
   // CHECK: () -> ![[$NAME5]]
   "test.dummy_op_for_roundtrip"() : () -> !test.test_rec_alias<name3, !test.test_rec_alias<name4, !test.test_rec_alias<name5, !test.test_rec_alias<name3>>>>
   "test.dummy_op_for_roundtrip"() : () -> !test.test_rec_alias<name4, !test.test_rec_alias<name5, !test.test_rec_alias<name3, !test.test_rec_alias<name4>>>>
   "test.dummy_op_for_roundtrip"() : () -> !test.test_rec_alias<name5, !test.test_rec_alias<name3, !test.test_rec_alias<name4, !test.test_rec_alias<name5>>>>
+
+  // Mutual recursion with incomplete types.
+  // CHECK: () -> ![[$NAME6]]
+  // CHECK: () -> ![[$NAME7]]
+  "test.dummy_op_for_roundtrip"() : () -> !test.test_rec_alias<name6, !test.test_rec_alias<name7>>
+  "test.dummy_op_for_roundtrip"() : () -> !test.test_rec_alias<name7, !test.test_rec_alias<name6>>
+
   return
 }
 
diff --git a/mlir/test/lib/Dialect/Test/TestTypes.cpp b/mlir/test/lib/Dialect/Test/TestTypes.cpp
index 1593b6d7d7534b..94576544ba3069 100644
--- a/mlir/test/lib/Dialect/Test/TestTypes.cpp
+++ b/mlir/test/lib/Dialect/Test/TestTypes.cpp
@@ -505,6 +505,10 @@ Type TestRecursiveAliasType::parse(AsmParser &parser) {
     return rec;
   }
 
+  // Allow incomplete definitions that can be completed later.
+  if (succeeded(parser.parseGreater()))
+    return rec;
+
   // Otherwise, parse the body and update the type.
   if (failed(parser.parseComma()))
     return Type();
@@ -525,7 +529,7 @@ void TestRecursiveAliasType::print(AsmPrinter &printer) const {
       printer.tryStartCyclicPrint(*this);
 
   printer << "<" << getName();
-  if (succeeded(cyclicPrint)) {
+  if (succeeded(cyclicPrint) && !printer.hasFutureAlias(*this)) {
     printer << ", ";
     printer << getBody();
   }

``````````

</details>


https://github.com/llvm/llvm-project/pull/112160


More information about the Mlir-commits mailing list