[Mlir-commits] [mlir] 6527712 - [mlir][Type] Remove the remaining usages of Type::getKind in preparation for its removal

River Riddle llvmlistbot at llvm.org
Wed Aug 12 19:40:41 PDT 2020


Author: River Riddle
Date: 2020-08-12T19:33:58-07:00
New Revision: 65277126bf90401436e018fcce0fc636d34ea771

URL: https://github.com/llvm/llvm-project/commit/65277126bf90401436e018fcce0fc636d34ea771
DIFF: https://github.com/llvm/llvm-project/commit/65277126bf90401436e018fcce0fc636d34ea771.diff

LOG: [mlir][Type] Remove the remaining usages of Type::getKind in preparation for its removal

This revision removes all of the lingering usages of Type::getKind. A consequence of this is that FloatType is now split into 4 derived types that represent each of the possible float types(BFloat16Type, Float16Type, Float32Type, and Float64Type). Other than this split, this revision is NFC.

Reviewed By: mehdi_amini

Differential Revision: https://reviews.llvm.org/D85566

Added: 
    

Modified: 
    mlir/include/mlir/IR/StandardTypes.h
    mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
    mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp
    mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp
    mlir/lib/Dialect/Shape/IR/Shape.cpp
    mlir/lib/IR/AsmPrinter.cpp
    mlir/lib/IR/Attributes.cpp
    mlir/lib/IR/Builders.cpp
    mlir/lib/IR/MLIRContext.cpp
    mlir/lib/IR/StandardTypes.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/IR/StandardTypes.h b/mlir/include/mlir/IR/StandardTypes.h
index a4a4566e3eb8..6ceddec53377 100644
--- a/mlir/include/mlir/IR/StandardTypes.h
+++ b/mlir/include/mlir/IR/StandardTypes.h
@@ -180,25 +180,18 @@ class IntegerType
 // FloatType
 //===----------------------------------------------------------------------===//
 
