[llvm-branch-commits] [mlir] b66219d - [mlir] Fix infinite recursion in alias initializer
Tobias Hieta via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Wed Aug 30 23:56:16 PDT 2023
Author: Markus Böck
Date: 2023-08-31T08:54:24+02:00
New Revision: b66219d735006fafeeb2b2a1194821daee2f7245
URL: https://github.com/llvm/llvm-project/commit/b66219d735006fafeeb2b2a1194821daee2f7245
DIFF: https://github.com/llvm/llvm-project/commit/b66219d735006fafeeb2b2a1194821daee2f7245.diff
LOG: [mlir] Fix infinite recursion in alias initializer
The alias initializer keeps a list of child indices around. When an alias is then marked as non-deferrable, all children are also marked non-deferrable.
This is currently done naively which leads to an infinite recursion if using mutable types or attributes containing a cycle.
This patch fixes this by adding an early return if the alias is already marked non-deferrable. Since this function is the only way to mark an alias as non-deferrable, it is guaranteed that if it is marked non-deferrable, all its children are as well, and it is not required to walk all the children.
This incidentally makes the non-deferrable marking also `O(n)` instead of `O(n^2)` (although not performance sensitive obviously).
Differential Revision: https://reviews.llvm.org/D158932
Added:
Modified:
mlir/lib/IR/AsmPrinter.cpp
mlir/test/IR/recursive-type.mlir
mlir/test/lib/Dialect/Test/TestDialect.cpp
mlir/test/lib/Dialect/Test/TestTypeDefs.td
mlir/test/lib/Dialect/Test/TestTypes.cpp
mlir/test/lib/Dialect/Test/TestTypes.h
Removed:
################################################################################
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 325f986f976944..af415326708904 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -1043,6 +1043,12 @@ std::pair<size_t, size_t> AliasInitializer::visitImpl(
void AliasInitializer::markAliasNonDeferrable(size_t aliasIndex) {
auto it = std::next(aliases.begin(), aliasIndex);
+
+ // If already marked non-deferrable stop the recursion.
+ // All children should already be marked non-deferrable as well.
+ if (!it->second.canBeDeferred)
+ return;
+
it->second.canBeDeferred = false;
// Propagate the non-deferrable flag to any child aliases.
diff --git a/mlir/test/IR/recursive-type.mlir b/mlir/test/IR/recursive-type.mlir
index bc9b2cdbea6b67..121ba095573baa 100644
--- a/mlir/test/IR/recursive-type.mlir
+++ b/mlir/test/IR/recursive-type.mlir
@@ -1,6 +1,8 @@
// RUN: mlir-opt %s -test-recursive-types | FileCheck %s
// CHECK: !testrec = !test.test_rec<type_to_alias, test_rec<type_to_alias>>
+// CHECK: ![[$NAME:.*]] = !test.test_rec_alias<name, !test.test_rec_alias<name>>
+// CHECK: ![[$NAME2:.*]] = !test.test_rec_alias<name2, tuple<!test.test_rec_alias<name2>, i32>>
// CHECK-LABEL: @roundtrip
func.func @roundtrip() {
@@ -12,6 +14,16 @@ func.func @roundtrip() {
// into inifinite recursion.
// CHECK: !testrec
"test.dummy_op_for_roundtrip"() : () -> !test.test_rec<type_to_alias, test_rec<type_to_alias>>
+
+ // CHECK: () -> ![[$NAME]]
+ // CHECK: () -> ![[$NAME]]
+ "test.dummy_op_for_roundtrip"() : () -> !test.test_rec_alias<name, !test.test_rec_alias<name>>
+ "test.dummy_op_for_roundtrip"() : () -> !test.test_rec_alias<name, !test.test_rec_alias<name>>
+
+ // CHECK: () -> ![[$NAME2]]
+ // CHECK: () -> ![[$NAME2]]
+ "test.dummy_op_for_roundtrip"() : () -> !test.test_rec_alias<name2, tuple<!test.test_rec_alias<name2>, i32>>
+ "test.dummy_op_for_roundtrip"() : () -> !test.test_rec_alias<name2, tuple<!test.test_rec_alias<name2>, i32>>
return
}
diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp
index 072f6ff4b84d33..debe733f59be40 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.cpp
+++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp
@@ -312,6 +312,10 @@ struct TestOpAsmInterface : public OpAsmDialectInterface {
return AliasResult::FinalAlias;
}
}
+ if (auto recAliasType = dyn_cast<TestRecursiveAliasType>(type)) {
+ os << recAliasType.getName();
+ return AliasResult::FinalAlias;
+ }
return AliasResult::NoAlias;
}
diff --git a/mlir/test/lib/Dialect/Test/TestTypeDefs.td b/mlir/test/lib/Dialect/Test/TestTypeDefs.td
index 15dbd74aec118f..2a8bdad8fb25d9 100644
--- a/mlir/test/lib/Dialect/Test/TestTypeDefs.td
+++ b/mlir/test/lib/Dialect/Test/TestTypeDefs.td
@@ -369,4 +369,26 @@ def TestTypeElseAnchorStruct : Test_Type<"TestTypeElseAnchorStruct"> {
let assemblyFormat = "`<` (`?`) : (struct($a, $b)^)? `>`";
}
+def TestI32 : Test_Type<"TestI32"> {
+ let mnemonic = "i32";
+}
+
+def TestRecursiveAlias
+ : Test_Type<"TestRecursiveAlias", [NativeTypeTrait<"IsMutable">]> {
+ let mnemonic = "test_rec_alias";
+ let storageClass = "TestRecursiveTypeStorage";
+ let storageNamespace = "test";
+ let genStorageClass = 0;
+
+ let parameters = (ins "llvm::StringRef":$name);
+
+ let hasCustomAssemblyFormat = 1;
+
+ let extraClassDeclaration = [{
+ Type getBody() const;
+
+ void setBody(Type type);
+ }];
+}
+
#endif // TEST_TYPEDEFS
diff --git a/mlir/test/lib/Dialect/Test/TestTypes.cpp b/mlir/test/lib/Dialect/Test/TestTypes.cpp
index 0633752067a14f..20dc03a7652697 100644
--- a/mlir/test/lib/Dialect/Test/TestTypes.cpp
+++ b/mlir/test/lib/Dialect/Test/TestTypes.cpp
@@ -482,3 +482,54 @@ void TestDialect::printType(Type type, DialectAsmPrinter &printer) const {
SetVector<Type> stack;
printTestType(type, printer, stack);
}
+
+Type TestRecursiveAliasType::getBody() const { return getImpl()->body; }
+
+void TestRecursiveAliasType::setBody(Type type) { (void)Base::mutate(type); }
+
+StringRef TestRecursiveAliasType::getName() const { return getImpl()->name; }
+
+Type TestRecursiveAliasType::parse(AsmParser &parser) {
+ thread_local static SetVector<Type> stack;
+
+ StringRef name;
+ if (parser.parseLess() || parser.parseKeyword(&name))
+ return Type();
+ auto rec = TestRecursiveAliasType::get(parser.getContext(), name);
+
+ // If this type already has been parsed above in the stack, expect just the
+ // name.
+ if (stack.contains(rec)) {
+ if (failed(parser.parseGreater()))
+ return Type();
+ return rec;
+ }
+
+ // Otherwise, parse the body and update the type.
+ if (failed(parser.parseComma()))
+ return Type();
+ stack.insert(rec);
+ Type subtype;
+ if (parser.parseType(subtype))
+ return nullptr;
+ stack.pop_back();
+ if (!subtype || failed(parser.parseGreater()))
+ return Type();
+
+ rec.setBody(subtype);
+
+ return rec;
+}
+
+void TestRecursiveAliasType::print(AsmPrinter &printer) const {
+ thread_local static SetVector<Type> stack;
+
+ printer << "<" << getName();
+ if (!stack.contains(*this)) {
+ printer << ", ";
+ stack.insert(*this);
+ printer << getBody();
+ stack.pop_back();
+ }
+ printer << ">";
+}
diff --git a/mlir/test/lib/Dialect/Test/TestTypes.h b/mlir/test/lib/Dialect/Test/TestTypes.h
index c7d169d020d56f..0ce86dd70ab904 100644
--- a/mlir/test/lib/Dialect/Test/TestTypes.h
+++ b/mlir/test/lib/Dialect/Test/TestTypes.h
@@ -91,9 +91,6 @@ struct FieldParser<std::optional<int>> {
#include "TestTypeInterfaces.h.inc"
-#define GET_TYPEDEF_CLASSES
-#include "TestTypeDefs.h.inc"
-
namespace test {
/// Storage for simple named recursive types, where the type is identified by
@@ -150,4 +147,7 @@ class TestRecursiveType
} // namespace test
+#define GET_TYPEDEF_CLASSES
+#include "TestTypeDefs.h.inc"
+
#endif // MLIR_TESTTYPES_H
More information about the llvm-branch-commits
mailing list