[Mlir-commits] [mlir] c8c4598 - [mlir][Type] Remove usages of Type::getKind

River Riddle llvmlistbot at llvm.org
Fri Aug 7 13:43:54 PDT 2020


Author: River Riddle
Date: 2020-08-07T13:43:25-07:00
New Revision: c8c45985fba935f28943d6218915d7fe5a5fc807

URL: https://github.com/llvm/llvm-project/commit/c8c45985fba935f28943d6218915d7fe5a5fc807
DIFF: https://github.com/llvm/llvm-project/commit/c8c45985fba935f28943d6218915d7fe5a5fc807.diff

LOG: [mlir][Type] Remove usages of Type::getKind

This is in preparation for removing the use of "kinds" within attributes and types in MLIR.

Differential Revision: https://reviews.llvm.org/D85475

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
    mlir/include/mlir/Dialect/Quant/QuantTypes.h
    mlir/include/mlir/IR/StandardTypes.h
    mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
    mlir/lib/Dialect/Linalg/IR/LinalgTypes.cpp
    mlir/lib/Dialect/Quant/IR/QuantTypes.cpp
    mlir/lib/Dialect/Quant/IR/TypeParser.cpp
    mlir/lib/Dialect/Quant/Utils/UniformSupport.cpp
    mlir/lib/Dialect/SPIRV/LayoutUtils.cpp
    mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp
    mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
    mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp
    mlir/lib/Dialect/Traits.cpp
    mlir/lib/IR/StandardTypes.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
index 32ec77ad63e4..449ade85a5b4 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
@@ -101,10 +101,7 @@ class LLVMType : public Type {
   }
 
   /// Support for isa/cast.
-  static bool classof(Type type) {
-    return type.getKind() >= FIRST_NEW_LLVM_TYPE &&
-           type.getKind() <= LAST_NEW_LLVM_TYPE;
-  }
+  static bool classof(Type type);
 
   LLVMDialect &getDialect();
 

diff  --git a/mlir/include/mlir/Dialect/Quant/QuantTypes.h b/mlir/include/mlir/Dialect/Quant/QuantTypes.h
index 689ac49163eb..ccdc289a9a7c 100644
--- a/mlir/include/mlir/Dialect/Quant/QuantTypes.h
+++ b/mlir/include/mlir/Dialect/Quant/QuantTypes.h
@@ -71,10 +71,7 @@ class QuantizedType : public Type {
                                int64_t storageTypeMax);
 
   /// Support method to enable LLVM-style type casting.
-  static bool classof(Type type) {
-    return type.getKind() >= Type::FIRST_QUANTIZATION_TYPE &&
-           type.getKind() <= QuantizationTypes::LAST_USED_QUANTIZATION_TYPE;
-  }
+  static bool classof(Type type);
 
   /// Gets the minimum possible stored by a storageType. storageTypeMin must
   /// be greater than or equal to this value.

diff  --git a/mlir/include/mlir/IR/StandardTypes.h b/mlir/include/mlir/IR/StandardTypes.h
index ff9af5bcdeeb..11f7e442f416 100644
--- a/mlir/include/mlir/IR/StandardTypes.h
+++ b/mlir/include/mlir/IR/StandardTypes.h
@@ -294,13 +294,7 @@ class ShapedType : public Type {
   int64_t getSizeInBits() const;
 
   /// Methods for support type inquiry through isa, cast, and dyn_cast.
-  static bool classof(Type type) {
-    return type.getKind() == StandardTypes::Vector ||
-           type.getKind() == StandardTypes::RankedTensor ||
-           type.getKind() == StandardTypes::UnrankedTensor ||
-           type.getKind() == StandardTypes::UnrankedMemRef ||
-           type.getKind() == StandardTypes::MemRef;
-  }
+  static bool classof(Type type);
 
   /// Whether the given dimension size indicates a dynamic dimension.
   static constexpr bool isDynamic(int64_t dSize) {
@@ -358,20 +352,10 @@ class TensorType : public ShapedType {
   using ShapedType::ShapedType;
 
   /// Return true if the specified element type is ok in a tensor.
-  static bool isValidElementType(Type type) {
-    // Note: Non standard/builtin types are allowed to exist within tensor
-    // types. Dialects are expected to verify that tensor types have a valid
-    // element type within that dialect.
-    return type.isa<ComplexType, FloatType, IntegerType, OpaqueType, VectorType,
-                    IndexType>() ||
-           (type.getKind() > Type::Kind::LAST_STANDARD_TYPE);
-  }
+  static bool isValidElementType(Type type);
 
   /// Methods for support type inquiry through isa, cast, and dyn_cast.
-  static bool classof(Type type) {
-    return type.getKind() == StandardTypes::RankedTensor ||
-           type.getKind() == StandardTypes::UnrankedTensor;
-  }
+  static bool classof(Type type);
 };
 
 //===----------------------------------------------------------------------===//
@@ -443,10 +427,7 @@ class BaseMemRefType : public ShapedType {
   using ShapedType::ShapedType;
 
   /// Methods for support type inquiry through isa, cast, and dyn_cast.
-  static bool classof(Type type) {
-    return type.getKind() == StandardTypes::MemRef ||
-           type.getKind() == StandardTypes::UnrankedMemRef;
-  }
+  static bool classof(Type type);
 };
 
 //===----------------------------------------------------------------------===//
@@ -629,6 +610,23 @@ class TupleType
   }
 };
 
