[Mlir-commits] [mlir] d410286 - [mlir] Prevent SubElementInterface from going into infinite recursion

Min-Yih Hsu llvmlistbot at llvm.org
Wed Jun 29 13:58:38 PDT 2022


Author: Min-Yih Hsu
Date: 2022-06-29T13:58:02-07:00
New Revision: d41028610b5372669adcb9b7091fae5250f0a4a8

URL: https://github.com/llvm/llvm-project/commit/d41028610b5372669adcb9b7091fae5250f0a4a8
DIFF: https://github.com/llvm/llvm-project/commit/d41028610b5372669adcb9b7091fae5250f0a4a8.diff

LOG: [mlir] Prevent SubElementInterface from going into infinite recursion

Since only mutable types and attributes can go into infinite recursion
inside SubElementInterface::walkSubElement, and there are only a few of
them (mutable types and attributes), we introduce new traits for Type
and Attribute: TypeTrait::IsMutable and AttributeTrait::IsMutable,
respectively. They indicate whether a type or attribute is mutable.
Such traits are required if the ImplType defines a `mutate` function.

Then, inside SubElementInterface, we use a set to record visited mutable
types and attributes that have been visited before.

Differential Revision: https://reviews.llvm.org/D127537

Added: 
    mlir/unittests/Dialect/LLVMIR/CMakeLists.txt
    mlir/unittests/Dialect/LLVMIR/LLVMTestBase.h
    mlir/unittests/Dialect/LLVMIR/LLVMTypeTest.cpp

Modified: 
    mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
    mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
    mlir/include/mlir/IR/Attributes.h
    mlir/include/mlir/IR/StorageUniquerSupport.h
    mlir/include/mlir/IR/Types.h
    mlir/lib/IR/SubElementInterfaces.cpp
    mlir/test/IR/recursive-type.mlir
    mlir/test/lib/Dialect/Test/TestDialect.cpp
    mlir/test/lib/Dialect/Test/TestTypes.h
    mlir/unittests/Dialect/CMakeLists.txt

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
index c3a244f021c79..50537f6c9abed 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
@@ -264,7 +264,8 @@ class LLVMPointerType : public Type::TypeBase<LLVMPointerType, Type,
 /// structs, but does not in uniquing of identified structs.
 class LLVMStructType
     : public Type::TypeBase<LLVMStructType, Type, detail::LLVMStructTypeStorage,
