[Mlir-commits] [mlir] [mlir] Support better printing for mutually recursive types (PR #112160)
Shoaib Meenai
llvmlistbot at llvm.org
Mon Oct 14 16:48:12 PDT 2024
https://github.com/smeenai updated https://github.com/llvm/llvm-project/pull/112160
>From 9aeefa8eef91b9966bd5b0226bd70e59c783e0fc Mon Sep 17 00:00:00 2001
From: Shoaib Meenai <smeenai at fb.com>
Date: Fri, 11 Oct 2024 22:02:29 -0700
Subject: [PATCH 1/2] [mlir] Support better printing for mutually recursive
types
For mutually recursive types, the current way types are printed forces
the earlier type alias to include a full definition of the later type.
Many recursive types (e.g. structs in ClangIR) have a notion of an
incomplete type definition, and by exposing a simple hook in the
AsmPrinter to determine whether a type will be printed in the future, we
can enable dialects to use incomplete type definitions (which they know
will be completed later) when printing mutually recursive types instead,
which makes them much easier to read.
---
mlir/include/mlir/IR/OpImplementation.h | 5 +++
mlir/lib/IR/AsmPrinter.cpp | 40 ++++++++++++++++++++++--
mlir/test/IR/recursive-type.mlir | 17 +++++++---
mlir/test/lib/Dialect/Test/TestTypes.cpp | 6 +++-
4 files changed, 60 insertions(+), 8 deletions(-)
diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h
index e2472eea8a3714..2b79727d8c9325 100644
--- a/mlir/include/mlir/IR/OpImplementation.h
+++ b/mlir/include/mlir/IR/OpImplementation.h
@@ -188,6 +188,11 @@ class AsmPrinter {
/// be printed.
virtual LogicalResult printAlias(Type type);
+ /// Check if the given type has an alias that will be printed in the future.
+ /// Returns false if the type has an alias that's currently being printed or
+ /// has already been printed. This can aid printing mutually recursive types.
+ virtual bool hasFutureAlias(Type type) const;
+
/// Print the given string as a keyword, or a quoted and escaped string if it
/// has any special or non-printable characters in it.
virtual void printKeywordOrString(StringRef keyword);
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index a728425f2ec6ba..2c5e0f5b92a4e4 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -430,6 +430,11 @@ class AsmPrinter::Impl {
/// be printed.
LogicalResult printAlias(Type type);
+ /// Check if the given type has an alias that will be printed in the future.
+ /// Returns false if the type has an alias that's currently being printed or
+ /// has already been printed. This can aid printing mutually recursive types.
+ bool hasFutureAlias(Type type) const;
+
/// Print the given location to the stream. If `allowAlias` is true, this
/// allows for the internal location to use an attribute alias.
void printLocation(LocationAttr loc, bool allowAlias = false);
@@ -547,8 +552,13 @@ class SymbolAlias {
bool isDeferrable : 1;
public:
+ /// Used to distinguish aliases that are currently being or have previously
+ /// been printed from those that will be printed in the future, which can aid
+ /// printing mutually recursive types.
+ bool hasStartedPrinting = false;
+
/// Used to avoid printing incomplete aliases for recursive types.
- bool isPrinted = false;
+ bool hasFinishedPrinting = false;
};
/// This class represents a utility that initializes the set of attribute and
@@ -774,6 +784,7 @@ class DummyAliasOperationPrinter : private OpAsmPrinter {
initializer.visit(type);
return success();
}
+ bool hasFutureAlias(Type) const override { return false; }
/// Consider the given location to be printed for an alias.
void printOptionalLocationSpecifier(Location loc) override {
@@ -948,6 +959,7 @@ class DummyAliasDialectAsmPrinter : public DialectAsmPrinter {
printType(type);
return success();
}
+ bool hasFutureAlias(Type) const override { return false; }
/// Record the alias result of a child element.
void recordAliasResult(std::pair<size_t, size_t> aliasDepthAndIndex) {
@@ -1182,6 +1194,11 @@ class AliasState {
/// Returns success if an alias was printed, failure otherwise.
LogicalResult getAlias(Type ty, raw_ostream &os) const;
+ /// Check if the given type has an alias that will be printed in the future.
+ /// Returns false if the type has an alias that's currently being printed or
+ /// has already been printed. This can aid printing mutually recursive types.
+ bool hasFutureAlias(Type ty) const;
+
/// Print all of the referenced aliases that can not be resolved in a deferred
/// manner.
void printNonDeferredAliases(AsmPrinter::Impl &p, NewLineCounter &newLine) {
@@ -1226,13 +1243,20 @@ LogicalResult AliasState::getAlias(Type ty, raw_ostream &os) const {
const auto *it = attrTypeToAlias.find(ty.getAsOpaquePointer());
if (it == attrTypeToAlias.end())
return failure();
- if (!it->second.isPrinted)
+ if (!it->second.hasFinishedPrinting)
return failure();
it->second.print(os);
return success();
}
+bool AliasState::hasFutureAlias(Type ty) const {
+ const auto *it = attrTypeToAlias.find(ty.getAsOpaquePointer());
+ if (it == attrTypeToAlias.end())
+ return false;
+ return !it->second.hasStartedPrinting;
+}
+
void AliasState::printAliases(AsmPrinter::Impl &p, NewLineCounter &newLine,
bool isDeferred) {
auto filterFn = [=](const auto &aliasIt) {
@@ -1245,8 +1269,9 @@ void AliasState::printAliases(AsmPrinter::Impl &p, NewLineCounter &newLine,
if (alias.isTypeAlias()) {
Type type = Type::getFromOpaquePointer(opaqueSymbol);
+ alias.hasStartedPrinting = true;
p.printTypeImpl(type);
- alias.isPrinted = true;
+ alias.hasFinishedPrinting = true;
} else {
// TODO: Support nested aliases in mutable attributes.
Attribute attr = Attribute::getFromOpaquePointer(opaqueSymbol);
@@ -2234,6 +2259,10 @@ LogicalResult AsmPrinter::Impl::printAlias(Type type) {
return state.getAliasState().getAlias(type, os);
}
+bool AsmPrinter::Impl::hasFutureAlias(Type type) const {
+ return state.getAliasState().hasFutureAlias(type);
+}
+
void AsmPrinter::Impl::printAttribute(Attribute attr,
AttrTypeElision typeElision) {
if (!attr) {
@@ -2832,6 +2861,11 @@ LogicalResult AsmPrinter::printAlias(Type type) {
return impl->printAlias(type);
}
+bool AsmPrinter::hasFutureAlias(Type type) const {
+ assert(impl && "expected AsmPrinter::hasFutureAlias to be overridden");
+ return impl->hasFutureAlias(type);
+}
+
void AsmPrinter::printAttributeWithoutType(Attribute attr) {
assert(impl &&
"expected AsmPrinter::printAttributeWithoutType to be overriden");
diff --git a/mlir/test/IR/recursive-type.mlir b/mlir/test/IR/recursive-type.mlir
index 42aecb41d998d1..c5d0cd09b220f1 100644
--- a/mlir/test/IR/recursive-type.mlir
+++ b/mlir/test/IR/recursive-type.mlir
@@ -2,10 +2,12 @@
// CHECK: !testrec = !test.test_rec<type_to_alias, test_rec<type_to_alias>>
// CHECK: ![[$NAME:.*]] = !test.test_rec_alias<name, !test.test_rec_alias<name>>
-// CHECK: ![[$NAME5:.*]] = !test.test_rec_alias<name5, !test.test_rec_alias<name3, !test.test_rec_alias<name4, !test.test_rec_alias<name5>>>>
+// CHECK: ![[$NAME5:.*]] = !test.test_rec_alias<name5, !test.test_rec_alias<name3>>
+// CHECK: ![[$NAME7:.*]] = !test.test_rec_alias<name7, !test.test_rec_alias<name6>>
// CHECK: ![[$NAME2:.*]] = !test.test_rec_alias<name2, tuple<!test.test_rec_alias<name2>, i32>>
-// CHECK: ![[$NAME4:.*]] = !test.test_rec_alias<name4, !name5_>
-// CHECK: ![[$NAME3:.*]] = !test.test_rec_alias<name3, !name4_>
+// CHECK: ![[$NAME4:.*]] = !test.test_rec_alias<name4, ![[$NAME5]]>
+// CHECK: ![[$NAME6:.*]] = !test.test_rec_alias<name6, ![[$NAME7]]>
+// CHECK: ![[$NAME3:.*]] = !test.test_rec_alias<name3, ![[$NAME4]]>
// CHECK-LABEL: @roundtrip
func.func @roundtrip() {
@@ -28,13 +30,20 @@ func.func @roundtrip() {
"test.dummy_op_for_roundtrip"() : () -> !test.test_rec_alias<name2, tuple<!test.test_rec_alias<name2>, i32>>
"test.dummy_op_for_roundtrip"() : () -> !test.test_rec_alias<name2, tuple<!test.test_rec_alias<name2>, i32>>
- // Mutual recursion.
+ // Mutual recursion with types fully spelled out.
// CHECK: () -> ![[$NAME3]]
// CHECK: () -> ![[$NAME4]]
// CHECK: () -> ![[$NAME5]]
"test.dummy_op_for_roundtrip"() : () -> !test.test_rec_alias<name3, !test.test_rec_alias<name4, !test.test_rec_alias<name5, !test.test_rec_alias<name3>>>>
"test.dummy_op_for_roundtrip"() : () -> !test.test_rec_alias<name4, !test.test_rec_alias<name5, !test.test_rec_alias<name3, !test.test_rec_alias<name4>>>>
"test.dummy_op_for_roundtrip"() : () -> !test.test_rec_alias<name5, !test.test_rec_alias<name3, !test.test_rec_alias<name4, !test.test_rec_alias<name5>>>>
+
+ // Mutual recursion with incomplete types.
+ // CHECK: () -> ![[$NAME6]]
+ // CHECK: () -> ![[$NAME7]]
+ "test.dummy_op_for_roundtrip"() : () -> !test.test_rec_alias<name6, !test.test_rec_alias<name7>>
+ "test.dummy_op_for_roundtrip"() : () -> !test.test_rec_alias<name7, !test.test_rec_alias<name6>>
+
return
}
diff --git a/mlir/test/lib/Dialect/Test/TestTypes.cpp b/mlir/test/lib/Dialect/Test/TestTypes.cpp
index 1593b6d7d7534b..94576544ba3069 100644
--- a/mlir/test/lib/Dialect/Test/TestTypes.cpp
+++ b/mlir/test/lib/Dialect/Test/TestTypes.cpp
@@ -505,6 +505,10 @@ Type TestRecursiveAliasType::parse(AsmParser &parser) {
return rec;
}
+ // Allow incomplete definitions that can be completed later.
+ if (succeeded(parser.parseGreater()))
+ return rec;
+
// Otherwise, parse the body and update the type.
if (failed(parser.parseComma()))
return Type();
@@ -525,7 +529,7 @@ void TestRecursiveAliasType::print(AsmPrinter &printer) const {
printer.tryStartCyclicPrint(*this);
printer << "<" << getName();
- if (succeeded(cyclicPrint)) {
+ if (succeeded(cyclicPrint) && !printer.hasFutureAlias(*this)) {
printer << ", ";
printer << getBody();
}
>From 75d62775d20ac038edb8cc7a7ffb5e54fdc73026 Mon Sep 17 00:00:00 2001
From: Shoaib Meenai <smeenai at fb.com>
Date: Mon, 14 Oct 2024 16:38:23 -0700
Subject: [PATCH 2/2] Incorporate into tryStartCyclicPrint
---
mlir/include/mlir/IR/OpImplementation.h | 20 ++++++++----
mlir/lib/IR/AsmPrinter.cpp | 41 +++++++++++-------------
mlir/test/lib/Dialect/Test/TestTypes.cpp | 2 +-
3 files changed, 33 insertions(+), 30 deletions(-)
diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h
index 2b79727d8c9325..b62b6706df9671 100644
--- a/mlir/include/mlir/IR/OpImplementation.h
+++ b/mlir/include/mlir/IR/OpImplementation.h
@@ -188,11 +188,6 @@ class AsmPrinter {
/// be printed.
virtual LogicalResult printAlias(Type type);
- /// Check if the given type has an alias that will be printed in the future.
- /// Returns false if the type has an alias that's currently being printed or
- /// has already been printed. This can aid printing mutually recursive types.
- virtual bool hasFutureAlias(Type type) const;
-
/// Print the given string as a keyword, or a quoted and escaped string if it
/// has any special or non-printable characters in it.
virtual void printKeywordOrString(StringRef keyword);
@@ -270,8 +265,11 @@ class AsmPrinter {
/// Attempts to start a cyclic printing region for `attrOrType`.
/// A cyclic printing region starts with this call and ends with the
/// destruction of the returned `CyclicPrintReset`. During this time,
- /// calling `tryStartCyclicPrint` with the same attribute in any printer
- /// will lead to returning failure.
+ /// calling `tryStartCyclicPrint` with the same attribute or type in any
+ /// printer will lead to returning failure. Additionally, if the printer
+ /// knows a complete definition of the attribute or type will be emitted in
+ /// the future, it'll also return failure to permit abbreviated definitions
+ /// to be used wherever possible.
///
/// This makes it possible to break infinite recursions when trying to print
/// cyclic attributes or types by printing only immutable parameters if nested
@@ -283,6 +281,8 @@ class AsmPrinter {
AttrOrTypeT> ||
std::is_base_of_v<TypeTrait::IsMutable<AttrOrTypeT>, AttrOrTypeT>,
"Only mutable attributes or types can be cyclic");
+ if (hasFutureAlias(attrOrType.getAsOpaquePointer()))
+ return failure();
if (failed(pushCyclicPrinting(attrOrType.getAsOpaquePointer())))
return failure();
return CyclicPrintReset(this);
@@ -304,6 +304,12 @@ class AsmPrinter {
/// in reverse order of all successful `pushCyclicPrinting`.
virtual void popCyclicPrinting();
+ /// Check if the given attribute or type (in the form of a type erased
+ /// pointer) will be printed as an alias in the future. Returns false if the
+ /// type has an alias that's currently being printed or has already been
+ /// printed. This enables cyclic print checking for mutual recursion.
+ virtual bool hasFutureAlias(const void *opaquePointer) const;
+
private:
AsmPrinter(const AsmPrinter &) = delete;
void operator=(const AsmPrinter &) = delete;
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 2c5e0f5b92a4e4..a62443932ddd12 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -430,11 +430,6 @@ class AsmPrinter::Impl {
/// be printed.
LogicalResult printAlias(Type type);
- /// Check if the given type has an alias that will be printed in the future.
- /// Returns false if the type has an alias that's currently being printed or
- /// has already been printed. This can aid printing mutually recursive types.
- bool hasFutureAlias(Type type) const;
-
/// Print the given location to the stream. If `allowAlias` is true, this
/// allows for the internal location to use an attribute alias.
void printLocation(LocationAttr loc, bool allowAlias = false);
@@ -454,6 +449,8 @@ class AsmPrinter::Impl {
void popCyclicPrinting();
+ bool hasFutureAlias(const void *opaquePointer) const;
+
void printDimensionList(ArrayRef<int64_t> shape);
protected:
@@ -784,7 +781,6 @@ class DummyAliasOperationPrinter : private OpAsmPrinter {
initializer.visit(type);
return success();
}
- bool hasFutureAlias(Type) const override { return false; }
/// Consider the given location to be printed for an alias.
void printOptionalLocationSpecifier(Location loc) override {
@@ -959,7 +955,6 @@ class DummyAliasDialectAsmPrinter : public DialectAsmPrinter {
printType(type);
return success();
}
- bool hasFutureAlias(Type) const override { return false; }
/// Record the alias result of a child element.
void recordAliasResult(std::pair<size_t, size_t> aliasDepthAndIndex) {
@@ -986,6 +981,8 @@ class DummyAliasDialectAsmPrinter : public DialectAsmPrinter {
void popCyclicPrinting() override { cyclicPrintingStack.pop_back(); }
+ bool hasFutureAlias(const void *) const override { return false; }
+
/// Stack of potentially cyclic mutable attributes or type currently being
/// printed.
SetVector<const void *> cyclicPrintingStack;
@@ -1194,10 +1191,11 @@ class AliasState {
/// Returns success if an alias was printed, failure otherwise.
LogicalResult getAlias(Type ty, raw_ostream &os) const;
- /// Check if the given type has an alias that will be printed in the future.
- /// Returns false if the type has an alias that's currently being printed or
- /// has already been printed. This can aid printing mutually recursive types.
- bool hasFutureAlias(Type ty) const;
+ /// Check if the given attribute or type (in the form of a type erased
+ /// pointer) will be printed as an alias in the future. Returns false if the
+ /// type has an alias that's currently being printed or has already been
+ /// printed. This enables cyclic print checking for mutual recursion.
+ bool hasFutureAlias(const void *opaquePointer) const;
/// Print all of the referenced aliases that can not be resolved in a deferred
/// manner.
@@ -1250,8 +1248,8 @@ LogicalResult AliasState::getAlias(Type ty, raw_ostream &os) const {
return success();
}
-bool AliasState::hasFutureAlias(Type ty) const {
- const auto *it = attrTypeToAlias.find(ty.getAsOpaquePointer());
+bool AliasState::hasFutureAlias(const void *opaquePointer) const {
+ const auto *it = attrTypeToAlias.find(opaquePointer);
if (it == attrTypeToAlias.end())
return false;
return !it->second.hasStartedPrinting;
@@ -2259,10 +2257,6 @@ LogicalResult AsmPrinter::Impl::printAlias(Type type) {
return state.getAliasState().getAlias(type, os);
}
-bool AsmPrinter::Impl::hasFutureAlias(Type type) const {
- return state.getAliasState().hasFutureAlias(type);
-}
-
void AsmPrinter::Impl::printAttribute(Attribute attr,
AttrTypeElision typeElision) {
if (!attr) {
@@ -2820,6 +2814,10 @@ LogicalResult AsmPrinter::Impl::pushCyclicPrinting(const void *opaquePointer) {
void AsmPrinter::Impl::popCyclicPrinting() { state.popCyclicPrinting(); }
+bool AsmPrinter::Impl::hasFutureAlias(const void *opaquePointer) const {
+ return state.getAliasState().hasFutureAlias(opaquePointer);
+}
+
void AsmPrinter::Impl::printDimensionList(ArrayRef<int64_t> shape) {
detail::printDimensionList(os, shape);
}
@@ -2861,11 +2859,6 @@ LogicalResult AsmPrinter::printAlias(Type type) {
return impl->printAlias(type);
}
-bool AsmPrinter::hasFutureAlias(Type type) const {
- assert(impl && "expected AsmPrinter::hasFutureAlias to be overridden");
- return impl->hasFutureAlias(type);
-}
-
void AsmPrinter::printAttributeWithoutType(Attribute attr) {
assert(impl &&
"expected AsmPrinter::printAttributeWithoutType to be overriden");
@@ -2904,6 +2897,10 @@ LogicalResult AsmPrinter::pushCyclicPrinting(const void *opaquePointer) {
void AsmPrinter::popCyclicPrinting() { impl->popCyclicPrinting(); }
+bool AsmPrinter::hasFutureAlias(const void *opaquePointer) const {
+ return impl->hasFutureAlias(opaquePointer);
+}
+
//===----------------------------------------------------------------------===//
// Affine expressions and maps
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/lib/Dialect/Test/TestTypes.cpp b/mlir/test/lib/Dialect/Test/TestTypes.cpp
index 94576544ba3069..48f519faba40f1 100644
--- a/mlir/test/lib/Dialect/Test/TestTypes.cpp
+++ b/mlir/test/lib/Dialect/Test/TestTypes.cpp
@@ -529,7 +529,7 @@ void TestRecursiveAliasType::print(AsmPrinter &printer) const {
printer.tryStartCyclicPrint(*this);
printer << "<" << getName();
- if (succeeded(cyclicPrint) && !printer.hasFutureAlias(*this)) {
+ if (succeeded(cyclicPrint)) {
printer << ", ";
printer << getBody();
}
More information about the Mlir-commits
mailing list