+//===----------------------------------------------------------------------===//
+// Deferred Method Definitions
+//===----------------------------------------------------------------------===//
+
+inline bool BaseMemRefType::classof(Type type) {
+  return type.isa<MemRefType, UnrankedMemRefType>();
+}
+
+inline bool ShapedType::classof(Type type) {
+  return type.isa<RankedTensorType, VectorType, UnrankedTensorType,
+                  UnrankedMemRefType, MemRefType>();
+}
+
+inline bool TensorType::classof(Type type) {
+  return type.isa<RankedTensorType, UnrankedTensorType>();
+}
+
 //===----------------------------------------------------------------------===//
 // Type Utilities
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
index ee429c1f73e3..7d052e3d644e 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
@@ -27,6 +27,10 @@ using namespace mlir::LLVM;
 // LLVMType.
 //===----------------------------------------------------------------------===//
 
+bool LLVMType::classof(Type type) {
+  return llvm::isa<LLVMDialect>(type.getDialect());
+}
+
 LLVMDialect &LLVMType::getDialect() {
   return static_cast<LLVMDialect &>(Type::getDialect());
 }

diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgTypes.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgTypes.cpp
index 50924f7b7866..b8bffd35f5a1 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgTypes.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgTypes.cpp
@@ -55,11 +55,5 @@ static void print(RangeType rt, DialectAsmPrinter &os) { os << "range"; }
 
 void mlir::linalg::LinalgDialect::printType(Type type,
                                             DialectAsmPrinter &os) const {
-  switch (type.getKind()) {
-  default:
-    llvm_unreachable("Unhandled Linalg type");
-  case LinalgTypes::Range:
-    print(type.cast<RangeType>(), os);
-    break;
-  }
+  print(type.cast<RangeType>(), os);
 }

diff  --git a/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp b/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp
index 7ef5b4a77c54..ef7d8144b259 100644
--- a/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp
+++ b/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp
@@ -8,6 +8,7 @@
 
 #include "mlir/Dialect/Quant/QuantTypes.h"
 #include "TypeDetail.h"
+#include "mlir/Dialect/Quant/QuantOps.h"
 
 #include "mlir/IR/MLIRContext.h"
 #include "mlir/IR/StandardTypes.h"
@@ -23,6 +24,10 @@ unsigned QuantizedType::getFlags() const {
   return static_cast<ImplType *>(impl)->flags;
 }
 