-                            DataLayoutTypeInterface::Trait> {
+                            DataLayoutTypeInterface::Trait,
+                            TypeTrait::IsMutable> {
 public:
   /// Inherit base constructors.
   using Base::Base;

diff  --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
index 0d4471ee51c64..40a4acff751b1 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
@@ -275,8 +275,9 @@ class SampledImageType
 /// In the above, expressing recursive struct types is accomplished by giving a
 /// recursive struct a unique identified and using that identifier in the struct
 /// definition for recursive references.
-class StructType : public Type::TypeBase<StructType, CompositeType,
-                                         detail::StructTypeStorage> {
+class StructType
+    : public Type::TypeBase<StructType, CompositeType,
+                            detail::StructTypeStorage, TypeTrait::IsMutable> {
 public:
   using Base::Base;
 

diff  --git a/mlir/include/mlir/IR/Attributes.h b/mlir/include/mlir/IR/Attributes.h
index da5b78edf92be..2195057f55ab9 100644
--- a/mlir/include/mlir/IR/Attributes.h
+++ b/mlir/include/mlir/IR/Attributes.h
@@ -231,6 +231,18 @@ class AttributeInterface
   friend InterfaceBase;
 };
 
+//===----------------------------------------------------------------------===//
+// Core AttributeTrait
+//===----------------------------------------------------------------------===//
+
+/// This trait is used to determine if an attribute is mutable or not. It is
+/// attached on an attribute if the corresponding ImplType defines a `mutate`
+/// function with proper signature.
+namespace AttributeTrait {
+template <typename ConcreteType>
+using IsMutable = detail::StorageUserTrait::IsMutable<ConcreteType>;
+} // namespace AttributeTrait
+
 } // namespace mlir.
 
 namespace llvm {

diff  --git a/mlir/include/mlir/IR/StorageUniquerSupport.h b/mlir/include/mlir/IR/StorageUniquerSupport.h
index 6d854f66f6ff7..7aaa6881ebd8f 100644
--- a/mlir/include/mlir/IR/StorageUniquerSupport.h
+++ b/mlir/include/mlir/IR/StorageUniquerSupport.h
@@ -53,6 +53,16 @@ class StorageUserTraitBase {
   }
 };
 
+namespace StorageUserTrait {
+/// This trait is used to determine if a storage user, like Type, is mutable
+/// or not. A storage user is mutable if ImplType of the derived class defines
+/// a `mutate` function with a proper signature. Note that this trait is not
+/// supposed to be used publicly. Users should use alias names like
+/// `TypeTrait::IsMutable` instead.
+template <typename ConcreteType>
+struct IsMutable : public StorageUserTraitBase<ConcreteType, IsMutable> {};
+} // namespace StorageUserTrait
+
 //===----------------------------------------------------------------------===//
 // StorageUserBase
 //===----------------------------------------------------------------------===//
@@ -173,6 +183,10 @@ class StorageUserBase : public BaseT, public Traits<ConcreteT>... {
   /// Mutate the current storage instance. This will not change the unique key.
   /// The arguments are forwarded to 'ConcreteT::mutate'.
   template <typename... Args> LogicalResult mutate(Args &&...args) {
+    static_assert(std::is_base_of<StorageUserTrait::IsMutable<ConcreteT>,
+                                  ConcreteT>::value,
+                  "The `mutate` function expects mutable trait "
+                  "(e.g. TypeTrait::IsMutable) to be attached on parent.");
     return UniquerT::template mutate<ConcreteT>(this->getContext(), getImpl(),
                                                 std::forward<Args>(args)...);
   }

diff  --git a/mlir/include/mlir/IR/Types.h b/mlir/include/mlir/IR/Types.h
index f2a69660ae9f7..5af73894c9c52 100644
--- a/mlir/include/mlir/IR/Types.h
+++ b/mlir/include/mlir/IR/Types.h
@@ -222,6 +222,18 @@ class TypeInterface : public detail::Interface<ConcreteType, Type, Traits, Type,
   friend InterfaceBase;
 };
 
+//===----------------------------------------------------------------------===//
+// Core TypeTrait
+//===----------------------------------------------------------------------===//
+
+/// This trait is used to determine if a type is mutable or not. It is attached
+/// on a type if the corresponding ImplType defines a `mutate` function with
+/// a proper signature.
+namespace TypeTrait {
+template <typename ConcreteType>
+using IsMutable = detail::StorageUserTrait::IsMutable<ConcreteType>;
+} // namespace TypeTrait
+
 //===----------------------------------------------------------------------===//
 // Type Utils
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/IR/SubElementInterfaces.cpp b/mlir/lib/IR/SubElementInterfaces.cpp
index 0a4875cf7d8aa..4059b99b5db24 100644
--- a/mlir/lib/IR/SubElementInterfaces.cpp
+++ b/mlir/lib/IR/SubElementInterfaces.cpp
@@ -8,12 +8,16 @@
 
 #include "mlir/IR/SubElementInterfaces.h"
 
+#include "llvm/ADT/DenseSet.h"
+
 using namespace mlir;
 
 template <typename InterfaceT>
 static void walkSubElementsImpl(InterfaceT interface,
                                 function_ref<void(Attribute)> walkAttrsFn,
-                                function_ref<void(Type)> walkTypesFn) {
+                                function_ref<void(Type)> walkTypesFn,
+                                DenseSet<Attribute> &visitedAttrs,
+                                DenseSet<Type> &visitedTypes) {
   interface.walkImmediateSubElements(
       [&](Attribute attr) {
         // Guard against potentially null inputs. This removes the need for the
@@ -21,9 +25,17 @@ static void walkSubElementsImpl(InterfaceT interface,
         if (!attr)
           return;
 
+        // Avoid infinite recursion when visiting sub attributes later, if this
+        // is a mutable attribute.
+        if (LLVM_UNLIKELY(attr.hasTrait<AttributeTrait::IsMutable>())) {
+          if (!visitedAttrs.insert(attr).second)
+            return;
+        }
+
         // Walk any sub elements first.
         if (auto interface = attr.dyn_cast<SubElementAttrInterface>())
-          walkSubElementsImpl(interface, walkAttrsFn, walkTypesFn);
+          walkSubElementsImpl(interface, walkAttrsFn, walkTypesFn, visitedAttrs,
+                              visitedTypes);
 
         // Walk this attribute.
         walkAttrsFn(attr);
@@ -34,9 +46,17 @@ static void walkSubElementsImpl(InterfaceT interface,
         if (!type)
           return;
 
+        // Avoid infinite recursion when visiting sub types later, if this
+        // is a mutable type.
+        if (LLVM_UNLIKELY(type.hasTrait<TypeTrait::IsMutable>())) {
+          if (!visitedTypes.insert(type).second)
+            return;
+        }
+
         // Walk any sub elements first.
         if (auto interface = type.dyn_cast<SubElementTypeInterface>())
-          walkSubElementsImpl(interface, walkAttrsFn, walkTypesFn);
+          walkSubElementsImpl(interface, walkAttrsFn, walkTypesFn, visitedAttrs,
+                              visitedTypes);
 
         // Walk this type.
         walkTypesFn(type);
@@ -47,14 +67,20 @@ void SubElementAttrInterface::walkSubElements(
     function_ref<void(Attribute)> walkAttrsFn,
     function_ref<void(Type)> walkTypesFn) {
   assert(walkAttrsFn && walkTypesFn && "expected valid walk functions");
-  walkSubElementsImpl(*this, walkAttrsFn, walkTypesFn);
+  DenseSet<Attribute> visitedAttrs;
+  DenseSet<Type> visitedTypes;
+  walkSubElementsImpl(*this, walkAttrsFn, walkTypesFn, visitedAttrs,
+                      visitedTypes);
 }
 
 void SubElementTypeInterface::walkSubElements(
     function_ref<void(Attribute)> walkAttrsFn,
     function_ref<void(Type)> walkTypesFn) {
   assert(walkAttrsFn && walkTypesFn && "expected valid walk functions");
-  walkSubElementsImpl(*this, walkAttrsFn, walkTypesFn);
+  DenseSet<Attribute> visitedAttrs;
+  DenseSet<Type> visitedTypes;
+  walkSubElementsImpl(*this, walkAttrsFn, walkTypesFn, visitedAttrs,
+                      visitedTypes);
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/IR/recursive-type.mlir b/mlir/test/IR/recursive-type.mlir
index e66d7fd5a5ab4..bc9b2cdbea6b6 100644
--- a/mlir/test/IR/recursive-type.mlir
+++ b/mlir/test/IR/recursive-type.mlir
@@ -1,11 +1,17 @@
 // RUN: mlir-opt %s -test-recursive-types | FileCheck %s
 
+// CHECK: !testrec = !test.test_rec<type_to_alias, test_rec<type_to_alias>>
+
 // CHECK-LABEL: @roundtrip
 func.func @roundtrip() {
   // CHECK: !test.test_rec<a, test_rec<b, test_type>>
   "test.dummy_op_for_roundtrip"() : () -> !test.test_rec<a, test_rec<b, test_type>>
   // CHECK: !test.test_rec<c, test_rec<c>>
   "test.dummy_op_for_roundtrip"() : () -> !test.test_rec<c, test_rec<c>>
+  // Make sure walkSubElementType, which is used to generate aliases, doesn't go
+  // into inifinite recursion.
+  // CHECK: !testrec
+  "test.dummy_op_for_roundtrip"() : () -> !test.test_rec<type_to_alias, test_rec<type_to_alias>>
   return
 }
 

diff  --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp
index 3585d43b3128b..653f0e10ae319 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.cpp
+++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp
@@ -160,6 +160,13 @@ struct TestOpAsmInterface : public OpAsmDialectInterface {
         return AliasResult::FinalAlias;
       }
     }
+    if (auto recType = type.dyn_cast<TestRecursiveType>()) {
+      if (recType.getName() == "type_to_alias") {
+        // We only make alias for a specific recursive type.
+        os << "testrec";
+        return AliasResult::FinalAlias;
+      }
+    }
     return AliasResult::NoAlias;
   }
 

diff  --git a/mlir/test/lib/Dialect/Test/TestTypes.h b/mlir/test/lib/Dialect/Test/TestTypes.h
index bd6421d04735e..8772efd8590fc 100644
--- a/mlir/test/lib/Dialect/Test/TestTypes.h
+++ b/mlir/test/lib/Dialect/Test/TestTypes.h
@@ -21,6 +21,7 @@
 #include "mlir/IR/Dialect.h"
 #include "mlir/IR/DialectImplementation.h"
 #include "mlir/IR/Operation.h"
+#include "mlir/IR/SubElementInterfaces.h"
 #include "mlir/IR/Types.h"
 #include "mlir/Interfaces/DataLayoutInterfaces.h"
 
@@ -130,7 +131,9 @@ struct TestRecursiveTypeStorage : public ::mlir::TypeStorage {
 /// from type creation.
 class TestRecursiveType
     : public ::mlir::Type::TypeBase<TestRecursiveType, ::mlir::Type,
-                                    TestRecursiveTypeStorage> {
+                                    TestRecursiveTypeStorage,
+                                    ::mlir::SubElementTypeInterface::Trait,
+                                    ::mlir::TypeTrait::IsMutable> {
 public:
   using Base::Base;
 
@@ -141,10 +144,16 @@ class TestRecursiveType
 
   /// Body getter and setter.
   ::mlir::LogicalResult setBody(Type body) { return Base::mutate(body); }
-  ::mlir::Type getBody() { return getImpl()->body; }
+  ::mlir::Type getBody() const { return getImpl()->body; }
 
   /// Name/key getter.
   ::llvm::StringRef getName() { return getImpl()->name; }
+
+  void walkImmediateSubElements(
+      ::llvm::function_ref<void(::mlir::Attribute)> walkAttrsFn,
+      ::llvm::function_ref<void(::mlir::Type)> walkTypesFn) const {
+    walkTypesFn(getBody());
+  }
 };
 
 } // namespace test

diff  --git a/mlir/unittests/Dialect/CMakeLists.txt b/mlir/unittests/Dialect/CMakeLists.txt
index ec89b14837de9..f8e5e46e5ac3b 100644
--- a/mlir/unittests/Dialect/CMakeLists.txt
+++ b/mlir/unittests/Dialect/CMakeLists.txt
@@ -7,8 +7,8 @@ target_link_libraries(MLIRDialectTests
   MLIRDialect)
 
 add_subdirectory(Affine)
+add_subdirectory(LLVMIR)
 add_subdirectory(MemRef)
-
 add_subdirectory(Quant)
 add_subdirectory(SparseTensor)
 add_subdirectory(SPIRV)

diff  --git a/mlir/unittests/Dialect/LLVMIR/CMakeLists.txt b/mlir/unittests/Dialect/LLVMIR/CMakeLists.txt
new file mode 100644
index 0000000000000..92af1856c68e0
--- /dev/null
+++ b/mlir/unittests/Dialect/LLVMIR/CMakeLists.txt
@@ -0,0 +1,7 @@
+add_mlir_unittest(MLIRLLVMIRTests
+  LLVMTypeTest.cpp
+)
+target_link_libraries(MLIRLLVMIRTests
+  PRIVATE
+  MLIRLLVMDialect
+  )

diff  --git a/mlir/unittests/Dialect/LLVMIR/LLVMTestBase.h b/mlir/unittests/Dialect/LLVMIR/LLVMTestBase.h
new file mode 100644
index 0000000000000..1badc44ce2132
--- /dev/null
+++ b/mlir/unittests/Dialect/LLVMIR/LLVMTestBase.h
@@ -0,0 +1,27 @@
+//===- LLVMTestBase.h - Test fixure for LLVM dialect tests ------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Test fixure for LLVM dialect tests.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_UNITTEST_DIALECT_LLVMIR_LLVMTESTBASE_H
+#define MLIR_UNITTEST_DIALECT_LLVMIR_LLVMTESTBASE_H
+
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/IR/MLIRContext.h"
+#include "gtest/gtest.h"
+
+class LLVMIRTest : public ::testing::Test {
+protected:
+  LLVMIRTest() { context.getOrLoadDialect<mlir::LLVM::LLVMDialect>(); }
+
+  mlir::MLIRContext context;
+};
+
+#endif

diff  --git a/mlir/unittests/Dialect/LLVMIR/LLVMTypeTest.cpp b/mlir/unittests/Dialect/LLVMIR/LLVMTypeTest.cpp
new file mode 100644
index 0000000000000..9c0ea4f14d766
--- /dev/null
+++ b/mlir/unittests/Dialect/LLVMIR/LLVMTypeTest.cpp
@@ -0,0 +1,20 @@
+//===- LLVMTypeTest.cpp - Tests for LLVM types ----------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "LLVMTestBase.h"
+#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
+#include "mlir/IR/SubElementInterfaces.h"
+
+using namespace mlir;
+using namespace mlir::LLVM;
+
+TEST_F(LLVMIRTest, IsStructTypeMutable) {
+  auto structTy = LLVMStructType::getIdentified(&context, "foo");
+  ASSERT_TRUE(bool(structTy));
+  ASSERT_TRUE(structTy.hasTrait<TypeTrait::IsMutable>());
+}


        


More information about the Mlir-commits mailing list