[Mlir-commits] [mlir] bb63d24 - [NFC][mlir] Add support for llvm style casting for mlir types

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Aug 24 09:33:12 PDT 2022


Author: Tyker
Date: 2022-08-24T09:33:01-07:00
New Revision: bb63d249f8612f87e819071663d81f516a2bec74

URL: https://github.com/llvm/llvm-project/commit/bb63d249f8612f87e819071663d81f516a2bec74
DIFF: https://github.com/llvm/llvm-project/commit/bb63d249f8612f87e819071663d81f516a2bec74.diff

LOG: [NFC][mlir] Add support for llvm style casting for mlir types

Note:
when operating on a Type hierarchy with LeafType inheriting from MiddleType which inherits from mlir::Type.
calling LeafType::classof(MiddleType) will always return false.
because classof call the static getTypeID from its parent instead of the dynamic Type::getTypeID
so classof in this context will check if the TypeID of LeafType is the same as the TypeID of MiddleType which is always false.
It is bypassed in this commit inside CastInfo<To, From>::isPossible by calling classof with an mlir::Type.
but other unsuspecting users of LeafType::classof(MiddleType) would still get an incorrect result.

Added: 
    mlir/unittests/IR/TypeTest.cpp

Modified: 
    mlir/include/mlir/IR/Types.h
    mlir/unittests/IR/CMakeLists.txt

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/IR/Types.h b/mlir/include/mlir/IR/Types.h
index cb2328ef3f033..7aafd7fe88ccf 100644
--- a/mlir/include/mlir/IR/Types.h
+++ b/mlir/include/mlir/IR/Types.h
@@ -94,11 +94,9 @@ class Type {
 
   bool operator!() const { return impl == nullptr; }
 
-  template <typename U>
-  bool isa() const;
-  template <typename First, typename Second, typename... Rest>
+  template <typename... Tys>
   bool isa() const;
-  template <typename First, typename... Rest>
+  template <typename... Tys>
   bool isa_and_nonnull() const;
   template <typename U>
   U dyn_cast() const;
@@ -185,6 +183,9 @@ class Type {
   /// Return the abstract type descriptor for this type.
   const AbstractTy &getAbstractType() { return impl->getAbstractType(); }
 
+  /// Return the Type implementation.
+  ImplType *getImpl() const { return impl; }
+
 protected:
   ImplType *impl{nullptr};
 };
@@ -250,34 +251,29 @@ inline ::llvm::hash_code hash_value(Type arg) {
   return DenseMapInfo<const Type::ImplType *>::getHashValue(arg.impl);
 }
 
-template <typename U>
-bool Type::isa() const {
-  assert(impl && "isa<> used on a null type.");
-  return U::classof(*this);
-}
-
-template <typename First, typename Second, typename... Rest>
+template <typename... Tys>
 bool Type::isa() const {
-  return isa<First>() || isa<Second, Rest...>();
+  return llvm::isa<Tys...>(*this);
 }
 
-template <typename First, typename... Rest>
+template <typename... Tys>
 bool Type::isa_and_nonnull() const {
-  return impl && isa<First, Rest...>();
+  return llvm::isa_and_present<Tys...>(*this);
 }
 
 template <typename U>
 U Type::dyn_cast() const {
-  return isa<U>() ? U(impl) : U(nullptr);
+  return llvm::dyn_cast<U>(*this);
 }
+
 template <typename U>
 U Type::dyn_cast_or_null() const {
-  return (impl && isa<U>()) ? U(impl) : U(nullptr);
+  return llvm::dyn_cast_or_null<U>(*this);
 }
+
 template <typename U>
 U Type::cast() const {
-  assert(isa<U>());
-  return U(impl);
+  return llvm::cast<U>(*this);
 }
 
 } // namespace mlir
@@ -325,6 +321,32 @@ struct PointerLikeTypeTraits<mlir::Type> {
   static constexpr int NumLowBitsAvailable = 3;
 };
 