+bool QuantizedType::classof(Type type) {
+  return llvm::isa<QuantizationDialect>(type.getDialect());
+}
+
 LogicalResult QuantizedType::verifyConstructionInvariants(
     Location loc, unsigned flags, Type storageType, Type expressedType,
     int64_t storageTypeMin, int64_t storageTypeMax) {

diff  --git a/mlir/lib/Dialect/Quant/IR/TypeParser.cpp b/mlir/lib/Dialect/Quant/IR/TypeParser.cpp
index 60b3b15d02af..c3fc3e5775c6 100644
--- a/mlir/lib/Dialect/Quant/IR/TypeParser.cpp
+++ b/mlir/lib/Dialect/Quant/IR/TypeParser.cpp
@@ -365,18 +365,12 @@ static void printUniformQuantizedPerAxisType(UniformQuantizedPerAxisType type,
 
 /// Print a type registered to this dialect.
 void QuantizationDialect::printType(Type type, DialectAsmPrinter &os) const {
-  switch (type.getKind()) {
-  default:
+  if (auto anyType = type.dyn_cast<AnyQuantizedType>())
+    printAnyQuantizedType(anyType, os);
+  else if (auto uniformType = type.dyn_cast<UniformQuantizedType>())
+    printUniformQuantizedType(uniformType, os);
+  else if (auto perAxisType = type.dyn_cast<UniformQuantizedPerAxisType>())
+    printUniformQuantizedPerAxisType(perAxisType, os);
+  else
     llvm_unreachable("Unhandled quantized type");
-  case QuantizationTypes::Any:
-    printAnyQuantizedType(type.cast<AnyQuantizedType>(), os);
-    break;
-  case QuantizationTypes::UniformQuantized:
-    printUniformQuantizedType(type.cast<UniformQuantizedType>(), os);
-    break;
-  case QuantizationTypes::UniformQuantizedPerAxis:
-    printUniformQuantizedPerAxisType(type.cast<UniformQuantizedPerAxisType>(),
-                                     os);
-    break;
-  }
 }

diff  --git a/mlir/lib/Dialect/Quant/Utils/UniformSupport.cpp b/mlir/lib/Dialect/Quant/Utils/UniformSupport.cpp
index a79ef0023a7f..bde201e1ef1f 100644
--- a/mlir/lib/Dialect/Quant/Utils/UniformSupport.cpp
+++ b/mlir/lib/Dialect/Quant/Utils/UniformSupport.cpp
@@ -19,48 +19,33 @@ static bool isQuantizablePrimitiveType(Type inputType) {
 
 const ExpressedToQuantizedConverter
 ExpressedToQuantizedConverter::forInputType(Type inputType) {
-  switch (inputType.getKind()) {
-  default:
-    if (isQuantizablePrimitiveType(inputType)) {
-      // Supported primitive type (which just is the expressed type).
-      return ExpressedToQuantizedConverter{inputType, inputType};
-    }
-    // Unsupported.
-    return ExpressedToQuantizedConverter{inputType, nullptr};
-  case StandardTypes::RankedTensor:
-  case StandardTypes::UnrankedTensor:
-  case StandardTypes::Vector: {
+  if (inputType.isa<TensorType, VectorType>()) {
     Type elementType = inputType.cast<ShapedType>().getElementType();
-    if (!isQuantizablePrimitiveType(elementType)) {
-      // Unsupported.
+    if (!isQuantizablePrimitiveType(elementType))
       return ExpressedToQuantizedConverter{inputType, nullptr};
-    }
-    return ExpressedToQuantizedConverter{
-        inputType, inputType.cast<ShapedType>().getElementType()};
-  }
+    return ExpressedToQuantizedConverter{inputType, elementType};
   }
+  // Supported primitive type (which just is the expressed type).
+  if (isQuantizablePrimitiveType(inputType))
+    return ExpressedToQuantizedConverter{inputType, inputType};
+  // Unsupported.
+  return ExpressedToQuantizedConverter{inputType, nullptr};
 }
 
 Type ExpressedToQuantizedConverter::convert(QuantizedType elementalType) const {
   assert(expressedType && "convert() on unsupported conversion");
-
-  switch (inputType.getKind()) {
-  default:
-    if (elementalType.getExpressedType() == expressedType) {
-      // If the expressed types match, just use the new elemental type.
-      return elementalType;
-    }
-    // Unsupported.
-    return nullptr;
-  case StandardTypes::RankedTensor:
-    return RankedTensorType::get(inputType.cast<RankedTensorType>().getShape(),
-                                 elementalType);
-  case StandardTypes::UnrankedTensor:
+  if (auto tensorType = inputType.dyn_cast<RankedTensorType>())
+    return RankedTensorType::get(tensorType.getShape(), elementalType);
+  if (auto tensorType = inputType.dyn_cast<UnrankedTensorType>())
     return UnrankedTensorType::get(elementalType);
-  case StandardTypes::Vector:
-    return VectorType::get(inputType.cast<VectorType>().getShape(),
-                           elementalType);
-  }
+  if (auto vectorType = inputType.dyn_cast<VectorType>())
+    return VectorType::get(vectorType.getShape(), elementalType);
+
+  // If the expressed types match, just use the new elemental type.
+  if (elementalType.getExpressedType() == expressedType)
+    return elementalType;
+  // Unsupported.
+  return nullptr;
 }
 
 ElementsAttr

diff  --git a/mlir/lib/Dialect/SPIRV/LayoutUtils.cpp b/mlir/lib/Dialect/SPIRV/LayoutUtils.cpp
index e456f499b745..c303f38a8e0c 100644
--- a/mlir/lib/Dialect/SPIRV/LayoutUtils.cpp
+++ b/mlir/lib/Dialect/SPIRV/LayoutUtils.cpp
@@ -78,20 +78,17 @@ Type VulkanLayoutUtils::decorateType(Type type, VulkanLayoutUtils::Size &size,
     size = alignment;
     return type;
   }
-
-  switch (type.getKind()) {
-  case spirv::TypeKind::Struct:
-    return decorateType(type.cast<spirv::StructType>(), size, alignment);
-  case spirv::TypeKind::Array:
-    return decorateType(type.cast<spirv::ArrayType>(), size, alignment);
-  case StandardTypes::Vector:
-    return decorateType(type.cast<VectorType>(), size, alignment);
-  case spirv::TypeKind::RuntimeArray:
+  if (auto structType = type.dyn_cast<spirv::StructType>())
+    return decorateType(structType, size, alignment);
+  if (auto arrayType = type.dyn_cast<spirv::ArrayType>())
+    return decorateType(arrayType, size, alignment);
+  if (auto vectorType = type.dyn_cast<VectorType>())
+    return decorateType(vectorType, size, alignment);
+  if (auto arrayType = type.dyn_cast<spirv::RuntimeArrayType>()) {
     size = std::numeric_limits<Size>().max();
-    return decorateType(type.cast<spirv::RuntimeArrayType>(), alignment);
-  default:
-    llvm_unreachable("unhandled SPIR-V type");
+    return decorateType(arrayType, alignment);
   }
+  llvm_unreachable("unhandled SPIR-V type");
 }
 
 Type VulkanLayoutUtils::decorateType(VectorType vectorType,

diff  --git a/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp
index 01c305720571..47f4b4ecbe55 100644
--- a/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp
+++ b/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp
@@ -26,6 +26,7 @@
 #include "llvm/ADT/StringExtras.h"
 #include "llvm/ADT/StringMap.h"
 #include "llvm/ADT/StringSwitch.h"
+#include "llvm/ADT/TypeSwitch.h"
 #include "llvm/Support/raw_ostream.h"
 
 namespace mlir {
@@ -727,31 +728,11 @@ static void print(MatrixType type, DialectAsmPrinter &os) {
 }
 
 void SPIRVDialect::printType(Type type, DialectAsmPrinter &os) const {
-  switch (type.getKind()) {
-  case TypeKind::Array:
-    print(type.cast<ArrayType>(), os);
-    return;
-  case TypeKind::CooperativeMatrix:
-    print(type.cast<CooperativeMatrixNVType>(), os);
-    return;
-  case TypeKind::Pointer:
-    print(type.cast<PointerType>(), os);
-    return;
-  case TypeKind::RuntimeArray:
-    print(type.cast<RuntimeArrayType>(), os);
-    return;
-  case TypeKind::Image:
-    print(type.cast<ImageType>(), os);
-    return;
-  case TypeKind::Struct:
-    print(type.cast<StructType>(), os);
-    return;
-  case TypeKind::Matrix:
-    print(type.cast<MatrixType>(), os);
-    return;
-  default:
-    llvm_unreachable("unhandled SPIR-V type");
-  }
+  TypeSwitch<Type>(type)
+      .Case<ArrayType, CooperativeMatrixNVType, PointerType, RuntimeArrayType,
+            ImageType, StructType, MatrixType>(
+          [&](auto type) { print(type, os); })
+      .Default([](Type) { llvm_unreachable("unhandled SPIR-V type"); });
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
index 5e098e815d98..06b06ddcc913 100644
--- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
@@ -1534,8 +1534,7 @@ bool spirv::ConstantOp::isBuildableWith(Type type) {
   if (!type.isa<spirv::SPIRVType>())
     return false;
 
-  if (type.getKind() >= Type::FIRST_SPIRV_TYPE &&
-      type.getKind() <= spirv::TypeKind::LAST_SPIRV_TYPE) {
+  if (isa<SPIRVDialect>(type.getDialect())) {
     // TODO: support constant struct
     return type.isa<spirv::ArrayType>();
   }

diff  --git a/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp
index be42ca833f21..6144edef2966 100644
--- a/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp
+++ b/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp
@@ -18,6 +18,7 @@
 #include "llvm/ADT/SetVector.h"
 #include "llvm/ADT/StringExtras.h"
 #include "llvm/ADT/StringSwitch.h"
+#include "llvm/ADT/TypeSwitch.h"
 
 using namespace mlir;
 using namespace mlir::spirv;
@@ -163,18 +164,11 @@ Optional<int64_t> ArrayType::getSizeInBytes() {
 //===----------------------------------------------------------------------===//
 
 bool CompositeType::classof(Type type) {
-  switch (type.getKind()) {
-  case TypeKind::Array:
-  case TypeKind::CooperativeMatrix:
-  case TypeKind::Matrix:
-  case TypeKind::RuntimeArray:
-  case TypeKind::Struct:
-    return true;
-  case StandardTypes::Vector:
-    return isValid(type.cast<VectorType>());
-  default:
-    return false;
-  }
+  if (auto vectorType = type.dyn_cast<VectorType>())
+    return isValid(vectorType);
+  return type
+      .isa<spirv::ArrayType, spirv::CooperativeMatrixNVType, spirv::MatrixType,
+           spirv::RuntimeArrayType, spirv::StructType>();
 }
 
 bool CompositeType::isValid(VectorType type) {
@@ -183,22 +177,14 @@ bool CompositeType::isValid(VectorType type) {
 }
 
 Type CompositeType::getElementType(unsigned index) const {
-  switch (getKind()) {
-  case spirv::TypeKind::Array:
-    return cast<ArrayType>().getElementType();
-  case spirv::TypeKind::CooperativeMatrix:
-    return cast<CooperativeMatrixNVType>().getElementType();
-  case spirv::TypeKind::Matrix:
-    return cast<MatrixType>().getColumnType();
-  case spirv::TypeKind::RuntimeArray:
-    return cast<RuntimeArrayType>().getElementType();
-  case spirv::TypeKind::Struct:
-    return cast<StructType>().getElementType(index);
-  case StandardTypes::Vector:
-    return cast<VectorType>().getElementType();
-  default:
-    llvm_unreachable("invalid composite type");
-  }
+  return TypeSwitch<Type, Type>(*this)
+      .Case<ArrayType, CooperativeMatrixNVType, RuntimeArrayType, VectorType>(
+          [](auto type) { return type.getElementType(); })
+      .Case<MatrixType>([](MatrixType type) { return type.getColumnType(); })
+      .Case<StructType>(
+          [index](StructType type) { return type.getElementType(index); })
+      .Default(
+          [](Type) -> Type { llvm_unreachable("invalid composite type"); });
 }
 
 unsigned CompositeType::getNumElements() const {

diff  --git a/mlir/lib/Dialect/Traits.cpp b/mlir/lib/Dialect/Traits.cpp
index 2a557c489e0b..ca7a25dae035 100644
--- a/mlir/lib/Dialect/Traits.cpp
+++ b/mlir/lib/Dialect/Traits.cpp
@@ -123,16 +123,16 @@ Type OpTrait::util::getBroadcastedType(Type type1, Type type2,
 
   // Returns the type kind if the given type is a vector or ranked tensor type.
   // Returns llvm::None otherwise.
-  auto getCompositeTypeKind = [](Type type) -> Optional<StandardTypes::Kind> {
+  auto getCompositeTypeKind = [](Type type) -> Optional<TypeID> {
     if (type.isa<VectorType, RankedTensorType>())
-      return static_cast<StandardTypes::Kind>(type.getKind());
+      return type.getTypeID();
     return llvm::None;
   };
 
   // Make sure the composite type, if has, is consistent.
-  auto compositeKind1 = getCompositeTypeKind(type1);
-  auto compositeKind2 = getCompositeTypeKind(type2);
-  Optional<StandardTypes::Kind> resultCompositeKind;
+  Optional<TypeID> compositeKind1 = getCompositeTypeKind(type1);
+  Optional<TypeID> compositeKind2 = getCompositeTypeKind(type2);
+  Optional<TypeID> resultCompositeKind;
 
   if (compositeKind1 && compositeKind2) {
     // Disallow mixing vector and tensor.
@@ -151,9 +151,9 @@ Type OpTrait::util::getBroadcastedType(Type type1, Type type2,
     return {};
 
   // Compose the final broadcasted type
-  if (resultCompositeKind == StandardTypes::Vector)
+  if (resultCompositeKind == VectorType::getTypeID())
     return VectorType::get(resultShape, elementType);
-  if (resultCompositeKind == StandardTypes::RankedTensor)
+  if (resultCompositeKind == RankedTensorType::getTypeID())
     return RankedTensorType::get(resultShape, elementType);
   return elementType;
 }

diff  --git a/mlir/lib/IR/StandardTypes.cpp b/mlir/lib/IR/StandardTypes.cpp
index 0cc82477d380..fc4f555b37f1 100644
--- a/mlir/lib/IR/StandardTypes.cpp
+++ b/mlir/lib/IR/StandardTypes.cpp
@@ -11,6 +11,7 @@
 #include "mlir/IR/AffineExpr.h"
 #include "mlir/IR/AffineMap.h"
 #include "mlir/IR/Diagnostics.h"
+#include "mlir/IR/Dialect.h"
 #include "llvm/ADT/APFloat.h"
 #include "llvm/ADT/Twine.h"
 
@@ -244,16 +245,11 @@ int64_t ShapedType::getSizeInBits() const {
 }
 
 ArrayRef<int64_t> ShapedType::getShape() const {
-  switch (getKind()) {
-  case StandardTypes::Vector:
-    return cast<VectorType>().getShape();
-  case StandardTypes::RankedTensor:
-    return cast<RankedTensorType>().getShape();
-  case StandardTypes::MemRef:
-    return cast<MemRefType>().getShape();
-  default:
-    llvm_unreachable("not a ShapedType or not ranked");
-  }
+  if (auto vectorType = dyn_cast<VectorType>())
+    return vectorType.getShape();
+  if (auto tensorType = dyn_cast<RankedTensorType>())
+    return tensorType.getShape();
+  return cast<MemRefType>().getShape();
 }
 
 int64_t ShapedType::getNumDynamicDims() const {
@@ -305,13 +301,23 @@ ArrayRef<int64_t> VectorType::getShape() const { return getImpl()->getShape(); }
 
 // Check if "elementType" can be an element type of a tensor. Emit errors if
 // location is not nullptr.  Returns failure if check failed.
-static inline LogicalResult checkTensorElementType(Location location,
-                                                   Type elementType) {
+static LogicalResult checkTensorElementType(Location location,
+                                            Type elementType) {
   if (!TensorType::isValidElementType(elementType))
     return emitError(location, "invalid tensor element type");
   return success();
 }
 
+/// Return true if the specified element type is ok in a tensor.
+bool TensorType::isValidElementType(Type type) {
+  // Note: Non standard/builtin types are allowed to exist within tensor
+  // types. Dialects are expected to verify that tensor types have a valid
+  // element type within that dialect.
+  return type.isa<ComplexType, FloatType, IntegerType, OpaqueType, VectorType,
+                  IndexType>() ||
+         !type.getDialect().getNamespace().empty();
+}
+
 //===----------------------------------------------------------------------===//
 // RankedTensorType
 //===----------------------------------------------------------------------===//


        


More information about the Mlir-commits mailing list