[Mlir-commits] [mlir] 6527712 - [mlir][Type] Remove the remaining usages of Type::getKind in preparation for its removal
River Riddle
llvmlistbot at llvm.org
Wed Aug 12 19:40:41 PDT 2020
Author: River Riddle
Date: 2020-08-12T19:33:58-07:00
New Revision: 65277126bf90401436e018fcce0fc636d34ea771
URL: https://github.com/llvm/llvm-project/commit/65277126bf90401436e018fcce0fc636d34ea771
DIFF: https://github.com/llvm/llvm-project/commit/65277126bf90401436e018fcce0fc636d34ea771.diff
LOG: [mlir][Type] Remove the remaining usages of Type::getKind in preparation for its removal
This revision removes all of the lingering usages of Type::getKind. A consequence of this is that FloatType is now split into 4 derived types that represent each of the possible float types(BFloat16Type, Float16Type, Float32Type, and Float64Type). Other than this split, this revision is NFC.
Reviewed By: mehdi_amini
Differential Revision: https://reviews.llvm.org/D85566
Added:
Modified:
mlir/include/mlir/IR/StandardTypes.h
mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp
mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp
mlir/lib/Dialect/Shape/IR/Shape.cpp
mlir/lib/IR/AsmPrinter.cpp
mlir/lib/IR/Attributes.cpp
mlir/lib/IR/Builders.cpp
mlir/lib/IR/MLIRContext.cpp
mlir/lib/IR/StandardTypes.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/IR/StandardTypes.h b/mlir/include/mlir/IR/StandardTypes.h
index a4a4566e3eb8..6ceddec53377 100644
--- a/mlir/include/mlir/IR/StandardTypes.h
+++ b/mlir/include/mlir/IR/StandardTypes.h
@@ -180,25 +180,18 @@ class IntegerType
// FloatType
//===----------------------------------------------------------------------===//
-class FloatType : public Type::TypeBase<FloatType, Type, TypeStorage> {
+class FloatType : public Type {
public:
- using Base::Base;
-
- static FloatType get(StandardTypes::Kind kind, MLIRContext *context);
+ using Type::Type;
// Convenience factories.
- static FloatType getBF16(MLIRContext *ctx) {
- return get(StandardTypes::BF16, ctx);
- }
- static FloatType getF16(MLIRContext *ctx) {
- return get(StandardTypes::F16, ctx);
- }
- static FloatType getF32(MLIRContext *ctx) {
- return get(StandardTypes::F32, ctx);
- }
- static FloatType getF64(MLIRContext *ctx) {
- return get(StandardTypes::F64, ctx);
- }
+ static FloatType getBF16(MLIRContext *ctx);
+ static FloatType getF16(MLIRContext *ctx);
+ static FloatType getF32(MLIRContext *ctx);
+ static FloatType getF64(MLIRContext *ctx);
+
+ /// Methods for support type inquiry through isa, cast, and dyn_cast.
+ static bool classof(Type type);
/// Return the bitwidth of this float type.
unsigned getWidth();
@@ -207,6 +200,67 @@ class FloatType : public Type::TypeBase<FloatType, Type, TypeStorage> {
const llvm::fltSemantics &getFloatSemantics();
};
+//===----------------------------------------------------------------------===//
+// BFloat16Type
+
+class BFloat16Type
+ : public Type::TypeBase<BFloat16Type, FloatType, TypeStorage> {
+public:
+ using Base::Base;
+
+ /// Return an instance of the bfloat16 type.
+ static BFloat16Type get(MLIRContext *context);
+};
+
+inline FloatType FloatType::getBF16(MLIRContext *ctx) {
+ return BFloat16Type::get(ctx);
+}
+
+//===----------------------------------------------------------------------===//
+// Float16Type
+
+class Float16Type : public Type::TypeBase<Float16Type, FloatType, TypeStorage> {
+public:
+ using Base::Base;
+
+ /// Return an instance of the float16 type.
+ static Float16Type get(MLIRContext *context);
+};
+
+inline FloatType FloatType::getF16(MLIRContext *ctx) {
+ return Float16Type::get(ctx);
+}
+
+//===----------------------------------------------------------------------===//
+// Float32Type
+
+class Float32Type : public Type::TypeBase<Float32Type, FloatType, TypeStorage> {
+public:
+ using Base::Base;
+
+ /// Return an instance of the float32 type.
+ static Float32Type get(MLIRContext *context);
+};
+
+inline FloatType FloatType::getF32(MLIRContext *ctx) {
+ return Float32Type::get(ctx);
+}
+
+//===----------------------------------------------------------------------===//
+// Float64Type
+
+class Float64Type : public Type::TypeBase<Float64Type, FloatType, TypeStorage> {
+public:
+ using Base::Base;
+
+ /// Return an instance of the float64 type.
+ static Float64Type get(MLIRContext *context);
+};
+
+inline FloatType FloatType::getF64(MLIRContext *ctx) {
+ return Float64Type::get(ctx);
+}
+
//===----------------------------------------------------------------------===//
// NoneType
//===----------------------------------------------------------------------===//
@@ -623,6 +677,10 @@ inline bool BaseMemRefType::classof(Type type) {
return type.isa<MemRefType, UnrankedMemRefType>();
}
+inline bool FloatType::classof(Type type) {
+ return type.isa<BFloat16Type, Float16Type, Float32Type, Float64Type>();
+}
+
inline bool ShapedType::classof(Type type) {
return type.isa<RankedTensorType, VectorType, UnrankedTensorType,
UnrankedMemRefType, MemRefType>();
diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
index 08d1c32d13c5..117cc9721af7 100644
--- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
+++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
@@ -210,19 +210,15 @@ Type LLVMTypeConverter::convertIntegerType(IntegerType type) {
}
Type LLVMTypeConverter::convertFloatType(FloatType type) {
- switch (type.getKind()) {
- case mlir::StandardTypes::F32:
+ if (type.isa<Float32Type>())
return LLVM::LLVMType::getFloatTy(&getContext());
- case mlir::StandardTypes::F64:
+ if (type.isa<Float64Type>())
return LLVM::LLVMType::getDoubleTy(&getContext());
- case mlir::StandardTypes::F16:
+ if (type.isa<Float16Type>())
return LLVM::LLVMType::getHalfTy(&getContext());
- case mlir::StandardTypes::BF16: {
+ if (type.isa<BFloat16Type>())
return LLVM::LLVMType::getBFloatTy(&getContext());
- }
- default:
- llvm_unreachable("non-float type in convertFloatType");
- }
+ llvm_unreachable("non-float type in convertFloatType");
}
// Convert a `ComplexType` to an LLVM type. The result is a complex number
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp
index 59ae48ebf3e7..aa611d76a67a 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp
@@ -10,6 +10,7 @@
#include "mlir/IR/Builders.h"
#include "mlir/IR/DialectImplementation.h"
#include "llvm/ADT/SetVector.h"
+#include "llvm/ADT/TypeSwitch.h"
using namespace mlir;
using namespace mlir::LLVM;
@@ -23,46 +24,28 @@ static void printTypeImpl(llvm::raw_ostream &os, LLVMType type,
/// Returns the keyword to use for the given type.
static StringRef getTypeKeyword(LLVMType type) {
- switch (type.getKind()) {
- case LLVMType::VoidType:
- return "void";
- case LLVMType::HalfType:
- return "half";
- case LLVMType::BFloatType:
- return "bfloat";
- case LLVMType::FloatType:
- return "float";
- case LLVMType::DoubleType:
- return "double";
- case LLVMType::FP128Type:
- return "fp128";
- case LLVMType::X86FP80Type:
- return "x86_fp80";
- case LLVMType::PPCFP128Type:
- return "ppc_fp128";
- case LLVMType::X86MMXType:
- return "x86_mmx";
- case LLVMType::TokenType:
- return "token";
- case LLVMType::LabelType:
- return "label";
- case LLVMType::MetadataType:
- return "metadata";
- case LLVMType::FunctionType:
- return "func";
- case LLVMType::IntegerType:
- return "i";
- case LLVMType::PointerType:
- return "ptr";
- case LLVMType::FixedVectorType:
- case LLVMType::ScalableVectorType:
- return "vec";
- case LLVMType::ArrayType:
- return "array";
- case LLVMType::StructType:
- return "struct";
- }
- llvm_unreachable("unhandled type kind");
+ return TypeSwitch<Type, StringRef>(type)
+ .Case<LLVMVoidType>([&](Type) { return "void"; })
+ .Case<LLVMHalfType>([&](Type) { return "half"; })
+ .Case<LLVMBFloatType>([&](Type) { return "bfloat"; })
+ .Case<LLVMFloatType>([&](Type) { return "float"; })
+ .Case<LLVMDoubleType>([&](Type) { return "double"; })
+ .Case<LLVMFP128Type>([&](Type) { return "fp128"; })
+ .Case<LLVMX86FP80Type>([&](Type) { return "x86_fp80"; })
+ .Case<LLVMPPCFP128Type>([&](Type) { return "ppc_fp128"; })
+ .Case<LLVMX86MMXType>([&](Type) { return "x86_mmx"; })
+ .Case<LLVMTokenType>([&](Type) { return "token"; })
+ .Case<LLVMLabelType>([&](Type) { return "label"; })
+ .Case<LLVMMetadataType>([&](Type) { return "metadata"; })
+ .Case<LLVMFunctionType>([&](Type) { return "func"; })
+ .Case<LLVMIntegerType>([&](Type) { return "i"; })
+ .Case<LLVMPointerType>([&](Type) { return "ptr"; })
+ .Case<LLVMVectorType>([&](Type) { return "vec"; })
+ .Case<LLVMArrayType>([&](Type) { return "array"; })
+ .Case<LLVMStructType>([&](Type) { return "struct"; })
+ .Default([](Type) -> StringRef {
+ llvm_unreachable("unexpected 'llvm' type kind");
+ });
}
/// Prints the body of a structure type. Uses `stack` to avoid printing
@@ -153,14 +136,8 @@ static void printTypeImpl(llvm::raw_ostream &os, LLVMType type,
return;
}
- unsigned kind = type.getKind();
os << getTypeKeyword(type);
- // Trivial types only consist of their keyword.
- if (LLVMType::FIRST_TRIVIAL_TYPE <= kind &&
- kind <= LLVMType::LAST_TRIVIAL_TYPE)
- return;
-
if (auto intType = type.dyn_cast<LLVMIntegerType>()) {
os << intType.getBitWidth();
return;
@@ -190,7 +167,8 @@ static void printTypeImpl(llvm::raw_ostream &os, LLVMType type,
if (auto structType = type.dyn_cast<LLVMStructType>())
return printStructType(os, structType, stack);
- printFunctionType(os, type.cast<LLVMFunctionType>(), stack);
+ if (auto funcType = type.dyn_cast<LLVMFunctionType>())
+ return printFunctionType(os, funcType, stack);
}
void mlir::LLVM::detail::printType(LLVMType type, DialectAsmPrinter &printer) {
diff --git a/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp
index 6144edef2966..b52ea812306e 100644
--- a/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp
+++ b/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp
@@ -11,6 +11,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/SPIRV/SPIRVTypes.h"
+#include "mlir/Dialect/SPIRV/SPIRVDialect.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Identifier.h"
#include "mlir/IR/StandardTypes.h"
@@ -188,108 +189,70 @@ Type CompositeType::getElementType(unsigned index) const {
}
unsigned CompositeType::getNumElements() const {
- switch (getKind()) {
- case spirv::TypeKind::Array:
- return cast<ArrayType>().getNumElements();
- case spirv::TypeKind::CooperativeMatrix:
+ if (auto arrayType = dyn_cast<ArrayType>())
+ return arrayType.getNumElements();
+ if (auto matrixType = dyn_cast<MatrixType>())
+ return matrixType.getNumColumns();
+ if (auto structType = dyn_cast<StructType>())
+ return structType.getNumElements();
+ if (auto vectorType = dyn_cast<VectorType>())
+ return vectorType.getNumElements();
+ if (isa<CooperativeMatrixNVType>()) {
llvm_unreachable(
"invalid to query number of elements of spirv::CooperativeMatrix type");
- case spirv::TypeKind::Matrix:
- return cast<MatrixType>().getNumColumns();
- case spirv::TypeKind::RuntimeArray:
+ }
+ if (isa<RuntimeArrayType>()) {
llvm_unreachable(
"invalid to query number of elements of spirv::RuntimeArray type");
- case spirv::TypeKind::Struct:
- return cast<StructType>().getNumElements();
- case StandardTypes::Vector:
- return cast<VectorType>().getNumElements();
- default:
- llvm_unreachable("invalid composite type");
}
+ llvm_unreachable("invalid composite type");
}
bool CompositeType::hasCompileTimeKnownNumElements() const {
- switch (getKind()) {
- case TypeKind::CooperativeMatrix:
- case TypeKind::RuntimeArray:
- return false;
- default:
- return true;
- }
+ return !isa<CooperativeMatrixNVType, RuntimeArrayType>();
}
void CompositeType::getExtensions(
SPIRVType::ExtensionArrayRefVector &extensions,
Optional<StorageClass> storage) {
- switch (getKind()) {
- case spirv::TypeKind::Array:
- cast<ArrayType>().getExtensions(extensions, storage);
- break;
- case spirv::TypeKind::CooperativeMatrix:
- cast<CooperativeMatrixNVType>().getExtensions(extensions, storage);
- break;
- case spirv::TypeKind::Matrix:
- cast<MatrixType>().getExtensions(extensions, storage);
- break;
- case spirv::TypeKind::RuntimeArray:
- cast<RuntimeArrayType>().getExtensions(extensions, storage);
- break;
- case spirv::TypeKind::Struct:
- cast<StructType>().getExtensions(extensions, storage);
- break;
- case StandardTypes::Vector:
- cast<VectorType>().getElementType().cast<ScalarType>().getExtensions(
- extensions, storage);
- break;
- default:
- llvm_unreachable("invalid composite type");
- }
+ TypeSwitch<Type>(*this)
+ .Case<ArrayType, CooperativeMatrixNVType, MatrixType, RuntimeArrayType,
+ StructType>(
+ [&](auto type) { type.getExtensions(extensions, storage); })
+ .Case<VectorType>([&](VectorType type) {
+ return type.getElementType().cast<ScalarType>().getExtensions(
+ extensions, storage);
+ })
+ .Default([](Type) { llvm_unreachable("invalid composite type"); });
}
void CompositeType::getCapabilities(
SPIRVType::CapabilityArrayRefVector &capabilities,
Optional<StorageClass> storage) {
- switch (getKind()) {
- case spirv::TypeKind::Array:
- cast<ArrayType>().getCapabilities(capabilities, storage);
- break;
- case spirv::TypeKind::CooperativeMatrix:
- cast<CooperativeMatrixNVType>().getCapabilities(capabilities, storage);
- break;
- case spirv::TypeKind::Matrix:
- cast<MatrixType>().getCapabilities(capabilities, storage);
- break;
- case spirv::TypeKind::RuntimeArray:
- cast<RuntimeArrayType>().getCapabilities(capabilities, storage);
- break;
- case spirv::TypeKind::Struct:
- cast<StructType>().getCapabilities(capabilities, storage);
- break;
- case StandardTypes::Vector:
- cast<VectorType>().getElementType().cast<ScalarType>().getCapabilities(
- capabilities, storage);
- break;
- default:
- llvm_unreachable("invalid composite type");
- }
+ TypeSwitch<Type>(*this)
+ .Case<ArrayType, CooperativeMatrixNVType, MatrixType, RuntimeArrayType,
+ StructType>(
+ [&](auto type) { type.getCapabilities(capabilities, storage); })
+ .Case<VectorType>([&](VectorType type) {
+ return type.getElementType().cast<ScalarType>().getCapabilities(
+ capabilities, storage);
+ })
+ .Default([](Type) { llvm_unreachable("invalid composite type"); });
}
Optional<int64_t> CompositeType::getSizeInBytes() {
- switch (getKind()) {
- case spirv::TypeKind::Array:
- return cast<ArrayType>().getSizeInBytes();
- case spirv::TypeKind::Struct:
- return cast<StructType>().getSizeInBytes();
- case StandardTypes::Vector: {
- auto elementSize =
- cast<VectorType>().getElementType().cast<ScalarType>().getSizeInBytes();
+ if (auto arrayType = dyn_cast<ArrayType>())
+ return arrayType.getSizeInBytes();
+ if (auto structType = dyn_cast<StructType>())
+ return structType.getSizeInBytes();
+ if (auto vectorType = dyn_cast<VectorType>()) {
+ Optional<int64_t> elementSize =
+ vectorType.getElementType().cast<ScalarType>().getSizeInBytes();
if (!elementSize)
return llvm::None;
- return *elementSize * cast<VectorType>().getNumElements();
- }
- default:
- return llvm::None;
+ return *elementSize * vectorType.getNumElements();
}
+ return llvm::None;
}
//===----------------------------------------------------------------------===//
@@ -741,8 +704,7 @@ Optional<int64_t> ScalarType::getSizeInBytes() {
bool SPIRVType::classof(Type type) {
// Allow SPIR-V dialect types
- if (type.getKind() >= Type::FIRST_SPIRV_TYPE &&
- type.getKind() <= TypeKind::LAST_SPIRV_TYPE)
+ if (llvm::isa<SPIRVDialect>(type.getDialect()))
return true;
if (type.isa<ScalarType>())
return true;
diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index 7cf02149fa25..511ec9bf2b4e 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -16,6 +16,7 @@
#include "mlir/IR/StandardTypes.h"
#include "mlir/Transforms/InliningUtils.h"
#include "llvm/ADT/SmallString.h"
+#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/raw_ostream.h"
using namespace mlir;
@@ -141,28 +142,14 @@ Type ShapeDialect::parseType(DialectAsmParser &parser) const {
/// Print a type registered to this dialect.
void ShapeDialect::printType(Type type, DialectAsmPrinter &os) const {
- switch (type.getKind()) {
- case ShapeTypes::Component:
- os << "component";
- return;
- case ShapeTypes::Element:
- os << "element";
- return;
- case ShapeTypes::Size:
- os << "size";
- return;
- case ShapeTypes::Shape:
- os << "shape";
- return;
- case ShapeTypes::ValueShape:
- os << "value_shape";
- return;
- case ShapeTypes::Witness:
- os << "witness";
- return;
- default:
- llvm_unreachable("unexpected 'shape' type kind");
- }
+ TypeSwitch<Type>(type)
+ .Case<ComponentType>([&](Type) { os << "component"; })
+ .Case<ElementType>([&](Type) { os << "element"; })
+ .Case<ShapeType>([&](Type) { os << "shape"; })
+ .Case<SizeType>([&](Type) { os << "size"; })
+ .Case<ValueShapeType>([&](Type) { os << "value_shape"; })
+ .Case<WitnessType>([&](Type) { os << "witness"; })
+ .Default([](Type) { llvm_unreachable("unexpected 'shape' type kind"); });
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 4470f5b4b826..c8b4a864fb63 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -1576,128 +1576,95 @@ void ModulePrinter::printType(Type type) {
}
}
- switch (type.getKind()) {
- default:
- return printDialectType(type);
-
- case Type::Kind::Opaque: {
- auto opaqueTy = type.cast<OpaqueType>();
- printDialectSymbol(os, "!", opaqueTy.getDialectNamespace(),
- opaqueTy.getTypeData());
- return;
- }
- case StandardTypes::Index:
- os << "index";
- return;
- case StandardTypes::BF16:
- os << "bf16";
- return;
- case StandardTypes::F16:
- os << "f16";
- return;
- case StandardTypes::F32:
- os << "f32";
- return;
- case StandardTypes::F64:
- os << "f64";
- return;
-
- case StandardTypes::Integer: {
- auto integer = type.cast<IntegerType>();
- if (integer.isSigned())
- os << 's';
- else if (integer.isUnsigned())
- os << 'u';
- os << 'i' << integer.getWidth();
- return;
- }
- case Type::Kind::Function: {
- auto func = type.cast<FunctionType>();
- os << '(';
- interleaveComma(func.getInputs(), [&](Type type) { printType(type); });
- os << ") -> ";
- auto results = func.getResults();
- if (results.size() == 1 && !results[0].isa<FunctionType>())
- os << results[0];
- else {
- os << '(';
- interleaveComma(results, [&](Type type) { printType(type); });
- os << ')';
- }
- return;
- }
- case StandardTypes::Vector: {
- auto v = type.cast<VectorType>();
- os << "vector<";
- for (auto dim : v.getShape())
- os << dim << 'x';
- os << v.getElementType() << '>';
- return;
- }
- case StandardTypes::RankedTensor: {
- auto v = type.cast<RankedTensorType>();
- os << "tensor<";
- for (auto dim : v.getShape()) {
- if (dim < 0)
- os << '?';
- else
- os << dim;
- os << 'x';
- }
- os << v.getElementType() << '>';
- return;
- }
- case StandardTypes::UnrankedTensor: {
- auto v = type.cast<UnrankedTensorType>();
- os << "tensor<*x";
- printType(v.getElementType());
- os << '>';
- return;
- }
- case StandardTypes::MemRef: {
- auto v = type.cast<MemRefType>();
- os << "memref<";
- for (auto dim : v.getShape()) {
- if (dim < 0)
- os << '?';
- else
- os << dim;
- os << 'x';
- }
- printType(v.getElementType());
- for (auto map : v.getAffineMaps()) {
- os << ", ";
- printAttribute(AffineMapAttr::get(map));
- }
- // Only print the memory space if it is the non-default one.
- if (v.getMemorySpace())
- os << ", " << v.getMemorySpace();
- os << '>';
- return;
- }
- case StandardTypes::UnrankedMemRef: {
- auto v = type.cast<UnrankedMemRefType>();
- os << "memref<*x";
- printType(v.getElementType());
- os << '>';
- return;
- }
- case StandardTypes::Complex:
- os << "complex<";
- printType(type.cast<ComplexType>().getElementType());
- os << '>';
- return;
- case StandardTypes::Tuple: {
- auto tuple = type.cast<TupleType>();
- os << "tuple<";
- interleaveComma(tuple.getTypes(), [&](Type type) { printType(type); });
- os << '>';
- return;
- }
- case StandardTypes::None:
- os << "none";
- return;
- }
+ TypeSwitch<Type>(type)
+ .Case<OpaqueType>([&](OpaqueType opaqueTy) {
+ printDialectSymbol(os, "!", opaqueTy.getDialectNamespace(),
+ opaqueTy.getTypeData());
+ })
+ .Case<IndexType>([&](Type) { os << "index"; })
+ .Case<BFloat16Type>([&](Type) { os << "bf16"; })
+ .Case<Float16Type>([&](Type) { os << "f16"; })
+ .Case<Float32Type>([&](Type) { os << "f32"; })
+ .Case<Float64Type>([&](Type) { os << "f64"; })
+ .Case<IntegerType>([&](IntegerType integerTy) {
+ if (integerTy.isSigned())
+ os << 's';
+ else if (integerTy.isUnsigned())
+ os << 'u';
+ os << 'i' << integerTy.getWidth();
+ })
+ .Case<FunctionType>([&](FunctionType funcTy) {
+ os << '(';
+ interleaveComma(funcTy.getInputs(), [&](Type ty) { printType(ty); });
+ os << ") -> ";
+ ArrayRef<Type> results = funcTy.getResults();
+ if (results.size() == 1 && !results[0].isa<FunctionType>()) {
+ os << results[0];
+ } else {
+ os << '(';
+ interleaveComma(results, [&](Type ty) { printType(ty); });
+ os << ')';
+ }
+ })
+ .Case<VectorType>([&](VectorType vectorTy) {
+ os << "vector<";
+ for (int64_t dim : vectorTy.getShape())
+ os << dim << 'x';
+ os << vectorTy.getElementType() << '>';
+ })
+ .Case<RankedTensorType>([&](RankedTensorType tensorTy) {
+ os << "tensor<";
+ for (int64_t dim : tensorTy.getShape()) {
+ if (ShapedType::isDynamic(dim))
+ os << '?';
+ else
+ os << dim;
+ os << 'x';
+ }
+ os << tensorTy.getElementType() << '>';
+ })
+ .Case<UnrankedTensorType>([&](UnrankedTensorType tensorTy) {
+ os << "tensor<*x";
+ printType(tensorTy.getElementType());
+ os << '>';
+ })
+ .Case<MemRefType>([&](MemRefType memrefTy) {
+ os << "memref<";
+ for (int64_t dim : memrefTy.getShape()) {
+ if (ShapedType::isDynamic(dim))
+ os << '?';
+ else
+ os << dim;
+ os << 'x';
+ }
+ printType(memrefTy.getElementType());
+ for (auto map : memrefTy.getAffineMaps()) {
+ os << ", ";
+ printAttribute(AffineMapAttr::get(map));
+ }
+ // Only print the memory space if it is the non-default one.
+ if (memrefTy.getMemorySpace())
+ os << ", " << memrefTy.getMemorySpace();
+ os << '>';
+ })
+ .Case<UnrankedMemRefType>([&](UnrankedMemRefType memrefTy) {
+ os << "memref<*x";
+ printType(memrefTy.getElementType());
+ os << '>';
+ })
+ .Case<ComplexType>([&](ComplexType complexTy) {
+ os << "complex<";
+ printType(complexTy.getElementType());
+ os << '>';
+ })
+ .Case<TupleType>([&](TupleType tupleTy) {
+ os << "tuple<";
+ interleaveComma(tupleTy.getTypes(),
+ [&](Type type) { printType(type); });
+ os << '>';
+ })
+ .Case<NoneType>([&](Type) { os << "none"; })
+ .Default([&](Type type) { return printDialectType(type); });
}
void ModulePrinter::printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
diff --git a/mlir/lib/IR/Attributes.cpp b/mlir/lib/IR/Attributes.cpp
index dceb07213492..dba7872b2a6c 100644
--- a/mlir/lib/IR/Attributes.cpp
+++ b/mlir/lib/IR/Attributes.cpp
@@ -748,21 +748,13 @@ DenseElementsAttr DenseElementsAttr::get(ShapedType type,
for (unsigned i = 0, e = values.size(); i < e; ++i) {
assert(eltType == values[i].getType() &&
"expected attribute value to have element type");
-
- switch (eltType.getKind()) {
- case StandardTypes::BF16:
- case StandardTypes::F16:
- case StandardTypes::F32:
- case StandardTypes::F64:
+ if (eltType.isa<FloatType>())
intVal = values[i].cast<FloatAttr>().getValue().bitcastToAPInt();
- break;
- case StandardTypes::Integer:
- case StandardTypes::Index:
+ else if (eltType.isa<IntegerType>())
intVal = values[i].cast<IntegerAttr>().getValue();
- break;
- default:
+ else
llvm_unreachable("unexpected element type");
- }
+
assert(intVal.getBitWidth() == bitWidth &&
"expected value to have same bitwidth as element type");
writeBits(data.data(), i * storageBitWidth, intVal);
diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp
index 68c5103bd89a..c45d03174db9 100644
--- a/mlir/lib/IR/Builders.cpp
+++ b/mlir/lib/IR/Builders.cpp
@@ -268,27 +268,19 @@ ArrayAttr Builder::getAffineMapArrayAttr(ArrayRef<AffineMap> values) {
}
Attribute Builder::getZeroAttr(Type type) {
- switch (type.getKind()) {
- case StandardTypes::BF16:
- case StandardTypes::F16:
- case StandardTypes::F32:
- case StandardTypes::F64:
+ if (type.isa<FloatType>())
return getFloatAttr(type, 0.0);
- case StandardTypes::Index:
+ if (type.isa<IndexType>())
return getIndexAttr(0);
- case StandardTypes::Integer:
+ if (auto integerType = type.dyn_cast<IntegerType>())
return getIntegerAttr(type, APInt(type.cast<IntegerType>().getWidth(), 0));
- case StandardTypes::Vector:
- case StandardTypes::RankedTensor: {
+ if (type.isa<RankedTensorType, VectorType>()) {
auto vtType = type.cast<ShapedType>();
auto element = getZeroAttr(vtType.getElementType());
if (!element)
return {};
return DenseElementsAttr::get(vtType, element);
}
- default:
- break;
- }
return {};
}
diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp
index ff6c28bb7c4d..0d66070657aa 100644
--- a/mlir/lib/IR/MLIRContext.cpp
+++ b/mlir/lib/IR/MLIRContext.cpp
@@ -95,9 +95,10 @@ struct BuiltinDialect : public Dialect {
addAttributes<CallSiteLoc, FileLineColLoc, FusedLoc, NameLoc, OpaqueLoc,
UnknownLoc>();
- addTypes<ComplexType, FloatType, FunctionType, IndexType, IntegerType,
- MemRefType, UnrankedMemRefType, NoneType, OpaqueType,
- RankedTensorType, TupleType, UnrankedTensorType, VectorType>();
+ addTypes<ComplexType, BFloat16Type, Float16Type, Float32Type, Float64Type,
+ FunctionType, IndexType, IntegerType, MemRefType,
+ UnrankedMemRefType, NoneType, OpaqueType, RankedTensorType,
+ TupleType, UnrankedTensorType, VectorType>();
// TODO: These operations should be moved to a
diff erent dialect when they
// have been fully decoupled from the core.
@@ -313,7 +314,10 @@ class MLIRContextImpl {
StorageUniquer typeUniquer;
/// Cached Type Instances.
- FloatType bf16Ty, f16Ty, f32Ty, f64Ty;
+ BFloat16Type bf16Ty;
+ Float16Type f16Ty;
+ Float32Type f32Ty;
+ Float64Type f64Ty;
IndexType indexTy;
IntegerType int1Ty, int8Ty, int16Ty, int32Ty, int64Ty, int128Ty;
NoneType noneType;
@@ -359,10 +363,10 @@ MLIRContext::MLIRContext() : impl(new MLIRContextImpl()) {
//// Types.
/// Floating-point Types.
- impl->bf16Ty = TypeUniquer::get<FloatType>(this, StandardTypes::BF16);
- impl->f16Ty = TypeUniquer::get<FloatType>(this, StandardTypes::F16);
- impl->f32Ty = TypeUniquer::get<FloatType>(this, StandardTypes::F32);
- impl->f64Ty = TypeUniquer::get<FloatType>(this, StandardTypes::F64);
+ impl->bf16Ty = TypeUniquer::get<BFloat16Type>(this, StandardTypes::BF16);
+ impl->f16Ty = TypeUniquer::get<Float16Type>(this, StandardTypes::F16);
+ impl->f32Ty = TypeUniquer::get<Float32Type>(this, StandardTypes::F32);
+ impl->f64Ty = TypeUniquer::get<Float64Type>(this, StandardTypes::F64);
/// Index Type.
impl->indexTy = TypeUniquer::get<IndexType>(this, StandardTypes::Index);
/// Integer Types.
@@ -660,19 +664,17 @@ Identifier Identifier::get(StringRef str, MLIRContext *context) {
/// This should not be used directly.
StorageUniquer &MLIRContext::getTypeUniquer() { return getImpl().typeUniquer; }
-FloatType FloatType::get(StandardTypes::Kind kind, MLIRContext *context) {
- switch (kind) {
- case StandardTypes::BF16:
- return context->getImpl().bf16Ty;
- case StandardTypes::F16:
- return context->getImpl().f16Ty;
- case StandardTypes::F32:
- return context->getImpl().f32Ty;
- case StandardTypes::F64:
- return context->getImpl().f64Ty;
- default:
- llvm_unreachable("unexpected floating-point kind");
- }
+BFloat16Type BFloat16Type::get(MLIRContext *context) {
+ return context->getImpl().bf16Ty;
+}
+Float16Type Float16Type::get(MLIRContext *context) {
+ return context->getImpl().f16Ty;
+}
+Float32Type Float32Type::get(MLIRContext *context) {
+ return context->getImpl().f32Ty;
+}
+Float64Type Float64Type::get(MLIRContext *context) {
+ return context->getImpl().f64Ty;
}
/// Get an instance of the IndexType.
diff --git a/mlir/lib/IR/StandardTypes.cpp b/mlir/lib/IR/StandardTypes.cpp
index c2295bb14573..9205a7da9000 100644
--- a/mlir/lib/IR/StandardTypes.cpp
+++ b/mlir/lib/IR/StandardTypes.cpp
@@ -22,10 +22,10 @@ using namespace mlir::detail;
// Type
//===----------------------------------------------------------------------===//
-bool Type::isBF16() { return getKind() == StandardTypes::BF16; }
-bool Type::isF16() { return getKind() == StandardTypes::F16; }
-bool Type::isF32() { return getKind() == StandardTypes::F32; }
-bool Type::isF64() { return getKind() == StandardTypes::F64; }
+bool Type::isBF16() { return isa<BFloat16Type>(); }
+bool Type::isF16() { return isa<Float16Type>(); }
+bool Type::isF32() { return isa<Float32Type>(); }
+bool Type::isF64() { return isa<Float64Type>(); }
bool Type::isIndex() { return isa<IndexType>(); }
@@ -90,6 +90,13 @@ bool Type::isIntOrFloat() { return isa<IntegerType, FloatType>(); }
bool Type::isIntOrIndexOrFloat() { return isIntOrFloat() || isIndex(); }
+unsigned Type::getIntOrFloatBitWidth() {
+ assert(isIntOrFloat() && "only integers and floats have a bitwidth");
+ if (auto intType = dyn_cast<IntegerType>())
+ return intType.getWidth();
+ return cast<FloatType>().getWidth();
+}
+
//===----------------------------------------------------------------------===//
/// ComplexType
//===----------------------------------------------------------------------===//
@@ -142,39 +149,28 @@ IntegerType::SignednessSemantics IntegerType::getSignedness() const {
//===----------------------------------------------------------------------===//
unsigned FloatType::getWidth() {
- switch (getKind()) {
- case StandardTypes::BF16:
- case StandardTypes::F16:
+ if (isa<Float16Type, BFloat16Type>())
return 16;
- case StandardTypes::F32:
+ if (isa<Float32Type>())
return 32;
- case StandardTypes::F64:
+ if (isa<Float64Type>())
return 64;
- default:
- llvm_unreachable("unexpected type");
- }
+ llvm_unreachable("unexpected float type");
}
/// Returns the floating semantics for the given type.
const llvm::fltSemantics &FloatType::getFloatSemantics() {
- if (isBF16())
+ if (isa<BFloat16Type>())
return APFloat::BFloat();
- if (isF16())
+ if (isa<Float16Type>())
return APFloat::IEEEhalf();
- if (isF32())
+ if (isa<Float32Type>())
return APFloat::IEEEsingle();
- if (isF64())
+ if (isa<Float64Type>())
return APFloat::IEEEdouble();
llvm_unreachable("non-floating point type used");
}
-unsigned Type::getIntOrFloatBitWidth() {
- assert(isIntOrFloat() && "only integers and floats have a bitwidth");
- if (auto intType = dyn_cast<IntegerType>())
- return intType.getWidth();
- return cast<FloatType>().getWidth();
-}
-
//===----------------------------------------------------------------------===//
// ShapedType
//===----------------------------------------------------------------------===//
More information about the Mlir-commits
mailing list