[Mlir-commits] [mlir] de3f7e2 - [mlir] Fix infinite recursion in alias initializer
Markus Böck
llvmlistbot at llvm.org
Sun Aug 27 08:32:24 PDT 2023
Author: Markus Böck
Date: 2023-08-27T17:31:35+02:00
New Revision: de3f7e2f0fb4363c17eec73ce79ca30e221ea844
URL: https://github.com/llvm/llvm-project/commit/de3f7e2f0fb4363c17eec73ce79ca30e221ea844
DIFF: https://github.com/llvm/llvm-project/commit/de3f7e2f0fb4363c17eec73ce79ca30e221ea844.diff
LOG: [mlir] Fix infinite recursion in alias initializer
The alias initializer keeps a list of child indices around. When an alias is then marked as non-deferrable, all children are also marked non-deferrable.
This is currently done naively which leads to an infinite recursion if using mutable types or attributes containing a cycle.
This patch fixes this by adding an early return if the alias is already marked non-deferrable. Since this function is the only way to mark an alias as non-deferrable, it is guaranteed that if it is marked non-deferrable, all its children are as well, and it is not required to walk all the children.
This incidentally makes the non-deferrable marking also `O(n)` instead of `O(n^2)` (although not performance sensitive obviously).
Differential Revision: https://reviews.llvm.org/D158932
Added:
Modified:
mlir/lib/IR/AsmPrinter.cpp
mlir/test/IR/recursive-type.mlir
mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp
mlir/test/lib/Dialect/Test/TestTypeDefs.td
mlir/test/lib/Dialect/Test/TestTypes.cpp
mlir/test/lib/Dialect/Test/TestTypes.h
Removed:
################################################################################
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index aeb6f0f3562635..333f4e537fcc74 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -1056,6 +1056,12 @@ std::pair<size_t, size_t> AliasInitializer::visitImpl(
void AliasInitializer::markAliasNonDeferrable(size_t aliasIndex) {
auto it = std::next(aliases.begin(), aliasIndex);
+
+ // If already marked non-deferrable stop the recursion.
+ // All children should already be marked non-deferrable as well.
+ if (!it->second.canBeDeferred)
+ return;
+
it->second.canBeDeferred = false;
// Propagate the non-deferrable flag to any child aliases.
diff --git a/mlir/test/IR/recursive-type.mlir b/mlir/test/IR/recursive-type.mlir
index bc9b2cdbea6b67..121ba095573baa 100644
--- a/mlir/test/IR/recursive-type.mlir
+++ b/mlir/test/IR/recursive-type.mlir
@@ -1,6 +1,8 @@
// RUN: mlir-opt %s -test-recursive-types | FileCheck %s
// CHECK: !testrec = !test.test_rec<type_to_alias, test_rec<type_to_alias>>
+// CHECK: ![[$NAME:.*]] = !test.test_rec_alias<name, !test.test_rec_alias<name>>
+// CHECK: ![[$NAME2:.*]] = !test.test_rec_alias<name2, tuple<!test.test_rec_alias<name2>, i32>>
// CHECK-LABEL: @roundtrip
func.func @roundtrip() {
@@ -12,6 +14,16 @@ func.func @roundtrip() {
// into inifinite recursion.
// CHECK: !testrec
"test.dummy_op_for_roundtrip"() : () -> !test.test_rec<type_to_alias, test_rec<type_to_alias>>
+
+ // CHECK: () -> ![[$NAME]]
+ // CHECK: () -> ![[$NAME]]
+ "test.dummy_op_for_roundtrip"() : () -> !test.test_rec_alias<name, !test.test_rec_alias<name>>
+ "test.dummy_op_for_roundtrip"() : () -> !test.test_rec_alias<name, !test.test_rec_alias<name>>
+
+ // CHECK: () -> ![[$NAME2]]
+ // CHECK: () -> ![[$NAME2]]
+ "test.dummy_op_for_roundtrip"() : () -> !test.test_rec_alias<name2, tuple<!test.test_rec_alias<name2>, i32>>
+ "test.dummy_op_for_roundtrip"() : () -> !test.test_rec_alias<name2, tuple<!test.test_rec_alias<name2>, i32>>
return
}
diff --git a/mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp b/mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp
index 5da22ddb081292..950af85007475b 100644
--- a/mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp
+++ b/mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp
@@ -218,6 +218,10 @@ struct TestOpAsmInterface : public OpAsmDialectInterface {
return AliasResult::FinalAlias;
}
}
+ if (auto recAliasType = dyn_cast<TestRecursiveAliasType>(type)) {
+ os << recAliasType.getName();
+ return AliasResult::FinalAlias;
+ }
return AliasResult::NoAlias;
}
diff --git a/mlir/test/lib/Dialect/Test/TestTypeDefs.td b/mlir/test/lib/Dialect/Test/TestTypeDefs.td
index f899d72219d058..2a8bdad8fb25d9 100644
--- a/mlir/test/lib/Dialect/Test/TestTypeDefs.td
+++ b/mlir/test/lib/Dialect/Test/TestTypeDefs.td
@@ -373,4 +373,22 @@ def TestI32 : Test_Type<"TestI32"> {
let mnemonic = "i32";
}
+def TestRecursiveAlias
+ : Test_Type<"TestRecursiveAlias", [NativeTypeTrait<"IsMutable">]> {
+ let mnemonic = "test_rec_alias";
+ let storageClass = "TestRecursiveTypeStorage";
+ let storageNamespace = "test";
+ let genStorageClass = 0;
+
+ let parameters = (ins "llvm::StringRef":$name);
+
+ let hasCustomAssemblyFormat = 1;
+
+ let extraClassDeclaration = [{
+ Type getBody() const;
+
+ void setBody(Type type);
+ }];
+}
+
#endif // TEST_TYPEDEFS
diff --git a/mlir/test/lib/Dialect/Test/TestTypes.cpp b/mlir/test/lib/Dialect/Test/TestTypes.cpp
index 0633752067a14f..20dc03a7652697 100644
--- a/mlir/test/lib/Dialect/Test/TestTypes.cpp
+++ b/mlir/test/lib/Dialect/Test/TestTypes.cpp
@@ -482,3 +482,54 @@ void TestDialect::printType(Type type, DialectAsmPrinter &printer) const {
SetVector<Type> stack;
printTestType(type, printer, stack);
}
+
+Type TestRecursiveAliasType::getBody() const { return getImpl()->body; }
+
+void TestRecursiveAliasType::setBody(Type type) { (void)Base::mutate(type); }
+
+StringRef TestRecursiveAliasType::getName() const { return getImpl()->name; }
+
+Type TestRecursiveAliasType::parse(AsmParser &parser) {
+ thread_local static SetVector<Type> stack;
+
+ StringRef name;
+ if (parser.parseLess() || parser.parseKeyword(&name))
+ return Type();
+ auto rec = TestRecursiveAliasType::get(parser.getContext(), name);
+
+ // If this type already has been parsed above in the stack, expect just the
+ // name.
+ if (stack.contains(rec)) {
+ if (failed(parser.parseGreater()))
+ return Type();
+ return rec;
+ }
+
+ // Otherwise, parse the body and update the type.
+ if (failed(parser.parseComma()))
+ return Type();
+ stack.insert(rec);
+ Type subtype;
+ if (parser.parseType(subtype))
+ return nullptr;
+ stack.pop_back();
+ if (!subtype || failed(parser.parseGreater()))
+ return Type();
+
+ rec.setBody(subtype);
+
+ return rec;
+}
+
+void TestRecursiveAliasType::print(AsmPrinter &printer) const {
+ thread_local static SetVector<Type> stack;
+
+ printer << "<" << getName();
+ if (!stack.contains(*this)) {
+ printer << ", ";
+ stack.insert(*this);
+ printer << getBody();
+ stack.pop_back();
+ }
+ printer << ">";
+}
diff --git a/mlir/test/lib/Dialect/Test/TestTypes.h b/mlir/test/lib/Dialect/Test/TestTypes.h
index c7d169d020d56f..0ce86dd70ab904 100644
--- a/mlir/test/lib/Dialect/Test/TestTypes.h
+++ b/mlir/test/lib/Dialect/Test/TestTypes.h
@@ -91,9 +91,6 @@ struct FieldParser<std::optional<int>> {
#include "TestTypeInterfaces.h.inc"
-#define GET_TYPEDEF_CLASSES
-#include "TestTypeDefs.h.inc"
-
namespace test {
/// Storage for simple named recursive types, where the type is identified by
@@ -150,4 +147,7 @@ class TestRecursiveType
} // namespace test
+#define GET_TYPEDEF_CLASSES
+#include "TestTypeDefs.h.inc"
+
#endif // MLIR_TESTTYPES_H
More information about the Mlir-commits
mailing list