[Mlir-commits] [mlir] de3f7e2 - [mlir] Fix infinite recursion in alias initializer

Markus Böck llvmlistbot at llvm.org
Sun Aug 27 08:32:24 PDT 2023


Author: Markus Böck
Date: 2023-08-27T17:31:35+02:00
New Revision: de3f7e2f0fb4363c17eec73ce79ca30e221ea844

URL: https://github.com/llvm/llvm-project/commit/de3f7e2f0fb4363c17eec73ce79ca30e221ea844
DIFF: https://github.com/llvm/llvm-project/commit/de3f7e2f0fb4363c17eec73ce79ca30e221ea844.diff

LOG: [mlir] Fix infinite recursion in alias initializer

The alias initializer keeps a list of child indices around. When an alias is then marked as non-deferrable, all children are also marked non-deferrable.

This is currently done naively which leads to an infinite recursion if using mutable types or attributes containing a cycle.

This patch fixes this by adding an early return if the alias is already marked non-deferrable. Since this function is the only way to mark an alias as non-deferrable, it is guaranteed that if it is marked non-deferrable, all its children are as well, and it is not required to walk all the children.
This incidentally makes the non-deferrable marking also `O(n)` instead of `O(n^2)` (although not performance sensitive obviously).

Differential Revision: https://reviews.llvm.org/D158932

Added: 
    

Modified: 
    mlir/lib/IR/AsmPrinter.cpp
    mlir/test/IR/recursive-type.mlir
    mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp
    mlir/test/lib/Dialect/Test/TestTypeDefs.td
    mlir/test/lib/Dialect/Test/TestTypes.cpp
    mlir/test/lib/Dialect/Test/TestTypes.h

Removed: 
    


