[Mlir-commits] [mlir] 73440ca - [mlir] Define proper DenseMapInfo for Interfaces

Markus Böck llvmlistbot at llvm.org
Wed Jul 6 03:39:13 PDT 2022


Author: Markus Böck
Date: 2022-07-06T12:27:44+02:00
New Revision: 73440ca9f878f9c4150b339bdd56b234d9167ee9

URL: https://github.com/llvm/llvm-project/commit/73440ca9f878f9c4150b339bdd56b234d9167ee9
DIFF: https://github.com/llvm/llvm-project/commit/73440ca9f878f9c4150b339bdd56b234d9167ee9.diff

LOG: [mlir] Define proper DenseMapInfo for Interfaces

Prior to this patch, using any kind of interface (op interface, attr interface, type interface) as the key of a llvm::DenseSet, llvm::DenseMap or other related containers, leads to invalid pointer dereferences, despite compiling.

The gist of the problem is that a llvm::DenseMapInfo specialization for the base type (aka one of Operation*, Type or Attribute) are selected when using an interface as a key, which uses getFromOpaquePointer with invalid pointer addresses to construct instances for the empty key and tombstone key values. The interface is then constructed with this invalid base type and an attempt is made to lookup the implementation in the interface map, which then dereferences the invalid pointer address. (For more details and the exact call chain involved see the GitHub issue below)