+/// Add support for llvm style casts.
+/// We provide a cast between To and From if From is mlir::Type or derives from
+/// it
+template <typename To, typename From>
+struct CastInfo<To, From,
+                typename std::enable_if<
+                    std::is_same_v<mlir::Type, std::remove_const_t<From>> ||
+                    std::is_base_of_v<mlir::Type, From>>::type>
+    : NullableValueCastFailed<To>,
+      DefaultDoCastIfPossible<To, From, CastInfo<To, From>> {
+  /// Arguments are taken as mlir::Type here and not as From.
+  /// Because when casting from an intermediate type of the hierarchy to one of
+  /// its children, the val.getTypeID() inside T::classof will use the static
+  /// getTypeID of the parent instead of the non-static Type::getTypeID return
+  /// the dynamic ID. so T::classof would end up comparing the static TypeID of
+  /// The children to the static TypeID of its parent making it impossible to
+  /// downcast from the parent to the child
+  static inline bool isPossible(mlir::Type ty) {
+    /// Return a constant true instead of a dynamic true when casting to self or
+    /// up the hierarchy
+    return std::is_same_v<To, std::remove_const_t<From>> ||
+           std::is_base_of_v<To, From> || To::classof(ty);
+  }
+  static inline To doCast(mlir::Type ty) { return To(ty.getImpl()); }
+};
+
 } // namespace llvm
 
 #endif // MLIR_IR_TYPES_H

diff  --git a/mlir/unittests/IR/CMakeLists.txt b/mlir/unittests/IR/CMakeLists.txt
index 188df7ea5cf58..51978aea6e880 100644
--- a/mlir/unittests/IR/CMakeLists.txt
+++ b/mlir/unittests/IR/CMakeLists.txt
@@ -7,6 +7,7 @@ add_mlir_unittest(MLIRIRTests
   PatternMatchTest.cpp
   ShapedTypeTest.cpp
   SubElementInterfaceTest.cpp
+  TypeTest.cpp
 
   DEPENDS
   MLIRTestInterfaceIncGen

diff  --git a/mlir/unittests/IR/TypeTest.cpp b/mlir/unittests/IR/TypeTest.cpp
new file mode 100644
index 0000000000000..45be1b792e252
--- /dev/null
+++ b/mlir/unittests/IR/TypeTest.cpp
@@ -0,0 +1,67 @@
+//===- TypeTest.cpp - Type API unit tests ---------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/IR/Dialect.h"
+#include "mlir/IR/Types.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Value.h"
+#include "gtest/gtest.h"
+
+using namespace mlir;
+
+/// Mock implementations of a Type hierarchy
+struct LeafType;
+
+struct MiddleType : Type::TypeBase<MiddleType, Type, TypeStorage> {
+  using Base::Base;
+  static bool classof(Type ty) {
+    return ty.getTypeID() == TypeID::get<LeafType>() || Base::classof(ty);
+  }
+};
+
+struct LeafType : Type::TypeBase<LeafType, MiddleType, TypeStorage> {
+  using Base::Base;
+};
+
+struct FakeDialect : Dialect {
+  FakeDialect(MLIRContext *context)
+      : Dialect(getDialectNamespace(), context, TypeID::get<FakeDialect>()) {
+    addTypes<MiddleType, LeafType>();
+  }
+  static constexpr ::llvm::StringLiteral getDialectNamespace() {
+    return ::llvm::StringLiteral("fake");
+  }
+};
+
+TEST(Type, Casting) {
+  MLIRContext ctx;
+  ctx.loadDialect<FakeDialect>();
+
+  Type intTy = IntegerType::get(&ctx, 8);
+  Type nullTy;
+  MiddleType middleTy = MiddleType::get(&ctx);
+  MiddleType leafTy = LeafType::get(&ctx);
+  Type leaf2Ty = LeafType::get(&ctx);
+
+  EXPECT_TRUE(isa<IntegerType>(intTy));
+  EXPECT_FALSE(isa<FunctionType>(intTy));
+  EXPECT_FALSE(llvm::isa_and_present<IntegerType>(nullTy));
+  EXPECT_TRUE(isa<MiddleType>(middleTy));
+  EXPECT_FALSE(isa<LeafType>(middleTy));
+  EXPECT_TRUE(isa<MiddleType>(leafTy));
+  EXPECT_TRUE(isa<LeafType>(leaf2Ty));
+  EXPECT_TRUE(isa<LeafType>(leafTy));
+
+  EXPECT_TRUE(static_cast<bool>(dyn_cast<IntegerType>(intTy)));
+  EXPECT_FALSE(static_cast<bool>(dyn_cast<FunctionType>(intTy)));
+  EXPECT_FALSE(static_cast<bool>(llvm::cast_if_present<FunctionType>(nullTy)));
+  EXPECT_FALSE(
+      static_cast<bool>(llvm::dyn_cast_if_present<IntegerType>(nullTy)));
+
+  EXPECT_EQ(8u, cast<IntegerType>(intTy).getWidth());
+}


        


More information about the Mlir-commits mailing list