################################################################################
diff  --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index aeb6f0f3562635..333f4e537fcc74 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -1056,6 +1056,12 @@ std::pair<size_t, size_t> AliasInitializer::visitImpl(
 
 void AliasInitializer::markAliasNonDeferrable(size_t aliasIndex) {
   auto it = std::next(aliases.begin(), aliasIndex);
+
+  // If already marked non-deferrable stop the recursion.
+  // All children should already be marked non-deferrable as well.
+  if (!it->second.canBeDeferred)
+    return;
+
   it->second.canBeDeferred = false;
 
   // Propagate the non-deferrable flag to any child aliases.

diff  --git a/mlir/test/IR/recursive-type.mlir b/mlir/test/IR/recursive-type.mlir
index bc9b2cdbea6b67..121ba095573baa 100644
--- a/mlir/test/IR/recursive-type.mlir
+++ b/mlir/test/IR/recursive-type.mlir
@@ -1,6 +1,8 @@
 // RUN: mlir-opt %s -test-recursive-types | FileCheck %s
 
 // CHECK: !testrec = !test.test_rec<type_to_alias, test_rec<type_to_alias>>
+// CHECK: ![[$NAME:.*]] = !test.test_rec_alias<name, !test.test_rec_alias<name>>
+// CHECK: ![[$NAME2:.*]] = !test.test_rec_alias<name2, tuple<!test.test_rec_alias<name2>, i32>>
 
 // CHECK-LABEL: @roundtrip
 func.func @roundtrip() {
@@ -12,6 +14,16 @@ func.func @roundtrip() {
   // into inifinite recursion.
   // CHECK: !testrec
   "test.dummy_op_for_roundtrip"() : () -> !test.test_rec<type_to_alias, test_rec<type_to_alias>>
+
+  // CHECK: () -> ![[$NAME]]
+  // CHECK: () -> ![[$NAME]]
+  "test.dummy_op_for_roundtrip"() : () -> !test.test_rec_alias<name, !test.test_rec_alias<name>>
+  "test.dummy_op_for_roundtrip"() : () -> !test.test_rec_alias<name, !test.test_rec_alias<name>>
+
+  // CHECK: () -> ![[$NAME2]]
+  // CHECK: () -> ![[$NAME2]]
+  "test.dummy_op_for_roundtrip"() : () -> !test.test_rec_alias<name2, tuple<!test.test_rec_alias<name2>, i32>>
+  "test.dummy_op_for_roundtrip"() : () -> !test.test_rec_alias<name2, tuple<!test.test_rec_alias<name2>, i32>>
   return
 }
 

diff  --git a/mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp b/mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp
index 5da22ddb081292..950af85007475b 100644
--- a/mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp
+++ b/mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp
@@ -218,6 +218,10 @@ struct TestOpAsmInterface : public OpAsmDialectInterface {
         return AliasResult::FinalAlias;
       }
     }
+    if (auto recAliasType = dyn_cast<TestRecursiveAliasType>(type)) {
+      os << recAliasType.getName();
+      return AliasResult::FinalAlias;
+    }
     return AliasResult::NoAlias;
   }
 

diff  --git a/mlir/test/lib/Dialect/Test/TestTypeDefs.td b/mlir/test/lib/Dialect/Test/TestTypeDefs.td
index f899d72219d058..2a8bdad8fb25d9 100644
--- a/mlir/test/lib/Dialect/Test/TestTypeDefs.td
+++ b/mlir/test/lib/Dialect/Test/TestTypeDefs.td
@@ -373,4 +373,22 @@ def TestI32 : Test_Type<"TestI32"> {
   let mnemonic = "i32";
 }
 
+def TestRecursiveAlias
+    : Test_Type<"TestRecursiveAlias", [NativeTypeTrait<"IsMutable">]> {
+  let mnemonic = "test_rec_alias";
+  let storageClass = "TestRecursiveTypeStorage";
+  let storageNamespace = "test";
+  let genStorageClass = 0;
+
+  let parameters = (ins "llvm::StringRef":$name);
+
+  let hasCustomAssemblyFormat = 1;
+
+  let extraClassDeclaration = [{
+    Type getBody() const;
+
+    void setBody(Type type);
+  }];
+}
+
 #endif // TEST_TYPEDEFS

diff  --git a/mlir/test/lib/Dialect/Test/TestTypes.cpp b/mlir/test/lib/Dialect/Test/TestTypes.cpp
index 0633752067a14f..20dc03a7652697 100644
--- a/mlir/test/lib/Dialect/Test/TestTypes.cpp
+++ b/mlir/test/lib/Dialect/Test/TestTypes.cpp
@@ -482,3 +482,54 @@ void TestDialect::printType(Type type, DialectAsmPrinter &printer) const {
   SetVector<Type> stack;
   printTestType(type, printer, stack);
 }
+
+Type TestRecursiveAliasType::getBody() const { return getImpl()->body; }
+
+void TestRecursiveAliasType::setBody(Type type) { (void)Base::mutate(type); }
+
+StringRef TestRecursiveAliasType::getName() const { return getImpl()->name; }
+
+Type TestRecursiveAliasType::parse(AsmParser &parser) {
+  thread_local static SetVector<Type> stack;
+
+  StringRef name;
+  if (parser.parseLess() || parser.parseKeyword(&name))
+    return Type();
+  auto rec = TestRecursiveAliasType::get(parser.getContext(), name);
+
+  // If this type already has been parsed above in the stack, expect just the
+  // name.
+  if (stack.contains(rec)) {
+    if (failed(parser.parseGreater()))
+      return Type();
+    return rec;
+  }
+
+  // Otherwise, parse the body and update the type.
+  if (failed(parser.parseComma()))
+    return Type();
+  stack.insert(rec);
+  Type subtype;
+  if (parser.parseType(subtype))
+    return nullptr;
+  stack.pop_back();
+  if (!subtype || failed(parser.parseGreater()))
+    return Type();
+
+  rec.setBody(subtype);
+
+  return rec;
+}
+
+void TestRecursiveAliasType::print(AsmPrinter &printer) const {
+  thread_local static SetVector<Type> stack;
+
+  printer << "<" << getName();
+  if (!stack.contains(*this)) {
+    printer << ", ";
+    stack.insert(*this);
+    printer << getBody();
+    stack.pop_back();
+  }
+  printer << ">";
+}

diff  --git a/mlir/test/lib/Dialect/Test/TestTypes.h b/mlir/test/lib/Dialect/Test/TestTypes.h
index c7d169d020d56f..0ce86dd70ab904 100644
--- a/mlir/test/lib/Dialect/Test/TestTypes.h
+++ b/mlir/test/lib/Dialect/Test/TestTypes.h
@@ -91,9 +91,6 @@ struct FieldParser<std::optional<int>> {
 
 #include "TestTypeInterfaces.h.inc"
 
-#define GET_TYPEDEF_CLASSES
-#include "TestTypeDefs.h.inc"
-
 namespace test {
 
 /// Storage for simple named recursive types, where the type is identified by
@@ -150,4 +147,7 @@ class TestRecursiveType
 
 } // namespace test
 
+#define GET_TYPEDEF_CLASSES
+#include "TestTypeDefs.h.inc"
+
 #endif // MLIR_TESTTYPES_H


        


More information about the Mlir-commits mailing list