[Mlir-commits] [mlir] c8c4598 - [mlir][Type] Remove usages of Type::getKind
River Riddle
llvmlistbot at llvm.org
Fri Aug 7 13:43:54 PDT 2020
Author: River Riddle
Date: 2020-08-07T13:43:25-07:00
New Revision: c8c45985fba935f28943d6218915d7fe5a5fc807
URL: https://github.com/llvm/llvm-project/commit/c8c45985fba935f28943d6218915d7fe5a5fc807
DIFF: https://github.com/llvm/llvm-project/commit/c8c45985fba935f28943d6218915d7fe5a5fc807.diff
LOG: [mlir][Type] Remove usages of Type::getKind
This is in preparation for removing the use of "kinds" within attributes and types in MLIR.
Differential Revision: https://reviews.llvm.org/D85475
Added:
Modified:
mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
mlir/include/mlir/Dialect/Quant/QuantTypes.h
mlir/include/mlir/IR/StandardTypes.h
mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
mlir/lib/Dialect/Linalg/IR/LinalgTypes.cpp
mlir/lib/Dialect/Quant/IR/QuantTypes.cpp
mlir/lib/Dialect/Quant/IR/TypeParser.cpp
mlir/lib/Dialect/Quant/Utils/UniformSupport.cpp
mlir/lib/Dialect/SPIRV/LayoutUtils.cpp
mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp
mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp
mlir/lib/Dialect/Traits.cpp
mlir/lib/IR/StandardTypes.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
index 32ec77ad63e4..449ade85a5b4 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
@@ -101,10 +101,7 @@ class LLVMType : public Type {
}
/// Support for isa/cast.
- static bool classof(Type type) {
- return type.getKind() >= FIRST_NEW_LLVM_TYPE &&
- type.getKind() <= LAST_NEW_LLVM_TYPE;
- }
+ static bool classof(Type type);
LLVMDialect &getDialect();
diff --git a/mlir/include/mlir/Dialect/Quant/QuantTypes.h b/mlir/include/mlir/Dialect/Quant/QuantTypes.h
index 689ac49163eb..ccdc289a9a7c 100644
--- a/mlir/include/mlir/Dialect/Quant/QuantTypes.h
+++ b/mlir/include/mlir/Dialect/Quant/QuantTypes.h
@@ -71,10 +71,7 @@ class QuantizedType : public Type {
int64_t storageTypeMax);
/// Support method to enable LLVM-style type casting.
- static bool classof(Type type) {
- return type.getKind() >= Type::FIRST_QUANTIZATION_TYPE &&
- type.getKind() <= QuantizationTypes::LAST_USED_QUANTIZATION_TYPE;
- }
+ static bool classof(Type type);
/// Gets the minimum possible stored by a storageType. storageTypeMin must
/// be greater than or equal to this value.
diff --git a/mlir/include/mlir/IR/StandardTypes.h b/mlir/include/mlir/IR/StandardTypes.h
index ff9af5bcdeeb..11f7e442f416 100644
--- a/mlir/include/mlir/IR/StandardTypes.h
+++ b/mlir/include/mlir/IR/StandardTypes.h
@@ -294,13 +294,7 @@ class ShapedType : public Type {
int64_t getSizeInBits() const;
/// Methods for support type inquiry through isa, cast, and dyn_cast.
- static bool classof(Type type) {
- return type.getKind() == StandardTypes::Vector ||
- type.getKind() == StandardTypes::RankedTensor ||
- type.getKind() == StandardTypes::UnrankedTensor ||
- type.getKind() == StandardTypes::UnrankedMemRef ||
- type.getKind() == StandardTypes::MemRef;
- }
+ static bool classof(Type type);
/// Whether the given dimension size indicates a dynamic dimension.
static constexpr bool isDynamic(int64_t dSize) {
@@ -358,20 +352,10 @@ class TensorType : public ShapedType {
using ShapedType::ShapedType;
/// Return true if the specified element type is ok in a tensor.
- static bool isValidElementType(Type type) {
- // Note: Non standard/builtin types are allowed to exist within tensor
- // types. Dialects are expected to verify that tensor types have a valid
- // element type within that dialect.
- return type.isa<ComplexType, FloatType, IntegerType, OpaqueType, VectorType,
- IndexType>() ||
- (type.getKind() > Type::Kind::LAST_STANDARD_TYPE);
- }
+ static bool isValidElementType(Type type);
/// Methods for support type inquiry through isa, cast, and dyn_cast.
- static bool classof(Type type) {
- return type.getKind() == StandardTypes::RankedTensor ||
- type.getKind() == StandardTypes::UnrankedTensor;
- }
+ static bool classof(Type type);
};
//===----------------------------------------------------------------------===//
@@ -443,10 +427,7 @@ class BaseMemRefType : public ShapedType {
using ShapedType::ShapedType;
/// Methods for support type inquiry through isa, cast, and dyn_cast.
- static bool classof(Type type) {
- return type.getKind() == StandardTypes::MemRef ||
- type.getKind() == StandardTypes::UnrankedMemRef;
- }
+ static bool classof(Type type);
};
//===----------------------------------------------------------------------===//
@@ -629,6 +610,23 @@ class TupleType
}
};
+//===----------------------------------------------------------------------===//
+// Deferred Method Definitions
+//===----------------------------------------------------------------------===//
+
+inline bool BaseMemRefType::classof(Type type) {
+ return type.isa<MemRefType, UnrankedMemRefType>();
+}
+
+inline bool ShapedType::classof(Type type) {
+ return type.isa<RankedTensorType, VectorType, UnrankedTensorType,
+ UnrankedMemRefType, MemRefType>();
+}
+
+inline bool TensorType::classof(Type type) {
+ return type.isa<RankedTensorType, UnrankedTensorType>();
+}
+
//===----------------------------------------------------------------------===//
// Type Utilities
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
index ee429c1f73e3..7d052e3d644e 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
@@ -27,6 +27,10 @@ using namespace mlir::LLVM;
// LLVMType.
//===----------------------------------------------------------------------===//
+bool LLVMType::classof(Type type) {
+ return llvm::isa<LLVMDialect>(type.getDialect());
+}
+
LLVMDialect &LLVMType::getDialect() {
return static_cast<LLVMDialect &>(Type::getDialect());
}
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgTypes.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgTypes.cpp
index 50924f7b7866..b8bffd35f5a1 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgTypes.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgTypes.cpp
@@ -55,11 +55,5 @@ static void print(RangeType rt, DialectAsmPrinter &os) { os << "range"; }
void mlir::linalg::LinalgDialect::printType(Type type,
DialectAsmPrinter &os) const {
- switch (type.getKind()) {
- default:
- llvm_unreachable("Unhandled Linalg type");
- case LinalgTypes::Range:
- print(type.cast<RangeType>(), os);
- break;
- }
+ print(type.cast<RangeType>(), os);
}
diff --git a/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp b/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp
index 7ef5b4a77c54..ef7d8144b259 100644
--- a/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp
+++ b/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp
@@ -8,6 +8,7 @@
#include "mlir/Dialect/Quant/QuantTypes.h"
#include "TypeDetail.h"
+#include "mlir/Dialect/Quant/QuantOps.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/StandardTypes.h"
@@ -23,6 +24,10 @@ unsigned QuantizedType::getFlags() const {
return static_cast<ImplType *>(impl)->flags;
}
+bool QuantizedType::classof(Type type) {
+ return llvm::isa<QuantizationDialect>(type.getDialect());
+}
+
LogicalResult QuantizedType::verifyConstructionInvariants(
Location loc, unsigned flags, Type storageType, Type expressedType,
int64_t storageTypeMin, int64_t storageTypeMax) {
diff --git a/mlir/lib/Dialect/Quant/IR/TypeParser.cpp b/mlir/lib/Dialect/Quant/IR/TypeParser.cpp
index 60b3b15d02af..c3fc3e5775c6 100644
--- a/mlir/lib/Dialect/Quant/IR/TypeParser.cpp
+++ b/mlir/lib/Dialect/Quant/IR/TypeParser.cpp
@@ -365,18 +365,12 @@ static void printUniformQuantizedPerAxisType(UniformQuantizedPerAxisType type,
/// Print a type registered to this dialect.
void QuantizationDialect::printType(Type type, DialectAsmPrinter &os) const {
- switch (type.getKind()) {
- default:
+ if (auto anyType = type.dyn_cast<AnyQuantizedType>())
+ printAnyQuantizedType(anyType, os);
+ else if (auto uniformType = type.dyn_cast<UniformQuantizedType>())
+ printUniformQuantizedType(uniformType, os);
+ else if (auto perAxisType = type.dyn_cast<UniformQuantizedPerAxisType>())
+ printUniformQuantizedPerAxisType(perAxisType, os);
+ else
llvm_unreachable("Unhandled quantized type");
- case QuantizationTypes::Any:
- printAnyQuantizedType(type.cast<AnyQuantizedType>(), os);
- break;
- case QuantizationTypes::UniformQuantized:
- printUniformQuantizedType(type.cast<UniformQuantizedType>(), os);
- break;
- case QuantizationTypes::UniformQuantizedPerAxis:
- printUniformQuantizedPerAxisType(type.cast<UniformQuantizedPerAxisType>(),
- os);
- break;
- }
}
diff --git a/mlir/lib/Dialect/Quant/Utils/UniformSupport.cpp b/mlir/lib/Dialect/Quant/Utils/UniformSupport.cpp
index a79ef0023a7f..bde201e1ef1f 100644
--- a/mlir/lib/Dialect/Quant/Utils/UniformSupport.cpp
+++ b/mlir/lib/Dialect/Quant/Utils/UniformSupport.cpp
@@ -19,48 +19,33 @@ static bool isQuantizablePrimitiveType(Type inputType) {
const ExpressedToQuantizedConverter
ExpressedToQuantizedConverter::forInputType(Type inputType) {
- switch (inputType.getKind()) {
- default:
- if (isQuantizablePrimitiveType(inputType)) {
- // Supported primitive type (which just is the expressed type).
- return ExpressedToQuantizedConverter{inputType, inputType};
- }
- // Unsupported.
- return ExpressedToQuantizedConverter{inputType, nullptr};
- case StandardTypes::RankedTensor:
- case StandardTypes::UnrankedTensor:
- case StandardTypes::Vector: {
+ if (inputType.isa<TensorType, VectorType>()) {
Type elementType = inputType.cast<ShapedType>().getElementType();
- if (!isQuantizablePrimitiveType(elementType)) {
- // Unsupported.
+ if (!isQuantizablePrimitiveType(elementType))
return ExpressedToQuantizedConverter{inputType, nullptr};
- }
- return ExpressedToQuantizedConverter{
- inputType, inputType.cast<ShapedType>().getElementType()};
- }
+ return ExpressedToQuantizedConverter{inputType, elementType};
}
+ // Supported primitive type (which just is the expressed type).
+ if (isQuantizablePrimitiveType(inputType))
+ return ExpressedToQuantizedConverter{inputType, inputType};
+ // Unsupported.
+ return ExpressedToQuantizedConverter{inputType, nullptr};
}
Type ExpressedToQuantizedConverter::convert(QuantizedType elementalType) const {
assert(expressedType && "convert() on unsupported conversion");
-
- switch (inputType.getKind()) {
- default:
- if (elementalType.getExpressedType() == expressedType) {
- // If the expressed types match, just use the new elemental type.
- return elementalType;
- }
- // Unsupported.
- return nullptr;
- case StandardTypes::RankedTensor:
- return RankedTensorType::get(inputType.cast<RankedTensorType>().getShape(),
- elementalType);
- case StandardTypes::UnrankedTensor:
+ if (auto tensorType = inputType.dyn_cast<RankedTensorType>())
+ return RankedTensorType::get(tensorType.getShape(), elementalType);
+ if (auto tensorType = inputType.dyn_cast<UnrankedTensorType>())
return UnrankedTensorType::get(elementalType);
- case StandardTypes::Vector:
- return VectorType::get(inputType.cast<VectorType>().getShape(),
- elementalType);
- }
+ if (auto vectorType = inputType.dyn_cast<VectorType>())
+ return VectorType::get(vectorType.getShape(), elementalType);
+
+ // If the expressed types match, just use the new elemental type.
+ if (elementalType.getExpressedType() == expressedType)
+ return elementalType;
+ // Unsupported.
+ return nullptr;
}
ElementsAttr
diff --git a/mlir/lib/Dialect/SPIRV/LayoutUtils.cpp b/mlir/lib/Dialect/SPIRV/LayoutUtils.cpp
index e456f499b745..c303f38a8e0c 100644
--- a/mlir/lib/Dialect/SPIRV/LayoutUtils.cpp
+++ b/mlir/lib/Dialect/SPIRV/LayoutUtils.cpp
@@ -78,20 +78,17 @@ Type VulkanLayoutUtils::decorateType(Type type, VulkanLayoutUtils::Size &size,
size = alignment;
return type;
}
-
- switch (type.getKind()) {
- case spirv::TypeKind::Struct:
- return decorateType(type.cast<spirv::StructType>(), size, alignment);
- case spirv::TypeKind::Array:
- return decorateType(type.cast<spirv::ArrayType>(), size, alignment);
- case StandardTypes::Vector:
- return decorateType(type.cast<VectorType>(), size, alignment);
- case spirv::TypeKind::RuntimeArray:
+ if (auto structType = type.dyn_cast<spirv::StructType>())
+ return decorateType(structType, size, alignment);
+ if (auto arrayType = type.dyn_cast<spirv::ArrayType>())
+ return decorateType(arrayType, size, alignment);
+ if (auto vectorType = type.dyn_cast<VectorType>())
+ return decorateType(vectorType, size, alignment);
+ if (auto arrayType = type.dyn_cast<spirv::RuntimeArrayType>()) {
size = std::numeric_limits<Size>().max();
- return decorateType(type.cast<spirv::RuntimeArrayType>(), alignment);
- default:
- llvm_unreachable("unhandled SPIR-V type");
+ return decorateType(arrayType, alignment);
}
+ llvm_unreachable("unhandled SPIR-V type");
}
Type VulkanLayoutUtils::decorateType(VectorType vectorType,
diff --git a/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp
index 01c305720571..47f4b4ecbe55 100644
--- a/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp
+++ b/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp
@@ -26,6 +26,7 @@
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/StringMap.h"
#include "llvm/ADT/StringSwitch.h"
+#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/raw_ostream.h"
namespace mlir {
@@ -727,31 +728,11 @@ static void print(MatrixType type, DialectAsmPrinter &os) {
}
void SPIRVDialect::printType(Type type, DialectAsmPrinter &os) const {
- switch (type.getKind()) {
- case TypeKind::Array:
- print(type.cast<ArrayType>(), os);
- return;
- case TypeKind::CooperativeMatrix:
- print(type.cast<CooperativeMatrixNVType>(), os);
- return;
- case TypeKind::Pointer:
- print(type.cast<PointerType>(), os);
- return;
- case TypeKind::RuntimeArray:
- print(type.cast<RuntimeArrayType>(), os);
- return;
- case TypeKind::Image:
- print(type.cast<ImageType>(), os);
- return;
- case TypeKind::Struct:
- print(type.cast<StructType>(), os);
- return;
- case TypeKind::Matrix:
- print(type.cast<MatrixType>(), os);
- return;
- default:
- llvm_unreachable("unhandled SPIR-V type");
- }
+ TypeSwitch<Type>(type)
+ .Case<ArrayType, CooperativeMatrixNVType, PointerType, RuntimeArrayType,
+ ImageType, StructType, MatrixType>(
+ [&](auto type) { print(type, os); })
+ .Default([](Type) { llvm_unreachable("unhandled SPIR-V type"); });
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
index 5e098e815d98..06b06ddcc913 100644
--- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
@@ -1534,8 +1534,7 @@ bool spirv::ConstantOp::isBuildableWith(Type type) {
if (!type.isa<spirv::SPIRVType>())
return false;
- if (type.getKind() >= Type::FIRST_SPIRV_TYPE &&
- type.getKind() <= spirv::TypeKind::LAST_SPIRV_TYPE) {
+ if (isa<SPIRVDialect>(type.getDialect())) {
// TODO: support constant struct
return type.isa<spirv::ArrayType>();
}
diff --git a/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp
index be42ca833f21..6144edef2966 100644
--- a/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp
+++ b/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp
@@ -18,6 +18,7 @@
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/StringSwitch.h"
+#include "llvm/ADT/TypeSwitch.h"
using namespace mlir;
using namespace mlir::spirv;
@@ -163,18 +164,11 @@ Optional<int64_t> ArrayType::getSizeInBytes() {
//===----------------------------------------------------------------------===//
bool CompositeType::classof(Type type) {
- switch (type.getKind()) {
- case TypeKind::Array:
- case TypeKind::CooperativeMatrix:
- case TypeKind::Matrix:
- case TypeKind::RuntimeArray:
- case TypeKind::Struct:
- return true;
- case StandardTypes::Vector:
- return isValid(type.cast<VectorType>());
- default:
- return false;
- }
+ if (auto vectorType = type.dyn_cast<VectorType>())
+ return isValid(vectorType);
+ return type
+ .isa<spirv::ArrayType, spirv::CooperativeMatrixNVType, spirv::MatrixType,
+ spirv::RuntimeArrayType, spirv::StructType>();
}
bool CompositeType::isValid(VectorType type) {
@@ -183,22 +177,14 @@ bool CompositeType::isValid(VectorType type) {
}
Type CompositeType::getElementType(unsigned index) const {
- switch (getKind()) {
- case spirv::TypeKind::Array:
- return cast<ArrayType>().getElementType();
- case spirv::TypeKind::CooperativeMatrix:
- return cast<CooperativeMatrixNVType>().getElementType();
- case spirv::TypeKind::Matrix:
- return cast<MatrixType>().getColumnType();
- case spirv::TypeKind::RuntimeArray:
- return cast<RuntimeArrayType>().getElementType();
- case spirv::TypeKind::Struct:
- return cast<StructType>().getElementType(index);
- case StandardTypes::Vector:
- return cast<VectorType>().getElementType();
- default:
- llvm_unreachable("invalid composite type");
- }
+ return TypeSwitch<Type, Type>(*this)
+ .Case<ArrayType, CooperativeMatrixNVType, RuntimeArrayType, VectorType>(
+ [](auto type) { return type.getElementType(); })
+ .Case<MatrixType>([](MatrixType type) { return type.getColumnType(); })
+ .Case<StructType>(
+ [index](StructType type) { return type.getElementType(index); })
+ .Default(
+ [](Type) -> Type { llvm_unreachable("invalid composite type"); });
}
unsigned CompositeType::getNumElements() const {
diff --git a/mlir/lib/Dialect/Traits.cpp b/mlir/lib/Dialect/Traits.cpp
index 2a557c489e0b..ca7a25dae035 100644
--- a/mlir/lib/Dialect/Traits.cpp
+++ b/mlir/lib/Dialect/Traits.cpp
@@ -123,16 +123,16 @@ Type OpTrait::util::getBroadcastedType(Type type1, Type type2,
// Returns the type kind if the given type is a vector or ranked tensor type.
// Returns llvm::None otherwise.
- auto getCompositeTypeKind = [](Type type) -> Optional<StandardTypes::Kind> {
+ auto getCompositeTypeKind = [](Type type) -> Optional<TypeID> {
if (type.isa<VectorType, RankedTensorType>())
- return static_cast<StandardTypes::Kind>(type.getKind());
+ return type.getTypeID();
return llvm::None;
};
// Make sure the composite type, if has, is consistent.
- auto compositeKind1 = getCompositeTypeKind(type1);
- auto compositeKind2 = getCompositeTypeKind(type2);
- Optional<StandardTypes::Kind> resultCompositeKind;
+ Optional<TypeID> compositeKind1 = getCompositeTypeKind(type1);
+ Optional<TypeID> compositeKind2 = getCompositeTypeKind(type2);
+ Optional<TypeID> resultCompositeKind;
if (compositeKind1 && compositeKind2) {
// Disallow mixing vector and tensor.
@@ -151,9 +151,9 @@ Type OpTrait::util::getBroadcastedType(Type type1, Type type2,
return {};
// Compose the final broadcasted type
- if (resultCompositeKind == StandardTypes::Vector)
+ if (resultCompositeKind == VectorType::getTypeID())
return VectorType::get(resultShape, elementType);
- if (resultCompositeKind == StandardTypes::RankedTensor)
+ if (resultCompositeKind == RankedTensorType::getTypeID())
return RankedTensorType::get(resultShape, elementType);
return elementType;
}
diff --git a/mlir/lib/IR/StandardTypes.cpp b/mlir/lib/IR/StandardTypes.cpp
index 0cc82477d380..fc4f555b37f1 100644
--- a/mlir/lib/IR/StandardTypes.cpp
+++ b/mlir/lib/IR/StandardTypes.cpp
@@ -11,6 +11,7 @@
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Diagnostics.h"
+#include "mlir/IR/Dialect.h"
#include "llvm/ADT/APFloat.h"
#include "llvm/ADT/Twine.h"
@@ -244,16 +245,11 @@ int64_t ShapedType::getSizeInBits() const {
}
ArrayRef<int64_t> ShapedType::getShape() const {
- switch (getKind()) {
- case StandardTypes::Vector:
- return cast<VectorType>().getShape();
- case StandardTypes::RankedTensor:
- return cast<RankedTensorType>().getShape();
- case StandardTypes::MemRef:
- return cast<MemRefType>().getShape();
- default:
- llvm_unreachable("not a ShapedType or not ranked");
- }
+ if (auto vectorType = dyn_cast<VectorType>())
+ return vectorType.getShape();
+ if (auto tensorType = dyn_cast<RankedTensorType>())
+ return tensorType.getShape();
+ return cast<MemRefType>().getShape();
}
int64_t ShapedType::getNumDynamicDims() const {
@@ -305,13 +301,23 @@ ArrayRef<int64_t> VectorType::getShape() const { return getImpl()->getShape(); }
// Check if "elementType" can be an element type of a tensor. Emit errors if
// location is not nullptr. Returns failure if check failed.
-static inline LogicalResult checkTensorElementType(Location location,
- Type elementType) {
+static LogicalResult checkTensorElementType(Location location,
+ Type elementType) {
if (!TensorType::isValidElementType(elementType))
return emitError(location, "invalid tensor element type");
return success();
}
+/// Return true if the specified element type is ok in a tensor.
+bool TensorType::isValidElementType(Type type) {
+ // Note: Non standard/builtin types are allowed to exist within tensor
+ // types. Dialects are expected to verify that tensor types have a valid
+ // element type within that dialect.
+ return type.isa<ComplexType, FloatType, IntegerType, OpaqueType, VectorType,
+ IndexType>() ||
+ !type.getDialect().getNamespace().empty();
+}
+
//===----------------------------------------------------------------------===//
// RankedTensorType
//===----------------------------------------------------------------------===//
More information about the Mlir-commits
mailing list