The current workaround is to use the more generic base type (eg. instead of DenseSet<FunctionOpInterface>, DenseSet<Operation*>), but this is strictly worse from a code perspective (doesn't enforce the invariant, code is less self documenting, having to insert casts etc).

This patch fixes that issue by defining a DenseMapInfo specialization of Interface subclasses which uses a new constructor to construct an instance without querying a concept. That allows getEmptyKey and getTombstoneKey to construct an interface with invalid pointer values.

Fixes https://github.com/llvm/llvm-project/issues/54908

Differential Revision: https://reviews.llvm.org/D129038

Added: 
    mlir/unittests/IR/InterfaceTest.cpp

Modified: 
    mlir/include/mlir/IR/Attributes.h
    mlir/include/mlir/IR/OpDefinition.h
    mlir/include/mlir/IR/Types.h
    mlir/include/mlir/Support/InterfaceSupport.h
    mlir/unittests/IR/CMakeLists.txt

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/IR/Attributes.h b/mlir/include/mlir/IR/Attributes.h
index 2195057f55ab9..6c420f13ee1d4 100644
--- a/mlir/include/mlir/IR/Attributes.h
+++ b/mlir/include/mlir/IR/Attributes.h
@@ -266,7 +266,8 @@ template <> struct DenseMapInfo<mlir::Attribute> {
 };
 template <typename T>
 struct DenseMapInfo<
-    T, std::enable_if_t<std::is_base_of<mlir::Attribute, T>::value>>
+    T, std::enable_if_t<std::is_base_of<mlir::Attribute, T>::value &&
+                        !mlir::detail::IsInterface<T>::value>>
     : public DenseMapInfo<mlir::Attribute> {
   static T getEmptyKey() {
     const void *pointer = llvm::DenseMapInfo<const void *>::getEmptyKey();

diff  --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h
index 42eccde488919..c98993a2cb93b 100644
--- a/mlir/include/mlir/IR/OpDefinition.h
+++ b/mlir/include/mlir/IR/OpDefinition.h
@@ -1963,8 +1963,9 @@ LogicalResult verifyCastInterfaceOp(
 namespace llvm {
 
 template <typename T>
-struct DenseMapInfo<
-    T, std::enable_if_t<std::is_base_of<mlir::OpState, T>::value>> {
+struct DenseMapInfo<T,
+                    std::enable_if_t<std::is_base_of<mlir::OpState, T>::value &&
+                                     !mlir::detail::IsInterface<T>::value>> {
   static inline T getEmptyKey() {
     auto *pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
     return T::getFromOpaquePointer(pointer);

diff  --git a/mlir/include/mlir/IR/Types.h b/mlir/include/mlir/IR/Types.h
index 5af73894c9c52..4efc838d021d1 100644
--- a/mlir/include/mlir/IR/Types.h
+++ b/mlir/include/mlir/IR/Types.h
@@ -282,7 +282,8 @@ template <> struct DenseMapInfo<mlir::Type> {
   static bool isEqual(mlir::Type LHS, mlir::Type RHS) { return LHS == RHS; }
 };
 template <typename T>
-struct DenseMapInfo<T, std::enable_if_t<std::is_base_of<mlir::Type, T>::value>>
+struct DenseMapInfo<T, std::enable_if_t<std::is_base_of<mlir::Type, T>::value &&
+                                        !mlir::detail::IsInterface<T>::value>>
     : public DenseMapInfo<mlir::Type> {
   static T getEmptyKey() {
     const void *pointer = llvm::DenseMapInfo<const void *>::getEmptyKey();

diff  --git a/mlir/include/mlir/Support/InterfaceSupport.h b/mlir/include/mlir/Support/InterfaceSupport.h
index 7cbd18f268b65..02eaa1bb78bec 100644
--- a/mlir/include/mlir/Support/InterfaceSupport.h
+++ b/mlir/include/mlir/Support/InterfaceSupport.h
@@ -78,6 +78,7 @@ class Interface : public BaseType {
       Interface<ConcreteType, ValueT, Traits, BaseType, BaseTrait>;
   template <typename T, typename U>
   using ExternalModel = typename Traits::template ExternalModel<T, U>;
+  using ValueType = ValueT;
 
   /// This is a special trait that registers a given interface with an object.
   template <typename ConcreteT>
@@ -104,6 +105,9 @@ class Interface : public BaseType {
     assert((!t || impl) && "expected value to provide interface instance");
   }
 
+  /// Constructor for DenseMapInfo's empty key and tombstone key.
+  Interface(ValueT t, std::nullptr_t) : BaseType(t), impl(nullptr) {}
+
   /// Support 'classof' by checking if the given object defines the concrete
   /// interface.
   static bool classof(ValueT t) { return ConcreteType::getInterfaceFor(t); }
@@ -264,7 +268,40 @@ class InterfaceMap {
   SmallVector<std::pair<TypeID, void *>> interfaces;
 };
 
+template <typename ConcreteType, typename ValueT, typename Traits,
+          typename BaseType,
+          template <typename, template <typename> class> class BaseTrait>
+void isInterfaceImpl(
+    Interface<ConcreteType, ValueT, Traits, BaseType, BaseTrait> &);
+
+template <typename T>
+using is_interface_t = decltype(isInterfaceImpl(std::declval<T &>()));
+
+template <typename T>
+using IsInterface = llvm::is_detected<is_interface_t, T>;
+
 } // namespace detail
 } // namespace mlir
 
+namespace llvm {
+
+template <typename T>
+struct DenseMapInfo<T, std::enable_if_t<mlir::detail::IsInterface<T>::value>> {
+  using ValueTypeInfo = llvm::DenseMapInfo<typename T::ValueType>;
+
+  static T getEmptyKey() { return T(ValueTypeInfo::getEmptyKey(), nullptr); }
+
+  static T getTombstoneKey() {
+    return T(ValueTypeInfo::getTombstoneKey(), nullptr);
+  }
+
+  static unsigned getHashValue(T val) {
+    return ValueTypeInfo::getHashValue(val);
+  }
+
+  static bool isEqual(T lhs, T rhs) { return ValueTypeInfo::isEqual(lhs, rhs); }
+};
+
+} // namespace llvm
+
 #endif

diff  --git a/mlir/unittests/IR/CMakeLists.txt b/mlir/unittests/IR/CMakeLists.txt
index 326fc0a0cd47c..188df7ea5cf58 100644
--- a/mlir/unittests/IR/CMakeLists.txt
+++ b/mlir/unittests/IR/CMakeLists.txt
@@ -1,6 +1,7 @@
 add_mlir_unittest(MLIRIRTests
   AttributeTest.cpp
   DialectTest.cpp
+  InterfaceTest.cpp
   InterfaceAttachmentTest.cpp
   OperationSupportTest.cpp
   PatternMatchTest.cpp

diff  --git a/mlir/unittests/IR/InterfaceTest.cpp b/mlir/unittests/IR/InterfaceTest.cpp
new file mode 100644
index 0000000000000..e77e8794d6967
--- /dev/null
+++ b/mlir/unittests/IR/InterfaceTest.cpp
@@ -0,0 +1,77 @@
+//===- InterfaceTest.cpp - Test interfaces --------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/BuiltinDialect.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/OwningOpRef.h"
+#include "gtest/gtest.h"
+
+#include "../../test/lib/Dialect/Test/TestAttributes.h"
+#include "../../test/lib/Dialect/Test/TestDialect.h"
+#include "../../test/lib/Dialect/Test/TestTypes.h"
+
+using namespace mlir;
+using namespace test;
+
+TEST(InterfaceTest, OpInterfaceDenseMapKey) {
+  MLIRContext context;
+  context.loadDialect<test::TestDialect>();
+
+  OwningOpRef<ModuleOp> module = ModuleOp::create(UnknownLoc::get(&context));
+  OpBuilder builder(module->getBody(), module->getBody()->begin());
+  auto op1 = builder.create<test::SideEffectOp>(builder.getUnknownLoc(),
+                                                builder.getI32Type());
+  auto op2 = builder.create<test::SideEffectOp>(builder.getUnknownLoc(),
+                                                builder.getI32Type());
+  auto op3 = builder.create<test::SideEffectOp>(builder.getUnknownLoc(),
+                                                builder.getI32Type());
+  DenseSet<MemoryEffectOpInterface> opSet;
+  opSet.insert(op1);
+  opSet.insert(op2);
+  opSet.erase(op1);
+  EXPECT_FALSE(opSet.contains(op1));
+  EXPECT_TRUE(opSet.contains(op2));
+  EXPECT_FALSE(opSet.contains(op3));
+}
+
+TEST(InterfaceTest, AttrInterfaceDenseMapKey) {
+  MLIRContext context;
+  context.loadDialect<test::TestDialect>();
+
+  OpBuilder builder(&context);
+
+  DenseSet<SubElementAttrInterface> attrSet;
+  auto attr1 = builder.getArrayAttr({});
+  auto attr2 = builder.getI32ArrayAttr({0});
+  auto attr3 = builder.getI32ArrayAttr({1});
+  attrSet.insert(attr1);
+  attrSet.insert(attr2);
+  attrSet.erase(attr1);
+  EXPECT_FALSE(attrSet.contains(attr1));
+  EXPECT_TRUE(attrSet.contains(attr2));
+  EXPECT_FALSE(attrSet.contains(attr3));
+}
+
+TEST(InterfaceTest, TypeInterfaceDenseMapKey) {
+  MLIRContext context;
+  context.loadDialect<test::TestDialect>();
+
+  OpBuilder builder(&context);
+  DenseSet<DataLayoutTypeInterface> typeSet;
+  auto type1 = builder.getType<test::TestTypeWithLayoutType>(1);
+  auto type2 = builder.getType<test::TestTypeWithLayoutType>(2);
+  auto type3 = builder.getType<test::TestTypeWithLayoutType>(3);
+  typeSet.insert(type1);
+  typeSet.insert(type2);
+  typeSet.erase(type1);
+  EXPECT_FALSE(typeSet.contains(type1));
+  EXPECT_TRUE(typeSet.contains(type2));
+  EXPECT_FALSE(typeSet.contains(type3));
+}


        


More information about the Mlir-commits mailing list