-class FloatType : public Type::TypeBase<FloatType, Type, TypeStorage> {
+class FloatType : public Type {
 public:
-  using Base::Base;
-
-  static FloatType get(StandardTypes::Kind kind, MLIRContext *context);
+  using Type::Type;
 
   // Convenience factories.
-  static FloatType getBF16(MLIRContext *ctx) {
-    return get(StandardTypes::BF16, ctx);
-  }
-  static FloatType getF16(MLIRContext *ctx) {
-    return get(StandardTypes::F16, ctx);
-  }
-  static FloatType getF32(MLIRContext *ctx) {
-    return get(StandardTypes::F32, ctx);
-  }
-  static FloatType getF64(MLIRContext *ctx) {
-    return get(StandardTypes::F64, ctx);
-  }
+  static FloatType getBF16(MLIRContext *ctx);
+  static FloatType getF16(MLIRContext *ctx);
+  static FloatType getF32(MLIRContext *ctx);
+  static FloatType getF64(MLIRContext *ctx);
+
+  /// Methods for support type inquiry through isa, cast, and dyn_cast.
+  static bool classof(Type type);
 
   /// Return the bitwidth of this float type.
   unsigned getWidth();
@@ -207,6 +200,67 @@ class FloatType : public Type::TypeBase<FloatType, Type, TypeStorage> {
   const llvm::fltSemantics &getFloatSemantics();
 };
 
+//===----------------------------------------------------------------------===//
+// BFloat16Type
+
+class BFloat16Type
+    : public Type::TypeBase<BFloat16Type, FloatType, TypeStorage> {
+public:
+  using Base::Base;
+
+  /// Return an instance of the bfloat16 type.
+  static BFloat16Type get(MLIRContext *context);
+};
+
+inline FloatType FloatType::getBF16(MLIRContext *ctx) {
+  return BFloat16Type::get(ctx);
+}
+
+//===----------------------------------------------------------------------===//
+// Float16Type
+
+class Float16Type : public Type::TypeBase<Float16Type, FloatType, TypeStorage> {
+public:
+  using Base::Base;
+
+  /// Return an instance of the float16 type.
+  static Float16Type get(MLIRContext *context);
+};
+
+inline FloatType FloatType::getF16(MLIRContext *ctx) {
+  return Float16Type::get(ctx);
+}
+
+//===----------------------------------------------------------------------===//
+// Float32Type
+
+class Float32Type : public Type::TypeBase<Float32Type, FloatType, TypeStorage> {
+public:
+  using Base::Base;
+
+  /// Return an instance of the float32 type.
+  static Float32Type get(MLIRContext *context);
+};
+
+inline FloatType FloatType::getF32(MLIRContext *ctx) {
+  return Float32Type::get(ctx);
+}
+
+//===----------------------------------------------------------------------===//
+// Float64Type
+
+class Float64Type : public Type::TypeBase<Float64Type, FloatType, TypeStorage> {
+public:
+  using Base::Base;
+
+  /// Return an instance of the float64 type.
+  static Float64Type get(MLIRContext *context);
+};
+
+inline FloatType FloatType::getF64(MLIRContext *ctx) {
+  return Float64Type::get(ctx);
+}
+
 //===----------------------------------------------------------------------===//
 // NoneType
 //===----------------------------------------------------------------------===//
@@ -623,6 +677,10 @@ inline bool BaseMemRefType::classof(Type type) {
   return type.isa<MemRefType, UnrankedMemRefType>();
 }
 
+inline bool FloatType::classof(Type type) {
+  return type.isa<BFloat16Type, Float16Type, Float32Type, Float64Type>();
+}
+
 inline bool ShapedType::classof(Type type) {
   return type.isa<RankedTensorType, VectorType, UnrankedTensorType,
                   UnrankedMemRefType, MemRefType>();

diff  --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
index 08d1c32d13c5..117cc9721af7 100644
--- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
+++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
@@ -210,19 +210,15 @@ Type LLVMTypeConverter::convertIntegerType(IntegerType type) {
 }
 
 Type LLVMTypeConverter::convertFloatType(FloatType type) {
-  switch (type.getKind()) {
-  case mlir::StandardTypes::F32:
+  if (type.isa<Float32Type>())
     return LLVM::LLVMType::getFloatTy(&getContext());
-  case mlir::StandardTypes::F64:
+  if (type.isa<Float64Type>())
     return LLVM::LLVMType::getDoubleTy(&getContext());
-  case mlir::StandardTypes::F16:
+  if (type.isa<Float16Type>())
     return LLVM::LLVMType::getHalfTy(&getContext());
-  case mlir::StandardTypes::BF16: {
+  if (type.isa<BFloat16Type>())
     return LLVM::LLVMType::getBFloatTy(&getContext());
-  }
-  default:
-    llvm_unreachable("non-float type in convertFloatType");
-  }
+  llvm_unreachable("non-float type in convertFloatType");
 }
 
 // Convert a `ComplexType` to an LLVM type. The result is a complex number

diff  --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp
index 59ae48ebf3e7..aa611d76a67a 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp
@@ -10,6 +10,7 @@
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/DialectImplementation.h"
 #include "llvm/ADT/SetVector.h"
+#include "llvm/ADT/TypeSwitch.h"
 
 using namespace mlir;
 using namespace mlir::LLVM;
@@ -23,46 +24,28 @@ static void printTypeImpl(llvm::raw_ostream &os, LLVMType type,
 
 /// Returns the keyword to use for the given type.
 static StringRef getTypeKeyword(LLVMType type) {
-  switch (type.getKind()) {
-  case LLVMType::VoidType:
-    return "void";
-  case LLVMType::HalfType:
-    return "half";
-  case LLVMType::BFloatType:
-    return "bfloat";
-  case LLVMType::FloatType:
-    return "float";
-  case LLVMType::DoubleType:
-    return "double";
-  case LLVMType::FP128Type:
-    return "fp128";
-  case LLVMType::X86FP80Type:
-    return "x86_fp80";
-  case LLVMType::PPCFP128Type:
-    return "ppc_fp128";
-  case LLVMType::X86MMXType:
-    return "x86_mmx";
-  case LLVMType::TokenType:
-    return "token";
-  case LLVMType::LabelType:
-    return "label";
-  case LLVMType::MetadataType:
-    return "metadata";
-  case LLVMType::FunctionType:
-    return "func";
-  case LLVMType::IntegerType:
-    return "i";
-  case LLVMType::PointerType:
-    return "ptr";
-  case LLVMType::FixedVectorType:
-  case LLVMType::ScalableVectorType:
-    return "vec";
-  case LLVMType::ArrayType:
-    return "array";
-  case LLVMType::StructType:
-    return "struct";
-  }
-  llvm_unreachable("unhandled type kind");
+  return TypeSwitch<Type, StringRef>(type)
+      .Case<LLVMVoidType>([&](Type) { return "void"; })
+      .Case<LLVMHalfType>([&](Type) { return "half"; })
+      .Case<LLVMBFloatType>([&](Type) { return "bfloat"; })
+      .Case<LLVMFloatType>([&](Type) { return "float"; })
+      .Case<LLVMDoubleType>([&](Type) { return "double"; })
+      .Case<LLVMFP128Type>([&](Type) { return "fp128"; })
+      .Case<LLVMX86FP80Type>([&](Type) { return "x86_fp80"; })
+      .Case<LLVMPPCFP128Type>([&](Type) { return "ppc_fp128"; })
+      .Case<LLVMX86MMXType>([&](Type) { return "x86_mmx"; })
+      .Case<LLVMTokenType>([&](Type) { return "token"; })
+      .Case<LLVMLabelType>([&](Type) { return "label"; })
+      .Case<LLVMMetadataType>([&](Type) { return "metadata"; })
+      .Case<LLVMFunctionType>([&](Type) { return "func"; })
+      .Case<LLVMIntegerType>([&](Type) { return "i"; })
+      .Case<LLVMPointerType>([&](Type) { return "ptr"; })
+      .Case<LLVMVectorType>([&](Type) { return "vec"; })
+      .Case<LLVMArrayType>([&](Type) { return "array"; })
+      .Case<LLVMStructType>([&](Type) { return "struct"; })
+      .Default([](Type) -> StringRef {
+        llvm_unreachable("unexpected 'llvm' type kind");
+      });
 }
 
 /// Prints the body of a structure type. Uses `stack` to avoid printing
@@ -153,14 +136,8 @@ static void printTypeImpl(llvm::raw_ostream &os, LLVMType type,
     return;
   }
 
-  unsigned kind = type.getKind();
   os << getTypeKeyword(type);
 
-  // Trivial types only consist of their keyword.
-  if (LLVMType::FIRST_TRIVIAL_TYPE <= kind &&
-      kind <= LLVMType::LAST_TRIVIAL_TYPE)
-    return;
-
   if (auto intType = type.dyn_cast<LLVMIntegerType>()) {
     os << intType.getBitWidth();
     return;
@@ -190,7 +167,8 @@ static void printTypeImpl(llvm::raw_ostream &os, LLVMType type,
   if (auto structType = type.dyn_cast<LLVMStructType>())
     return printStructType(os, structType, stack);
 
-  printFunctionType(os, type.cast<LLVMFunctionType>(), stack);
+  if (auto funcType = type.dyn_cast<LLVMFunctionType>())
+    return printFunctionType(os, funcType, stack);
 }
 
 void mlir::LLVM::detail::printType(LLVMType type, DialectAsmPrinter &printer) {

diff  --git a/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp
index 6144edef2966..b52ea812306e 100644
--- a/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp
+++ b/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp
@@ -11,6 +11,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/SPIRV/SPIRVTypes.h"
+#include "mlir/Dialect/SPIRV/SPIRVDialect.h"
 #include "mlir/IR/Attributes.h"
 #include "mlir/IR/Identifier.h"
 #include "mlir/IR/StandardTypes.h"
@@ -188,108 +189,70 @@ Type CompositeType::getElementType(unsigned index) const {
 }
 
 unsigned CompositeType::getNumElements() const {
-  switch (getKind()) {
-  case spirv::TypeKind::Array:
-    return cast<ArrayType>().getNumElements();
-  case spirv::TypeKind::CooperativeMatrix:
+  if (auto arrayType = dyn_cast<ArrayType>())
+    return arrayType.getNumElements();
+  if (auto matrixType = dyn_cast<MatrixType>())
+    return matrixType.getNumColumns();
+  if (auto structType = dyn_cast<StructType>())
+    return structType.getNumElements();
+  if (auto vectorType = dyn_cast<VectorType>())
+    return vectorType.getNumElements();
+  if (isa<CooperativeMatrixNVType>()) {
     llvm_unreachable(
         "invalid to query number of elements of spirv::CooperativeMatrix type");
-  case spirv::TypeKind::Matrix:
-    return cast<MatrixType>().getNumColumns();
-  case spirv::TypeKind::RuntimeArray:
+  }
+  if (isa<RuntimeArrayType>()) {
     llvm_unreachable(
         "invalid to query number of elements of spirv::RuntimeArray type");
-  case spirv::TypeKind::Struct:
-    return cast<StructType>().getNumElements();
-  case StandardTypes::Vector:
-    return cast<VectorType>().getNumElements();
-  default:
-    llvm_unreachable("invalid composite type");
   }
+  llvm_unreachable("invalid composite type");
 }
 
 bool CompositeType::hasCompileTimeKnownNumElements() const {
-  switch (getKind()) {
-  case TypeKind::CooperativeMatrix:
-  case TypeKind::RuntimeArray:
-    return false;
-  default:
-    return true;
-  }
+  return !isa<CooperativeMatrixNVType, RuntimeArrayType>();
 }
 
 void CompositeType::getExtensions(
     SPIRVType::ExtensionArrayRefVector &extensions,
     Optional<StorageClass> storage) {
-  switch (getKind()) {
-  case spirv::TypeKind::Array:
-    cast<ArrayType>().getExtensions(extensions, storage);
-    break;
-  case spirv::TypeKind::CooperativeMatrix:
-    cast<CooperativeMatrixNVType>().getExtensions(extensions, storage);
-    break;
-  case spirv::TypeKind::Matrix:
-    cast<MatrixType>().getExtensions(extensions, storage);
-    break;
-  case spirv::TypeKind::RuntimeArray:
-    cast<RuntimeArrayType>().getExtensions(extensions, storage);
-    break;
-  case spirv::TypeKind::Struct:
-    cast<StructType>().getExtensions(extensions, storage);
-    break;
-  case StandardTypes::Vector:
-    cast<VectorType>().getElementType().cast<ScalarType>().getExtensions(
-        extensions, storage);
-    break;
-  default:
-    llvm_unreachable("invalid composite type");
-  }
+  TypeSwitch<Type>(*this)
+      .Case<ArrayType, CooperativeMatrixNVType, MatrixType, RuntimeArrayType,
+            StructType>(
+          [&](auto type) { type.getExtensions(extensions, storage); })
+      .Case<VectorType>([&](VectorType type) {
+        return type.getElementType().cast<ScalarType>().getExtensions(
+            extensions, storage);
+      })
+      .Default([](Type) { llvm_unreachable("invalid composite type"); });
 }
 
 void CompositeType::getCapabilities(
     SPIRVType::CapabilityArrayRefVector &capabilities,
     Optional<StorageClass> storage) {
-  switch (getKind()) {
-  case spirv::TypeKind::Array:
-    cast<ArrayType>().getCapabilities(capabilities, storage);
-    break;
-  case spirv::TypeKind::CooperativeMatrix:
-    cast<CooperativeMatrixNVType>().getCapabilities(capabilities, storage);
-    break;
-  case spirv::TypeKind::Matrix:
-    cast<MatrixType>().getCapabilities(capabilities, storage);
-    break;
-  case spirv::TypeKind::RuntimeArray:
-    cast<RuntimeArrayType>().getCapabilities(capabilities, storage);
-    break;
-  case spirv::TypeKind::Struct:
-    cast<StructType>().getCapabilities(capabilities, storage);
-    break;
-  case StandardTypes::Vector:
-    cast<VectorType>().getElementType().cast<ScalarType>().getCapabilities(
-        capabilities, storage);
-    break;
-  default:
-    llvm_unreachable("invalid composite type");
-  }
+  TypeSwitch<Type>(*this)
+      .Case<ArrayType, CooperativeMatrixNVType, MatrixType, RuntimeArrayType,
+            StructType>(
+          [&](auto type) { type.getCapabilities(capabilities, storage); })
+      .Case<VectorType>([&](VectorType type) {
+        return type.getElementType().cast<ScalarType>().getCapabilities(
+            capabilities, storage);
+      })
+      .Default([](Type) { llvm_unreachable("invalid composite type"); });
 }
 
 Optional<int64_t> CompositeType::getSizeInBytes() {
-  switch (getKind()) {
-  case spirv::TypeKind::Array:
-    return cast<ArrayType>().getSizeInBytes();
-  case spirv::TypeKind::Struct:
-    return cast<StructType>().getSizeInBytes();
-  case StandardTypes::Vector: {
-    auto elementSize =
-        cast<VectorType>().getElementType().cast<ScalarType>().getSizeInBytes();
+  if (auto arrayType = dyn_cast<ArrayType>())
+    return arrayType.getSizeInBytes();
+  if (auto structType = dyn_cast<StructType>())
+    return structType.getSizeInBytes();
+  if (auto vectorType = dyn_cast<VectorType>()) {
+    Optional<int64_t> elementSize =
+        vectorType.getElementType().cast<ScalarType>().getSizeInBytes();
     if (!elementSize)
       return llvm::None;
-    return *elementSize * cast<VectorType>().getNumElements();
-  }
-  default:
-    return llvm::None;
+    return *elementSize * vectorType.getNumElements();
   }
+  return llvm::None;
 }
 
 //===----------------------------------------------------------------------===//
@@ -741,8 +704,7 @@ Optional<int64_t> ScalarType::getSizeInBytes() {
 
 bool SPIRVType::classof(Type type) {
   // Allow SPIR-V dialect types
-  if (type.getKind() >= Type::FIRST_SPIRV_TYPE &&
-      type.getKind() <= TypeKind::LAST_SPIRV_TYPE)
+  if (llvm::isa<SPIRVDialect>(type.getDialect()))
     return true;
   if (type.isa<ScalarType>())
     return true;

diff  --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index 7cf02149fa25..511ec9bf2b4e 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -16,6 +16,7 @@
 #include "mlir/IR/StandardTypes.h"
 #include "mlir/Transforms/InliningUtils.h"
 #include "llvm/ADT/SmallString.h"
+#include "llvm/ADT/TypeSwitch.h"
 #include "llvm/Support/raw_ostream.h"
 
 using namespace mlir;
@@ -141,28 +142,14 @@ Type ShapeDialect::parseType(DialectAsmParser &parser) const {
 
 /// Print a type registered to this dialect.
 void ShapeDialect::printType(Type type, DialectAsmPrinter &os) const {
-  switch (type.getKind()) {
-  case ShapeTypes::Component:
-    os << "component";
-    return;
-  case ShapeTypes::Element:
-    os << "element";
-    return;
-  case ShapeTypes::Size:
-    os << "size";
-    return;
-  case ShapeTypes::Shape:
-    os << "shape";
-    return;
-  case ShapeTypes::ValueShape:
-    os << "value_shape";
-    return;
-  case ShapeTypes::Witness:
-    os << "witness";
-    return;
-  default:
-    llvm_unreachable("unexpected 'shape' type kind");
-  }
+  TypeSwitch<Type>(type)
+      .Case<ComponentType>([&](Type) { os << "component"; })
+      .Case<ElementType>([&](Type) { os << "element"; })
+      .Case<ShapeType>([&](Type) { os << "shape"; })
+      .Case<SizeType>([&](Type) { os << "size"; })
+      .Case<ValueShapeType>([&](Type) { os << "value_shape"; })
+      .Case<WitnessType>([&](Type) { os << "witness"; })
+      .Default([](Type) { llvm_unreachable("unexpected 'shape' type kind"); });
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 4470f5b4b826..c8b4a864fb63 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -1576,128 +1576,95 @@ void ModulePrinter::printType(Type type) {
     }
   }
 
-  switch (type.getKind()) {
-  default:
-    return printDialectType(type);
-
-  case Type::Kind::Opaque: {
-    auto opaqueTy = type.cast<OpaqueType>();
-    printDialectSymbol(os, "!", opaqueTy.getDialectNamespace(),
-                       opaqueTy.getTypeData());
-    return;
-  }
-  case StandardTypes::Index:
-    os << "index";
-    return;
-  case StandardTypes::BF16:
-    os << "bf16";
-    return;
-  case StandardTypes::F16:
-    os << "f16";
-    return;
-  case StandardTypes::F32:
-    os << "f32";
-    return;
-  case StandardTypes::F64:
-    os << "f64";
-    return;
-
-  case StandardTypes::Integer: {
-    auto integer = type.cast<IntegerType>();
-    if (integer.isSigned())
-      os << 's';
-    else if (integer.isUnsigned())
-      os << 'u';
-    os << 'i' << integer.getWidth();
-    return;
-  }
-  case Type::Kind::Function: {
-    auto func = type.cast<FunctionType>();
-    os << '(';
-    interleaveComma(func.getInputs(), [&](Type type) { printType(type); });
-    os << ") -> ";
-    auto results = func.getResults();
-    if (results.size() == 1 && !results[0].isa<FunctionType>())
-      os << results[0];
-    else {
-      os << '(';
-      interleaveComma(results, [&](Type type) { printType(type); });
-      os << ')';
-    }
-    return;
-  }
-  case StandardTypes::Vector: {
-    auto v = type.cast<VectorType>();
-    os << "vector<";
-    for (auto dim : v.getShape())
-      os << dim << 'x';
-    os << v.getElementType() << '>';
-    return;
-  }
-  case StandardTypes::RankedTensor: {
-    auto v = type.cast<RankedTensorType>();
-    os << "tensor<";
-    for (auto dim : v.getShape()) {
-      if (dim < 0)
-        os << '?';
-      else
-        os << dim;
-      os << 'x';
-    }
-    os << v.getElementType() << '>';
-    return;
-  }
-  case StandardTypes::UnrankedTensor: {
-    auto v = type.cast<UnrankedTensorType>();
-    os << "tensor<*x";
-    printType(v.getElementType());
-    os << '>';
-    return;
-  }
-  case StandardTypes::MemRef: {
-    auto v = type.cast<MemRefType>();
-    os << "memref<";
-    for (auto dim : v.getShape()) {
-      if (dim < 0)
-        os << '?';
-      else
-        os << dim;
-      os << 'x';
-    }
-    printType(v.getElementType());
-    for (auto map : v.getAffineMaps()) {
-      os << ", ";
-      printAttribute(AffineMapAttr::get(map));
-    }
-    // Only print the memory space if it is the non-default one.
-    if (v.getMemorySpace())
-      os << ", " << v.getMemorySpace();
-    os << '>';
-    return;
-  }
-  case StandardTypes::UnrankedMemRef: {
-    auto v = type.cast<UnrankedMemRefType>();
-    os << "memref<*x";
-    printType(v.getElementType());
-    os << '>';
-    return;
-  }
-  case StandardTypes::Complex:
-    os << "complex<";
-    printType(type.cast<ComplexType>().getElementType());
-    os << '>';
-    return;
-  case StandardTypes::Tuple: {
-    auto tuple = type.cast<TupleType>();
-    os << "tuple<";
-    interleaveComma(tuple.getTypes(), [&](Type type) { printType(type); });
-    os << '>';
-    return;
-  }
-  case StandardTypes::None:
-    os << "none";
-    return;
-  }
+  TypeSwitch<Type>(type)
+      .Case<OpaqueType>([&](OpaqueType opaqueTy) {
+        printDialectSymbol(os, "!", opaqueTy.getDialectNamespace(),
+                           opaqueTy.getTypeData());
+      })
+      .Case<IndexType>([&](Type) { os << "index"; })
+      .Case<BFloat16Type>([&](Type) { os << "bf16"; })
+      .Case<Float16Type>([&](Type) { os << "f16"; })
+      .Case<Float32Type>([&](Type) { os << "f32"; })
+      .Case<Float64Type>([&](Type) { os << "f64"; })
+      .Case<IntegerType>([&](IntegerType integerTy) {
+        if (integerTy.isSigned())
+          os << 's';
+        else if (integerTy.isUnsigned())
+          os << 'u';
+        os << 'i' << integerTy.getWidth();
+      })
+      .Case<FunctionType>([&](FunctionType funcTy) {
+        os << '(';
+        interleaveComma(funcTy.getInputs(), [&](Type ty) { printType(ty); });
+        os << ") -> ";
+        ArrayRef<Type> results = funcTy.getResults();
+        if (results.size() == 1 && !results[0].isa<FunctionType>()) {
+          os << results[0];
+        } else {
+          os << '(';
+          interleaveComma(results, [&](Type ty) { printType(ty); });
+          os << ')';
+        }
+      })
+      .Case<VectorType>([&](VectorType vectorTy) {
+        os << "vector<";
+        for (int64_t dim : vectorTy.getShape())
+          os << dim << 'x';
+        os << vectorTy.getElementType() << '>';
+      })
+      .Case<RankedTensorType>([&](RankedTensorType tensorTy) {
+        os << "tensor<";
+        for (int64_t dim : tensorTy.getShape()) {
+          if (ShapedType::isDynamic(dim))
+            os << '?';
+          else
+            os << dim;
+          os << 'x';
+        }
+        os << tensorTy.getElementType() << '>';
+      })
+      .Case<UnrankedTensorType>([&](UnrankedTensorType tensorTy) {
+        os << "tensor<*x";
+        printType(tensorTy.getElementType());
+        os << '>';
+      })
+      .Case<MemRefType>([&](MemRefType memrefTy) {
+        os << "memref<";
+        for (int64_t dim : memrefTy.getShape()) {
+          if (ShapedType::isDynamic(dim))
+            os << '?';
+          else
+            os << dim;
+          os << 'x';
+        }
+        printType(memrefTy.getElementType());
+        for (auto map : memrefTy.getAffineMaps()) {
+          os << ", ";
+          printAttribute(AffineMapAttr::get(map));
+        }
+        // Only print the memory space if it is the non-default one.
+        if (memrefTy.getMemorySpace())
+          os << ", " << memrefTy.getMemorySpace();
+        os << '>';
+      })
+      .Case<UnrankedMemRefType>([&](UnrankedMemRefType memrefTy) {
+        os << "memref<*x";
+        printType(memrefTy.getElementType());
+        os << '>';
+      })
+      .Case<ComplexType>([&](ComplexType complexTy) {
+        os << "complex<";
+        printType(complexTy.getElementType());
+        os << '>';
+      })
+      .Case<TupleType>([&](TupleType tupleTy) {
+        os << "tuple<";
+        interleaveComma(tupleTy.getTypes(),
+                        [&](Type type) { printType(type); });
+        os << '>';
+      })
+      .Case<NoneType>([&](Type) { os << "none"; })
+      .Default([&](Type type) { return printDialectType(type); });
 }
 
 void ModulePrinter::printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,

diff  --git a/mlir/lib/IR/Attributes.cpp b/mlir/lib/IR/Attributes.cpp
index dceb07213492..dba7872b2a6c 100644
--- a/mlir/lib/IR/Attributes.cpp
+++ b/mlir/lib/IR/Attributes.cpp
@@ -748,21 +748,13 @@ DenseElementsAttr DenseElementsAttr::get(ShapedType type,
   for (unsigned i = 0, e = values.size(); i < e; ++i) {
     assert(eltType == values[i].getType() &&
            "expected attribute value to have element type");
-
-    switch (eltType.getKind()) {
-    case StandardTypes::BF16:
-    case StandardTypes::F16:
-    case StandardTypes::F32:
-    case StandardTypes::F64:
+    if (eltType.isa<FloatType>())
       intVal = values[i].cast<FloatAttr>().getValue().bitcastToAPInt();
-      break;
-    case StandardTypes::Integer:
-    case StandardTypes::Index:
+    else if (eltType.isa<IntegerType>())
       intVal = values[i].cast<IntegerAttr>().getValue();
-      break;
-    default:
+    else
       llvm_unreachable("unexpected element type");
-    }
+
     assert(intVal.getBitWidth() == bitWidth &&
            "expected value to have same bitwidth as element type");
     writeBits(data.data(), i * storageBitWidth, intVal);

diff  --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp
index 68c5103bd89a..c45d03174db9 100644
--- a/mlir/lib/IR/Builders.cpp
+++ b/mlir/lib/IR/Builders.cpp
@@ -268,27 +268,19 @@ ArrayAttr Builder::getAffineMapArrayAttr(ArrayRef<AffineMap> values) {
 }
 
 Attribute Builder::getZeroAttr(Type type) {
-  switch (type.getKind()) {
-  case StandardTypes::BF16:
-  case StandardTypes::F16:
-  case StandardTypes::F32:
-  case StandardTypes::F64:
+  if (type.isa<FloatType>())
     return getFloatAttr(type, 0.0);
-  case StandardTypes::Index:
+  if (type.isa<IndexType>())
     return getIndexAttr(0);
-  case StandardTypes::Integer:
+  if (auto integerType = type.dyn_cast<IntegerType>())
     return getIntegerAttr(type, APInt(type.cast<IntegerType>().getWidth(), 0));
-  case StandardTypes::Vector:
-  case StandardTypes::RankedTensor: {
+  if (type.isa<RankedTensorType, VectorType>()) {
     auto vtType = type.cast<ShapedType>();
     auto element = getZeroAttr(vtType.getElementType());
     if (!element)
       return {};
     return DenseElementsAttr::get(vtType, element);
   }
-  default:
-    break;
-  }
   return {};
 }
 

diff  --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp
index ff6c28bb7c4d..0d66070657aa 100644
--- a/mlir/lib/IR/MLIRContext.cpp
+++ b/mlir/lib/IR/MLIRContext.cpp
@@ -95,9 +95,10 @@ struct BuiltinDialect : public Dialect {
     addAttributes<CallSiteLoc, FileLineColLoc, FusedLoc, NameLoc, OpaqueLoc,
                   UnknownLoc>();
 
-    addTypes<ComplexType, FloatType, FunctionType, IndexType, IntegerType,
-             MemRefType, UnrankedMemRefType, NoneType, OpaqueType,
-             RankedTensorType, TupleType, UnrankedTensorType, VectorType>();
+    addTypes<ComplexType, BFloat16Type, Float16Type, Float32Type, Float64Type,
+             FunctionType, IndexType, IntegerType, MemRefType,
+             UnrankedMemRefType, NoneType, OpaqueType, RankedTensorType,
+             TupleType, UnrankedTensorType, VectorType>();
 
     // TODO: These operations should be moved to a 
diff erent dialect when they
     // have been fully decoupled from the core.
@@ -313,7 +314,10 @@ class MLIRContextImpl {
   StorageUniquer typeUniquer;
 
   /// Cached Type Instances.
-  FloatType bf16Ty, f16Ty, f32Ty, f64Ty;
+  BFloat16Type bf16Ty;
+  Float16Type f16Ty;
+  Float32Type f32Ty;
+  Float64Type f64Ty;
   IndexType indexTy;
   IntegerType int1Ty, int8Ty, int16Ty, int32Ty, int64Ty, int128Ty;
   NoneType noneType;
@@ -359,10 +363,10 @@ MLIRContext::MLIRContext() : impl(new MLIRContextImpl()) {
 
   //// Types.
   /// Floating-point Types.
-  impl->bf16Ty = TypeUniquer::get<FloatType>(this, StandardTypes::BF16);
-  impl->f16Ty = TypeUniquer::get<FloatType>(this, StandardTypes::F16);
-  impl->f32Ty = TypeUniquer::get<FloatType>(this, StandardTypes::F32);
-  impl->f64Ty = TypeUniquer::get<FloatType>(this, StandardTypes::F64);
+  impl->bf16Ty = TypeUniquer::get<BFloat16Type>(this, StandardTypes::BF16);
+  impl->f16Ty = TypeUniquer::get<Float16Type>(this, StandardTypes::F16);
+  impl->f32Ty = TypeUniquer::get<Float32Type>(this, StandardTypes::F32);
+  impl->f64Ty = TypeUniquer::get<Float64Type>(this, StandardTypes::F64);
   /// Index Type.
   impl->indexTy = TypeUniquer::get<IndexType>(this, StandardTypes::Index);
   /// Integer Types.
@@ -660,19 +664,17 @@ Identifier Identifier::get(StringRef str, MLIRContext *context) {
 /// This should not be used directly.
 StorageUniquer &MLIRContext::getTypeUniquer() { return getImpl().typeUniquer; }
 
-FloatType FloatType::get(StandardTypes::Kind kind, MLIRContext *context) {
-  switch (kind) {
-  case StandardTypes::BF16:
-    return context->getImpl().bf16Ty;
-  case StandardTypes::F16:
-    return context->getImpl().f16Ty;
-  case StandardTypes::F32:
-    return context->getImpl().f32Ty;
-  case StandardTypes::F64:
-    return context->getImpl().f64Ty;
-  default:
-    llvm_unreachable("unexpected floating-point kind");
-  }
+BFloat16Type BFloat16Type::get(MLIRContext *context) {
+  return context->getImpl().bf16Ty;
+}
+Float16Type Float16Type::get(MLIRContext *context) {
+  return context->getImpl().f16Ty;
+}
+Float32Type Float32Type::get(MLIRContext *context) {
+  return context->getImpl().f32Ty;
+}
+Float64Type Float64Type::get(MLIRContext *context) {
+  return context->getImpl().f64Ty;
 }
 
 /// Get an instance of the IndexType.

diff  --git a/mlir/lib/IR/StandardTypes.cpp b/mlir/lib/IR/StandardTypes.cpp
index c2295bb14573..9205a7da9000 100644
--- a/mlir/lib/IR/StandardTypes.cpp
+++ b/mlir/lib/IR/StandardTypes.cpp
@@ -22,10 +22,10 @@ using namespace mlir::detail;
 // Type
 //===----------------------------------------------------------------------===//
 
-bool Type::isBF16() { return getKind() == StandardTypes::BF16; }
-bool Type::isF16() { return getKind() == StandardTypes::F16; }
-bool Type::isF32() { return getKind() == StandardTypes::F32; }
-bool Type::isF64() { return getKind() == StandardTypes::F64; }
+bool Type::isBF16() { return isa<BFloat16Type>(); }
+bool Type::isF16() { return isa<Float16Type>(); }
+bool Type::isF32() { return isa<Float32Type>(); }
+bool Type::isF64() { return isa<Float64Type>(); }
 
 bool Type::isIndex() { return isa<IndexType>(); }
 
@@ -90,6 +90,13 @@ bool Type::isIntOrFloat() { return isa<IntegerType, FloatType>(); }
 
 bool Type::isIntOrIndexOrFloat() { return isIntOrFloat() || isIndex(); }
 
+unsigned Type::getIntOrFloatBitWidth() {
+  assert(isIntOrFloat() && "only integers and floats have a bitwidth");
+  if (auto intType = dyn_cast<IntegerType>())
+    return intType.getWidth();
+  return cast<FloatType>().getWidth();
+}
+
 //===----------------------------------------------------------------------===//
 /// ComplexType
 //===----------------------------------------------------------------------===//
@@ -142,39 +149,28 @@ IntegerType::SignednessSemantics IntegerType::getSignedness() const {
 //===----------------------------------------------------------------------===//
 
 unsigned FloatType::getWidth() {
-  switch (getKind()) {
-  case StandardTypes::BF16:
-  case StandardTypes::F16:
+  if (isa<Float16Type, BFloat16Type>())
     return 16;
-  case StandardTypes::F32:
+  if (isa<Float32Type>())
     return 32;
-  case StandardTypes::F64:
+  if (isa<Float64Type>())
     return 64;
-  default:
-    llvm_unreachable("unexpected type");
-  }
+  llvm_unreachable("unexpected float type");
 }
 
 /// Returns the floating semantics for the given type.
 const llvm::fltSemantics &FloatType::getFloatSemantics() {
-  if (isBF16())
+  if (isa<BFloat16Type>())
     return APFloat::BFloat();
-  if (isF16())
+  if (isa<Float16Type>())
     return APFloat::IEEEhalf();
-  if (isF32())
+  if (isa<Float32Type>())
     return APFloat::IEEEsingle();
-  if (isF64())
+  if (isa<Float64Type>())
     return APFloat::IEEEdouble();
   llvm_unreachable("non-floating point type used");
 }
 
-unsigned Type::getIntOrFloatBitWidth() {
-  assert(isIntOrFloat() && "only integers and floats have a bitwidth");
-  if (auto intType = dyn_cast<IntegerType>())
-    return intType.getWidth();
-  return cast<FloatType>().getWidth();
-}
-
 //===----------------------------------------------------------------------===//
 // ShapedType
 //===----------------------------------------------------------------------===//


        


More information about the Mlir-commits mailing list