[Mlir-commits] [mlir] d410286 - [mlir] Prevent SubElementInterface from going into infinite recursion
Min-Yih Hsu
llvmlistbot at llvm.org
Wed Jun 29 13:58:38 PDT 2022
Author: Min-Yih Hsu
Date: 2022-06-29T13:58:02-07:00
New Revision: d41028610b5372669adcb9b7091fae5250f0a4a8
URL: https://github.com/llvm/llvm-project/commit/d41028610b5372669adcb9b7091fae5250f0a4a8
DIFF: https://github.com/llvm/llvm-project/commit/d41028610b5372669adcb9b7091fae5250f0a4a8.diff
LOG: [mlir] Prevent SubElementInterface from going into infinite recursion
Since only mutable types and attributes can go into infinite recursion
inside SubElementInterface::walkSubElement, and there are only a few of
them (mutable types and attributes), we introduce new traits for Type
and Attribute: TypeTrait::IsMutable and AttributeTrait::IsMutable,
respectively. They indicate whether a type or attribute is mutable.
Such traits are required if the ImplType defines a `mutate` function.
Then, inside SubElementInterface, we use a set to record visited mutable
types and attributes that have been visited before.
Differential Revision: https://reviews.llvm.org/D127537
Added:
mlir/unittests/Dialect/LLVMIR/CMakeLists.txt
mlir/unittests/Dialect/LLVMIR/LLVMTestBase.h
mlir/unittests/Dialect/LLVMIR/LLVMTypeTest.cpp
Modified:
mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
mlir/include/mlir/IR/Attributes.h
mlir/include/mlir/IR/StorageUniquerSupport.h
mlir/include/mlir/IR/Types.h
mlir/lib/IR/SubElementInterfaces.cpp
mlir/test/IR/recursive-type.mlir
mlir/test/lib/Dialect/Test/TestDialect.cpp
mlir/test/lib/Dialect/Test/TestTypes.h
mlir/unittests/Dialect/CMakeLists.txt
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
index c3a244f021c79..50537f6c9abed 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
@@ -264,7 +264,8 @@ class LLVMPointerType : public Type::TypeBase<LLVMPointerType, Type,
/// structs, but does not in uniquing of identified structs.
class LLVMStructType
: public Type::TypeBase<LLVMStructType, Type, detail::LLVMStructTypeStorage,
- DataLayoutTypeInterface::Trait> {
+ DataLayoutTypeInterface::Trait,
+ TypeTrait::IsMutable> {
public:
/// Inherit base constructors.
using Base::Base;
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
index 0d4471ee51c64..40a4acff751b1 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
@@ -275,8 +275,9 @@ class SampledImageType
/// In the above, expressing recursive struct types is accomplished by giving a
/// recursive struct a unique identified and using that identifier in the struct
/// definition for recursive references.
-class StructType : public Type::TypeBase<StructType, CompositeType,
- detail::StructTypeStorage> {
+class StructType
+ : public Type::TypeBase<StructType, CompositeType,
+ detail::StructTypeStorage, TypeTrait::IsMutable> {
public:
using Base::Base;
diff --git a/mlir/include/mlir/IR/Attributes.h b/mlir/include/mlir/IR/Attributes.h
index da5b78edf92be..2195057f55ab9 100644
--- a/mlir/include/mlir/IR/Attributes.h
+++ b/mlir/include/mlir/IR/Attributes.h
@@ -231,6 +231,18 @@ class AttributeInterface
friend InterfaceBase;
};
+//===----------------------------------------------------------------------===//
+// Core AttributeTrait
+//===----------------------------------------------------------------------===//
+
+/// This trait is used to determine if an attribute is mutable or not. It is
+/// attached on an attribute if the corresponding ImplType defines a `mutate`
+/// function with proper signature.
+namespace AttributeTrait {
+template <typename ConcreteType>
+using IsMutable = detail::StorageUserTrait::IsMutable<ConcreteType>;
+} // namespace AttributeTrait
+
} // namespace mlir.
namespace llvm {
diff --git a/mlir/include/mlir/IR/StorageUniquerSupport.h b/mlir/include/mlir/IR/StorageUniquerSupport.h
index 6d854f66f6ff7..7aaa6881ebd8f 100644
--- a/mlir/include/mlir/IR/StorageUniquerSupport.h
+++ b/mlir/include/mlir/IR/StorageUniquerSupport.h
@@ -53,6 +53,16 @@ class StorageUserTraitBase {
}
};
+namespace StorageUserTrait {
+/// This trait is used to determine if a storage user, like Type, is mutable
+/// or not. A storage user is mutable if ImplType of the derived class defines
+/// a `mutate` function with a proper signature. Note that this trait is not
+/// supposed to be used publicly. Users should use alias names like
+/// `TypeTrait::IsMutable` instead.
+template <typename ConcreteType>
+struct IsMutable : public StorageUserTraitBase<ConcreteType, IsMutable> {};
+} // namespace StorageUserTrait
+
//===----------------------------------------------------------------------===//
// StorageUserBase
//===----------------------------------------------------------------------===//
@@ -173,6 +183,10 @@ class StorageUserBase : public BaseT, public Traits<ConcreteT>... {
/// Mutate the current storage instance. This will not change the unique key.
/// The arguments are forwarded to 'ConcreteT::mutate'.
template <typename... Args> LogicalResult mutate(Args &&...args) {
+ static_assert(std::is_base_of<StorageUserTrait::IsMutable<ConcreteT>,
+ ConcreteT>::value,
+ "The `mutate` function expects mutable trait "
+ "(e.g. TypeTrait::IsMutable) to be attached on parent.");
return UniquerT::template mutate<ConcreteT>(this->getContext(), getImpl(),
std::forward<Args>(args)...);
}
diff --git a/mlir/include/mlir/IR/Types.h b/mlir/include/mlir/IR/Types.h
index f2a69660ae9f7..5af73894c9c52 100644
--- a/mlir/include/mlir/IR/Types.h
+++ b/mlir/include/mlir/IR/Types.h
@@ -222,6 +222,18 @@ class TypeInterface : public detail::Interface<ConcreteType, Type, Traits, Type,
friend InterfaceBase;
};
+//===----------------------------------------------------------------------===//
+// Core TypeTrait
+//===----------------------------------------------------------------------===//
+
+/// This trait is used to determine if a type is mutable or not. It is attached
+/// on a type if the corresponding ImplType defines a `mutate` function with
+/// a proper signature.
+namespace TypeTrait {
+template <typename ConcreteType>
+using IsMutable = detail::StorageUserTrait::IsMutable<ConcreteType>;
+} // namespace TypeTrait
+
//===----------------------------------------------------------------------===//
// Type Utils
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/IR/SubElementInterfaces.cpp b/mlir/lib/IR/SubElementInterfaces.cpp
index 0a4875cf7d8aa..4059b99b5db24 100644
--- a/mlir/lib/IR/SubElementInterfaces.cpp
+++ b/mlir/lib/IR/SubElementInterfaces.cpp
@@ -8,12 +8,16 @@
#include "mlir/IR/SubElementInterfaces.h"
+#include "llvm/ADT/DenseSet.h"
+
using namespace mlir;
template <typename InterfaceT>
static void walkSubElementsImpl(InterfaceT interface,
function_ref<void(Attribute)> walkAttrsFn,
- function_ref<void(Type)> walkTypesFn) {
+ function_ref<void(Type)> walkTypesFn,
+ DenseSet<Attribute> &visitedAttrs,
+ DenseSet<Type> &visitedTypes) {
interface.walkImmediateSubElements(
[&](Attribute attr) {
// Guard against potentially null inputs. This removes the need for the
@@ -21,9 +25,17 @@ static void walkSubElementsImpl(InterfaceT interface,
if (!attr)
return;
+ // Avoid infinite recursion when visiting sub attributes later, if this
+ // is a mutable attribute.
+ if (LLVM_UNLIKELY(attr.hasTrait<AttributeTrait::IsMutable>())) {
+ if (!visitedAttrs.insert(attr).second)
+ return;
+ }
+
// Walk any sub elements first.
if (auto interface = attr.dyn_cast<SubElementAttrInterface>())
- walkSubElementsImpl(interface, walkAttrsFn, walkTypesFn);
+ walkSubElementsImpl(interface, walkAttrsFn, walkTypesFn, visitedAttrs,
+ visitedTypes);
// Walk this attribute.
walkAttrsFn(attr);
@@ -34,9 +46,17 @@ static void walkSubElementsImpl(InterfaceT interface,
if (!type)
return;
+ // Avoid infinite recursion when visiting sub types later, if this
+ // is a mutable type.
+ if (LLVM_UNLIKELY(type.hasTrait<TypeTrait::IsMutable>())) {
+ if (!visitedTypes.insert(type).second)
+ return;
+ }
+
// Walk any sub elements first.
if (auto interface = type.dyn_cast<SubElementTypeInterface>())
- walkSubElementsImpl(interface, walkAttrsFn, walkTypesFn);
+ walkSubElementsImpl(interface, walkAttrsFn, walkTypesFn, visitedAttrs,
+ visitedTypes);
// Walk this type.
walkTypesFn(type);
@@ -47,14 +67,20 @@ void SubElementAttrInterface::walkSubElements(
function_ref<void(Attribute)> walkAttrsFn,
function_ref<void(Type)> walkTypesFn) {
assert(walkAttrsFn && walkTypesFn && "expected valid walk functions");
- walkSubElementsImpl(*this, walkAttrsFn, walkTypesFn);
+ DenseSet<Attribute> visitedAttrs;
+ DenseSet<Type> visitedTypes;
+ walkSubElementsImpl(*this, walkAttrsFn, walkTypesFn, visitedAttrs,
+ visitedTypes);
}
void SubElementTypeInterface::walkSubElements(
function_ref<void(Attribute)> walkAttrsFn,
function_ref<void(Type)> walkTypesFn) {
assert(walkAttrsFn && walkTypesFn && "expected valid walk functions");
- walkSubElementsImpl(*this, walkAttrsFn, walkTypesFn);
+ DenseSet<Attribute> visitedAttrs;
+ DenseSet<Type> visitedTypes;
+ walkSubElementsImpl(*this, walkAttrsFn, walkTypesFn, visitedAttrs,
+ visitedTypes);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/IR/recursive-type.mlir b/mlir/test/IR/recursive-type.mlir
index e66d7fd5a5ab4..bc9b2cdbea6b6 100644
--- a/mlir/test/IR/recursive-type.mlir
+++ b/mlir/test/IR/recursive-type.mlir
@@ -1,11 +1,17 @@
// RUN: mlir-opt %s -test-recursive-types | FileCheck %s
+// CHECK: !testrec = !test.test_rec<type_to_alias, test_rec<type_to_alias>>
+
// CHECK-LABEL: @roundtrip
func.func @roundtrip() {
// CHECK: !test.test_rec<a, test_rec<b, test_type>>
"test.dummy_op_for_roundtrip"() : () -> !test.test_rec<a, test_rec<b, test_type>>
// CHECK: !test.test_rec<c, test_rec<c>>
"test.dummy_op_for_roundtrip"() : () -> !test.test_rec<c, test_rec<c>>
+ // Make sure walkSubElementType, which is used to generate aliases, doesn't go
+ // into inifinite recursion.
+ // CHECK: !testrec
+ "test.dummy_op_for_roundtrip"() : () -> !test.test_rec<type_to_alias, test_rec<type_to_alias>>
return
}
diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp
index 3585d43b3128b..653f0e10ae319 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.cpp
+++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp
@@ -160,6 +160,13 @@ struct TestOpAsmInterface : public OpAsmDialectInterface {
return AliasResult::FinalAlias;
}
}
+ if (auto recType = type.dyn_cast<TestRecursiveType>()) {
+ if (recType.getName() == "type_to_alias") {
+ // We only make alias for a specific recursive type.
+ os << "testrec";
+ return AliasResult::FinalAlias;
+ }
+ }
return AliasResult::NoAlias;
}
diff --git a/mlir/test/lib/Dialect/Test/TestTypes.h b/mlir/test/lib/Dialect/Test/TestTypes.h
index bd6421d04735e..8772efd8590fc 100644
--- a/mlir/test/lib/Dialect/Test/TestTypes.h
+++ b/mlir/test/lib/Dialect/Test/TestTypes.h
@@ -21,6 +21,7 @@
#include "mlir/IR/Dialect.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/Operation.h"
+#include "mlir/IR/SubElementInterfaces.h"
#include "mlir/IR/Types.h"
#include "mlir/Interfaces/DataLayoutInterfaces.h"
@@ -130,7 +131,9 @@ struct TestRecursiveTypeStorage : public ::mlir::TypeStorage {
/// from type creation.
class TestRecursiveType
: public ::mlir::Type::TypeBase<TestRecursiveType, ::mlir::Type,
- TestRecursiveTypeStorage> {
+ TestRecursiveTypeStorage,
+ ::mlir::SubElementTypeInterface::Trait,
+ ::mlir::TypeTrait::IsMutable> {
public:
using Base::Base;
@@ -141,10 +144,16 @@ class TestRecursiveType
/// Body getter and setter.
::mlir::LogicalResult setBody(Type body) { return Base::mutate(body); }
- ::mlir::Type getBody() { return getImpl()->body; }
+ ::mlir::Type getBody() const { return getImpl()->body; }
/// Name/key getter.
::llvm::StringRef getName() { return getImpl()->name; }
+
+ void walkImmediateSubElements(
+ ::llvm::function_ref<void(::mlir::Attribute)> walkAttrsFn,
+ ::llvm::function_ref<void(::mlir::Type)> walkTypesFn) const {
+ walkTypesFn(getBody());
+ }
};
} // namespace test
diff --git a/mlir/unittests/Dialect/CMakeLists.txt b/mlir/unittests/Dialect/CMakeLists.txt
index ec89b14837de9..f8e5e46e5ac3b 100644
--- a/mlir/unittests/Dialect/CMakeLists.txt
+++ b/mlir/unittests/Dialect/CMakeLists.txt
@@ -7,8 +7,8 @@ target_link_libraries(MLIRDialectTests
MLIRDialect)
add_subdirectory(Affine)
+add_subdirectory(LLVMIR)
add_subdirectory(MemRef)
-
add_subdirectory(Quant)
add_subdirectory(SparseTensor)
add_subdirectory(SPIRV)
diff --git a/mlir/unittests/Dialect/LLVMIR/CMakeLists.txt b/mlir/unittests/Dialect/LLVMIR/CMakeLists.txt
new file mode 100644
index 0000000000000..92af1856c68e0
--- /dev/null
+++ b/mlir/unittests/Dialect/LLVMIR/CMakeLists.txt
@@ -0,0 +1,7 @@
+add_mlir_unittest(MLIRLLVMIRTests
+ LLVMTypeTest.cpp
+)
+target_link_libraries(MLIRLLVMIRTests
+ PRIVATE
+ MLIRLLVMDialect
+ )
diff --git a/mlir/unittests/Dialect/LLVMIR/LLVMTestBase.h b/mlir/unittests/Dialect/LLVMIR/LLVMTestBase.h
new file mode 100644
index 0000000000000..1badc44ce2132
--- /dev/null
+++ b/mlir/unittests/Dialect/LLVMIR/LLVMTestBase.h
@@ -0,0 +1,27 @@
+//===- LLVMTestBase.h - Test fixure for LLVM dialect tests ------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Test fixure for LLVM dialect tests.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_UNITTEST_DIALECT_LLVMIR_LLVMTESTBASE_H
+#define MLIR_UNITTEST_DIALECT_LLVMIR_LLVMTESTBASE_H
+
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/IR/MLIRContext.h"
+#include "gtest/gtest.h"
+
+class LLVMIRTest : public ::testing::Test {
+protected:
+ LLVMIRTest() { context.getOrLoadDialect<mlir::LLVM::LLVMDialect>(); }
+
+ mlir::MLIRContext context;
+};
+
+#endif
diff --git a/mlir/unittests/Dialect/LLVMIR/LLVMTypeTest.cpp b/mlir/unittests/Dialect/LLVMIR/LLVMTypeTest.cpp
new file mode 100644
index 0000000000000..9c0ea4f14d766
--- /dev/null
+++ b/mlir/unittests/Dialect/LLVMIR/LLVMTypeTest.cpp
@@ -0,0 +1,20 @@
+//===- LLVMTypeTest.cpp - Tests for LLVM types ----------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "LLVMTestBase.h"
+#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
+#include "mlir/IR/SubElementInterfaces.h"
+
+using namespace mlir;
+using namespace mlir::LLVM;
+
+TEST_F(LLVMIRTest, IsStructTypeMutable) {
+ auto structTy = LLVMStructType::getIdentified(&context, "foo");
+ ASSERT_TRUE(bool(structTy));
+ ASSERT_TRUE(structTy.hasTrait<TypeTrait::IsMutable>());
+}
More information about the Mlir-commits
mailing list