[Mlir-commits] [mlir] 676bfb2 - [mlir] Refactor ShapedType into an interface
River Riddle
llvmlistbot at llvm.org
Wed Jan 12 14:12:59 PST 2022
Author: River Riddle
Date: 2022-01-12T14:12:09-08:00
New Revision: 676bfb2a226e705d801016bb433b68a1e09a1e10
URL: https://github.com/llvm/llvm-project/commit/676bfb2a226e705d801016bb433b68a1e09a1e10
DIFF: https://github.com/llvm/llvm-project/commit/676bfb2a226e705d801016bb433b68a1e09a1e10.diff
LOG: [mlir] Refactor ShapedType into an interface
ShapedType was created in a time before interfaces, and is one of the earliest
type base classes in the ecosystem. This commit refactors ShapedType into
an interface, which is what it would have been if interfaces had existed at that
time. The API of ShapedType and it's derived classes are essentially untouched
by this refactor, with the exception being the API surrounding kDynamicIndex
(which requires a sole home).
For now, the API of ShapedType and its name have been kept as consistent to
the current state of the world as possible (to help with potential migration churn,
among other reasons). Moving forward though, we should look into potentially
restructuring its API and possible its name as well (it should really have "Interface"
at the end like other interfaces at the very least).
One other potentially interesting note is that I've attached the ShapedType::Trait
to TensorType/BaseMemRefType to act as mixins for the ShapedType API. This
is kind of weird, but allows for sharing the same API (i.e. preventing API loss from
the transition from base class -> Interface). This inheritance doesn't affect any
of the derived classes, it is just for API mixin.
Differential Revision: https://reviews.llvm.org/D116962
Added:
mlir/lib/IR/BuiltinTypeInterfaces.cpp
Modified:
mlir/include/mlir/IR/BuiltinTypeInterfaces.td
mlir/include/mlir/IR/BuiltinTypes.h
mlir/include/mlir/IR/BuiltinTypes.td
mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp
mlir/lib/Conversion/LLVMCommon/Pattern.cpp
mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
mlir/lib/Dialect/Vector/VectorTransferSplitRewritePatterns.cpp
mlir/lib/IR/BuiltinTypes.cpp
mlir/lib/IR/CMakeLists.txt
mlir/unittests/IR/ShapedTypeTest.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/IR/BuiltinTypeInterfaces.td b/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
index f8879e55bfd48..94031fd62bdfa 100644
--- a/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
+++ b/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
@@ -41,4 +41,151 @@ def MemRefElementTypeInterface : TypeInterface<"MemRefElementTypeInterface"> {
}];
}
+//===----------------------------------------------------------------------===//
+// ShapedType
+//===----------------------------------------------------------------------===//
+
+def ShapedTypeInterface : TypeInterface<"ShapedType"> {
+ let cppNamespace = "::mlir";
+ let description = [{
+ This interface provides a common API for interacting with multi-dimensional
+ container types. These types contain a shape and an element type.
+
+ A shape is a list of sizes corresponding to the dimensions of the container.
+ If the number of dimensions in the shape is unknown, the shape is "unranked".
+ If the number of dimensions is known, the shape "ranked". The sizes of the
+ dimensions of the shape must be positive, or kDynamicSize (in which case the
+ size of the dimension is dynamic, or not statically known).
+ }];
+ let methods = [
+ InterfaceMethod<[{
+ Returns a clone of this type with the given shape and element
+ type. If a shape is not provided, the current shape of the type is used.
+ }],
+ "::mlir::ShapedType", "cloneWith", (ins
+ "::llvm::Optional<::llvm::ArrayRef<int64_t>>":$shape,
+ "::mlir::Type":$elementType
+ )>,
+
+ InterfaceMethod<[{
+ Returns the element type of this shaped type.
+ }],
+ "::mlir::Type", "getElementType">,
+
+ InterfaceMethod<[{
+ Returns if this type is ranked, i.e. it has a known number of dimensions.
+ }],
+ "bool", "hasRank">,
+
+ InterfaceMethod<[{
+ Returns the shape of this type if it is ranked, otherwise asserts.
+ }],
+ "::llvm::ArrayRef<int64_t>", "getShape">,
+ ];
+
+ let extraClassDeclaration = [{
+ // TODO: merge these two special values in a single one used everywhere.
+ // Unfortunately, uses of `-1` have crept deep into the codebase now and are
+ // hard to track.
+ static constexpr int64_t kDynamicSize = -1;
+ static constexpr int64_t kDynamicStrideOrOffset =
+ std::numeric_limits<int64_t>::min();
+
+ /// Whether the given dimension size indicates a dynamic dimension.
+ static constexpr bool isDynamic(int64_t dSize) {
+ return dSize == kDynamicSize;
+ }
+ static constexpr bool isDynamicStrideOrOffset(int64_t dStrideOrOffset) {
+ return dStrideOrOffset == kDynamicStrideOrOffset;
+ }
+
+ /// Return the number of elements present in the given shape.
+ static int64_t getNumElements(ArrayRef<int64_t> shape);
+
+ /// Returns the total amount of bits occupied by a value of this type. This
+ /// does not take into account any memory layout or widening constraints,
+ /// e.g. a vector<3xi57> may report to occupy 3x57=171 bit, even though in
+ /// practice it will likely be stored as in a 4xi64 vector register. Fails
+ /// with an assertion if the size cannot be computed statically, e.g. if the
+ /// type has a dynamic shape or if its elemental type does not have a known
+ /// bit width.
+ int64_t getSizeInBits() const;
+ }];
+
+ let extraSharedClassDeclaration = [{
+ /// Return a clone of this type with the given new shape and element type.
+ auto clone(::llvm::ArrayRef<int64_t> shape, Type elementType) {
+ return $_type.cloneWith(shape, elementType);
+ }
+ /// Return a clone of this type with the given new shape.
+ auto clone(::llvm::ArrayRef<int64_t> shape) {
+ return $_type.cloneWith(shape, $_type.getElementType());
+ }
+ /// Return a clone of this type with the given new element type.
+ auto clone(::mlir::Type elementType) {
+ return $_type.cloneWith(/*shape=*/llvm::None, elementType);
+ }
+
+ /// If an element type is an integer or a float, return its width. Otherwise,
+ /// abort.
+ unsigned getElementTypeBitWidth() const {
+ return $_type.getElementType().getIntOrFloatBitWidth();
+ }
+
+ /// If this is a ranked type, return the rank. Otherwise, abort.
+ int64_t getRank() const {
+ assert($_type.hasRank() && "cannot query rank of unranked shaped type");
+ return $_type.getShape().size();
+ }
+
+ /// If it has static shape, return the number of elements. Otherwise, abort.
+ int64_t getNumElements() const {
+ assert(hasStaticShape() && "cannot get element count of dynamic shaped type");
+ return ::mlir::ShapedType::getNumElements($_type.getShape());
+ }
+
+ /// Returns true if this dimension has a dynamic size (for ranked types);
+ /// aborts for unranked types.
+ bool isDynamicDim(unsigned idx) const {
+ assert(idx < getRank() && "invalid index for shaped type");
+ return ::mlir::ShapedType::isDynamic($_type.getShape()[idx]);
+ }
+
+ /// Returns if this type has a static shape, i.e. if the type is ranked and
+ /// all dimensions have known size (>= 0).
+ bool hasStaticShape() const {
+ return $_type.hasRank() &&
+ llvm::none_of($_type.getShape(), ::mlir::ShapedType::isDynamic);
+ }
+
+ /// Returns if this type has a static shape and the shape is equal to
+ /// `shape` return true.
+ bool hasStaticShape(::llvm::ArrayRef<int64_t> shape) const {
+ return hasStaticShape() && $_type.getShape() == shape;
+ }
+
+ /// If this is a ranked type, return the number of dimensions with dynamic
+ /// size. Otherwise, abort.
+ int64_t getNumDynamicDims() const {
+ return llvm::count_if($_type.getShape(), ::mlir::ShapedType::isDynamic);
+ }
+
+ /// If this is ranked type, return the size of the specified dimension.
+ /// Otherwise, abort.
+ int64_t getDimSize(unsigned idx) const {
+ assert(idx < getRank() && "invalid index for shaped type");
+ return $_type.getShape()[idx];
+ }
+
+ /// Returns the position of the dynamic dimension relative to just the dynamic
+ /// dimensions, given its `index` within the shape.
+ unsigned getDynamicDimIndex(unsigned index) const {
+ assert(index < getRank() && "invalid index");
+ assert(::mlir::ShapedType::isDynamic(getDimSize(index)) && "invalid index");
+ return llvm::count_if($_type.getShape().take_front(index),
+ ::mlir::ShapedType::isDynamic);
+ }
+ }];
+}
+
#endif // MLIR_IR_BUILTINTYPEINTERFACES_TD_
diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h
index 6f25f81e9342d..e087e6bf55d8c 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.h
+++ b/mlir/include/mlir/IR/BuiltinTypes.h
@@ -16,6 +16,12 @@ namespace llvm {
struct fltSemantics;
} // namespace llvm
+//===----------------------------------------------------------------------===//
+// Tablegen Interface Declarations
+//===----------------------------------------------------------------------===//
+
+#include "mlir/IR/BuiltinTypeInterfaces.h.inc"
+
namespace mlir {
class AffineExpr;
class AffineMap;
@@ -56,118 +62,67 @@ class FloatType : public Type {
};
//===----------------------------------------------------------------------===//
-// ShapedType
+// TensorType
//===----------------------------------------------------------------------===//
-/// This is a common base class between Vector, UnrankedTensor, RankedTensor,
-/// and MemRef types because they share behavior and semantics around shape,
-/// rank, and fixed element type. Any type with these semantics should inherit
-/// from ShapedType.
-class ShapedType : public Type {
+/// Tensor types represent multi-dimensional arrays, and have two variants:
+/// RankedTensorType and UnrankedTensorType.
+/// Note: This class attaches the ShapedType trait to act as a mixin to
+/// provide many useful utility functions. This inheritance has no effect
+/// on derived tensor types.
+class TensorType : public Type, public ShapedType::Trait<TensorType> {
public:
using Type::Type;
- // TODO: merge these two special values in a single one used everywhere.
- // Unfortunately, uses of `-1` have crept deep into the codebase now and are
- // hard to track.
- static constexpr int64_t kDynamicSize = -1;
- static constexpr int64_t kDynamicStrideOrOffset =
- std::numeric_limits<int64_t>::min();
-
- /// Return clone of this type with new shape and element type.
- ShapedType clone(ArrayRef<int64_t> shape, Type elementType);
- ShapedType clone(ArrayRef<int64_t> shape);
- ShapedType clone(Type elementType);
-
- /// Return the element type.
+ /// Returns the element type of this tensor type.
Type getElementType() const;
- /// If an element type is an integer or a float, return its width. Otherwise,
- /// abort.
- unsigned getElementTypeBitWidth() const;
-
- /// If it has static shape, return the number of elements. Otherwise, abort.
- int64_t getNumElements() const;
-
- /// If this is a ranked type, return the rank. Otherwise, abort.
- int64_t getRank() const;
-
- /// Whether or not this is a ranked type. Memrefs, vectors and ranked tensors
- /// have a rank, while unranked tensors do not.
+ /// Returns if this type is ranked, i.e. it has a known number of dimensions.
bool hasRank() const;
- /// If this is a ranked type, return the shape. Otherwise, abort.
+ /// Returns the shape of this tensor type.
ArrayRef<int64_t> getShape() const;
- /// If this is unranked type or any dimension has unknown size (<0), it
- /// doesn't have static shape. If all dimensions have known size (>= 0), it
- /// has static shape.
- bool hasStaticShape() const;
-
- /// If this has a static shape and the shape is equal to `shape` return true.
- bool hasStaticShape(ArrayRef<int64_t> shape) const;
-
- /// If this is a ranked type, return the number of dimensions with dynamic
- /// size. Otherwise, abort.
- int64_t getNumDynamicDims() const;
-
- /// If this is ranked type, return the size of the specified dimension.
- /// Otherwise, abort.
- int64_t getDimSize(unsigned idx) const;
-
- /// Returns true if this dimension has a dynamic size (for ranked types);
- /// aborts for unranked types.
- bool isDynamicDim(unsigned idx) const;
-
- /// Returns the position of the dynamic dimension relative to just the dynamic
- /// dimensions, given its `index` within the shape.
- unsigned getDynamicDimIndex(unsigned index) const;
+ /// Clone this type with the given shape and element type. If the
+ /// provided shape is `None`, the current shape of the type is used.
+ TensorType cloneWith(Optional<ArrayRef<int64_t>> shape,
+ Type elementType) const;
- /// Get the total amount of bits occupied by a value of this type. This does
- /// not take into account any memory layout or widening constraints, e.g. a
- /// vector<3xi57> is reported to occupy 3x57=171 bit, even though in practice
- /// it will likely be stored as in a 4xi64 vector register. Fail an assertion
- /// if the size cannot be computed statically, i.e. if the type has a dynamic
- /// shape or if its elemental type does not have a known bit width.
- int64_t getSizeInBits() const;
+ /// Return true if the specified element type is ok in a tensor.
+ static bool isValidElementType(Type type);
/// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool classof(Type type);
- /// Whether the given dimension size indicates a dynamic dimension.
- static constexpr bool isDynamic(int64_t dSize) {
- return dSize == kDynamicSize;
- }
- static constexpr bool isDynamicStrideOrOffset(int64_t dStrideOrOffset) {
- return dStrideOrOffset == kDynamicStrideOrOffset;
- }
+ /// Allow implicit conversion to ShapedType.
+ operator ShapedType() const { return cast<ShapedType>(); }
};
//===----------------------------------------------------------------------===//
-// TensorType
+// BaseMemRefType
//===----------------------------------------------------------------------===//
-/// Tensor types represent multi-dimensional arrays, and have two variants:
-/// RankedTensorType and UnrankedTensorType.
-class TensorType : public ShapedType {
+/// This class provides a shared interface for ranked and unranked memref types.
+/// Note: This class attaches the ShapedType trait to act as a mixin to
+/// provide many useful utility functions. This inheritance has no effect
+/// on derived memref types.
+class BaseMemRefType : public Type, public ShapedType::Trait<BaseMemRefType> {
public:
- using ShapedType::ShapedType;
+ using Type::Type;
- /// Return true if the specified element type is ok in a tensor.
- static bool isValidElementType(Type type);
+ /// Returns the element type of this memref type.
+ Type getElementType() const;
- /// Methods for support type inquiry through isa, cast, and dyn_cast.
- static bool classof(Type type);
-};
+ /// Returns if this type is ranked, i.e. it has a known number of dimensions.
+ bool hasRank() const;
-//===----------------------------------------------------------------------===//
-// BaseMemRefType
-//===----------------------------------------------------------------------===//
+ /// Returns the shape of this memref type.
+ ArrayRef<int64_t> getShape() const;
-/// Base MemRef for Ranked and Unranked variants
-class BaseMemRefType : public ShapedType {
-public:
- using ShapedType::ShapedType;
+ /// Clone this type with the given shape and element type. If the
+ /// provided shape is `None`, the current shape of the type is used.
+ BaseMemRefType cloneWith(Optional<ArrayRef<int64_t>> shape,
+ Type elementType) const;
/// Return true if the specified element type is ok in a memref.
static bool isValidElementType(Type type);
@@ -181,6 +136,9 @@ class BaseMemRefType : public ShapedType {
/// [deprecated] Returns the memory space in old raw integer representation.
/// New `Attribute getMemorySpace()` method should be used instead.
unsigned getMemorySpaceAsInt() const;
+
+ /// Allow implicit conversion to ShapedType.
+ operator ShapedType() const { return cast<ShapedType>(); }
};
} // namespace mlir
@@ -192,12 +150,6 @@ class BaseMemRefType : public ShapedType {
#define GET_TYPEDEF_CLASSES
#include "mlir/IR/BuiltinTypes.h.inc"
-//===----------------------------------------------------------------------===//
-// Tablegen Interface Declarations
-//===----------------------------------------------------------------------===//
-
-#include "mlir/IR/BuiltinTypeInterfaces.h.inc"
-
namespace mlir {
//===----------------------------------------------------------------------===//
@@ -439,11 +391,6 @@ inline FloatType FloatType::getF128(MLIRContext *ctx) {
return Float128Type::get(ctx);
}
-inline bool ShapedType::classof(Type type) {
- return type.isa<RankedTensorType, VectorType, UnrankedTensorType,
- UnrankedMemRefType, MemRefType>();
-}
-
inline bool TensorType::classof(Type type) {
return type.isa<RankedTensorType, UnrankedTensorType>();
}
diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index 17e2de1ce10a8..d646688a08fea 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -266,7 +266,7 @@ def Builtin_Integer : Builtin_Type<"Integer"> {
//===----------------------------------------------------------------------===//
def Builtin_MemRef : Builtin_Type<"MemRef", [
- DeclareTypeInterfaceMethods<SubElementTypeInterface>
+ DeclareTypeInterfaceMethods<SubElementTypeInterface>, ShapedTypeInterface
], "BaseMemRefType"> {
let summary = "Shaped reference to a region of memory";
let description = [{
@@ -541,6 +541,16 @@ def Builtin_MemRef : Builtin_Type<"MemRef", [
"unsigned":$memorySpaceInd)>
];
let extraClassDeclaration = [{
+ using ShapedType::Trait<MemRefType>::clone;
+ using ShapedType::Trait<MemRefType>::getElementTypeBitWidth;
+ using ShapedType::Trait<MemRefType>::getRank;
+ using ShapedType::Trait<MemRefType>::getNumElements;
+ using ShapedType::Trait<MemRefType>::isDynamicDim;
+ using ShapedType::Trait<MemRefType>::hasStaticShape;
+ using ShapedType::Trait<MemRefType>::getNumDynamicDims;
+ using ShapedType::Trait<MemRefType>::getDimSize;
+ using ShapedType::Trait<MemRefType>::getDynamicDimIndex;
+
/// This is a builder type that keeps local references to arguments.
/// Arguments that are passed into the builder must outlive the builder.
class Builder;
@@ -620,7 +630,7 @@ def Builtin_Opaque : Builtin_Type<"Opaque"> {
//===----------------------------------------------------------------------===//
def Builtin_RankedTensor : Builtin_Type<"RankedTensor", [
- DeclareTypeInterfaceMethods<SubElementTypeInterface>
+ DeclareTypeInterfaceMethods<SubElementTypeInterface>, ShapedTypeInterface
], "TensorType"> {
let summary = "Multi-dimensional array with a fixed number of dimensions";
let description = [{
@@ -702,6 +712,16 @@ def Builtin_RankedTensor : Builtin_Type<"RankedTensor", [
}]>
];
let extraClassDeclaration = [{
+ using ShapedType::Trait<RankedTensorType>::clone;
+ using ShapedType::Trait<RankedTensorType>::getElementTypeBitWidth;
+ using ShapedType::Trait<RankedTensorType>::getRank;
+ using ShapedType::Trait<RankedTensorType>::getNumElements;
+ using ShapedType::Trait<RankedTensorType>::isDynamicDim;
+ using ShapedType::Trait<RankedTensorType>::hasStaticShape;
+ using ShapedType::Trait<RankedTensorType>::getNumDynamicDims;
+ using ShapedType::Trait<RankedTensorType>::getDimSize;
+ using ShapedType::Trait<RankedTensorType>::getDynamicDimIndex;
+
/// This is a builder type that keeps local references to arguments.
/// Arguments that are passed into the builder must outlive the builder.
class Builder;
@@ -784,7 +804,7 @@ def Builtin_Tuple : Builtin_Type<"Tuple", [
//===----------------------------------------------------------------------===//
def Builtin_UnrankedMemRef : Builtin_Type<"UnrankedMemRef", [
- DeclareTypeInterfaceMethods<SubElementTypeInterface>
+ DeclareTypeInterfaceMethods<SubElementTypeInterface>, ShapedTypeInterface
], "BaseMemRefType"> {
let summary = "Shaped reference, with unknown rank, to a region of memory";
let description = [{
@@ -831,6 +851,16 @@ def Builtin_UnrankedMemRef : Builtin_Type<"UnrankedMemRef", [
}]>
];
let extraClassDeclaration = [{
+ using ShapedType::Trait<UnrankedMemRefType>::clone;
+ using ShapedType::Trait<UnrankedMemRefType>::getElementTypeBitWidth;
+ using ShapedType::Trait<UnrankedMemRefType>::getRank;
+ using ShapedType::Trait<UnrankedMemRefType>::getNumElements;
+ using ShapedType::Trait<UnrankedMemRefType>::isDynamicDim;
+ using ShapedType::Trait<UnrankedMemRefType>::hasStaticShape;
+ using ShapedType::Trait<UnrankedMemRefType>::getNumDynamicDims;
+ using ShapedType::Trait<UnrankedMemRefType>::getDimSize;
+ using ShapedType::Trait<UnrankedMemRefType>::getDynamicDimIndex;
+
ArrayRef<int64_t> getShape() const { return llvm::None; }
/// [deprecated] Returns the memory space in old raw integer representation.
@@ -846,7 +876,7 @@ def Builtin_UnrankedMemRef : Builtin_Type<"UnrankedMemRef", [
//===----------------------------------------------------------------------===//
def Builtin_UnrankedTensor : Builtin_Type<"UnrankedTensor", [
- DeclareTypeInterfaceMethods<SubElementTypeInterface>
+ DeclareTypeInterfaceMethods<SubElementTypeInterface>, ShapedTypeInterface
], "TensorType"> {
let summary = "Multi-dimensional array with unknown dimensions";
let description = [{
@@ -874,6 +904,16 @@ def Builtin_UnrankedTensor : Builtin_Type<"UnrankedTensor", [
}]>
];
let extraClassDeclaration = [{
+ using ShapedType::Trait<UnrankedTensorType>::clone;
+ using ShapedType::Trait<UnrankedTensorType>::getElementTypeBitWidth;
+ using ShapedType::Trait<UnrankedTensorType>::getRank;
+ using ShapedType::Trait<UnrankedTensorType>::getNumElements;
+ using ShapedType::Trait<UnrankedTensorType>::isDynamicDim;
+ using ShapedType::Trait<UnrankedTensorType>::hasStaticShape;
+ using ShapedType::Trait<UnrankedTensorType>::getNumDynamicDims;
+ using ShapedType::Trait<UnrankedTensorType>::getDimSize;
+ using ShapedType::Trait<UnrankedTensorType>::getDynamicDimIndex;
+
ArrayRef<int64_t> getShape() const { return llvm::None; }
}];
let skipDefaultBuilders = 1;
@@ -885,8 +925,8 @@ def Builtin_UnrankedTensor : Builtin_Type<"UnrankedTensor", [
//===----------------------------------------------------------------------===//
def Builtin_Vector : Builtin_Type<"Vector", [
- DeclareTypeInterfaceMethods<SubElementTypeInterface>
- ], "ShapedType"> {
+ DeclareTypeInterfaceMethods<SubElementTypeInterface>, ShapedTypeInterface
+ ], "Type"> {
let summary = "Multi-dimensional SIMD vector type";
let description = [{
Syntax:
@@ -966,6 +1006,14 @@ def Builtin_Vector : Builtin_Type<"Vector", [
/// element type of bitwidth scaled by `scale`.
/// Return null if the scaled element type cannot be represented.
VectorType scaleElementBitwidth(unsigned scale);
+
+ /// Returns if this type is ranked (always true).
+ bool hasRank() const { return true; }
+
+ /// Clone this vector type with the given shape and element type. If the
+ /// provided shape is `None`, the current shape of the type is used.
+ VectorType cloneWith(Optional<ArrayRef<int64_t>> shape,
+ Type elementType);
}];
let skipDefaultBuilders = 1;
let genVerifyDecl = 1;
diff --git a/mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp b/mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp
index 4f78461572f33..428d578d99ef8 100644
--- a/mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp
@@ -51,10 +51,10 @@ MemRefDescriptor::fromStaticShape(OpBuilder &builder, Location loc,
auto result = getStridesAndOffset(type, strides, offset);
(void)result;
assert(succeeded(result) && "unexpected failure in stride computation");
- assert(!MemRefType::isDynamicStrideOrOffset(offset) &&
+ assert(!ShapedType::isDynamicStrideOrOffset(offset) &&
"expected static offset");
assert(!llvm::any_of(strides, [](int64_t stride) {
- return MemRefType::isDynamicStrideOrOffset(stride);
+ return ShapedType::isDynamicStrideOrOffset(stride);
}) && "expected static strides");
auto convertedType = typeConverter.convertType(type);
diff --git a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
index 41e8cefd712ad..791a36156955a 100644
--- a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
@@ -79,14 +79,14 @@ Value ConvertToLLVMPattern::getStridedElementPtr(
Value index;
if (offset != 0) // Skip if offset is zero.
- index = MemRefType::isDynamicStrideOrOffset(offset)
+ index = ShapedType::isDynamicStrideOrOffset(offset)
? memRefDescriptor.offset(rewriter, loc)
: createIndexConstant(rewriter, loc, offset);
for (int i = 0, e = indices.size(); i < e; ++i) {
Value increment = indices[i];
if (strides[i] != 1) { // Skip if stride is 1.
- Value stride = MemRefType::isDynamicStrideOrOffset(strides[i])
+ Value stride = ShapedType::isDynamicStrideOrOffset(strides[i])
? memRefDescriptor.stride(rewriter, loc, i)
: createIndexConstant(rewriter, loc, strides[i]);
increment = rewriter.create<LLVM::MulOp>(loc, increment, stride);
diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
index d8fa9654664e2..3bb802b1ec43d 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
@@ -106,7 +106,7 @@ struct AlignedAllocOpLowering : public AllocLikeOpLLVMLowering {
Operation *op) const {
uint64_t sizeDivisor = getMemRefEltSizeInBytes(type, op);
for (unsigned i = 0, e = type.getRank(); i < e; i++) {
- if (type.isDynamic(type.getDimSize(i)))
+ if (ShapedType::isDynamic(type.getDimSize(i)))
continue;
sizeDivisor = sizeDivisor * type.getDimSize(i);
}
@@ -1467,7 +1467,7 @@ struct ViewOpLowering : public ConvertOpToLLVMPattern<memref::ViewOp> {
ArrayRef<int64_t> strides, Value nextSize,
Value runningStride, unsigned idx) const {
assert(idx < strides.size());
- if (!MemRefType::isDynamicStrideOrOffset(strides[idx]))
+ if (!ShapedType::isDynamicStrideOrOffset(strides[idx]))
return createIndexConstant(rewriter, loc, strides[idx]);
if (nextSize)
return runningStride
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 2e7dc4112592b..42719eedffa73 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -342,22 +342,22 @@ bool CastOp::canFoldIntoConsumerOp(CastOp castOp) {
for (auto it : llvm::zip(sourceType.getShape(), resultType.getShape())) {
auto ss = std::get<0>(it), st = std::get<1>(it);
if (ss != st)
- if (MemRefType::isDynamic(ss) && !MemRefType::isDynamic(st))
+ if (ShapedType::isDynamic(ss) && !ShapedType::isDynamic(st))
return false;
}
// If cast is towards more static offset along any dimension, don't fold.
if (sourceOffset != resultOffset)
- if (MemRefType::isDynamicStrideOrOffset(sourceOffset) &&
- !MemRefType::isDynamicStrideOrOffset(resultOffset))
+ if (ShapedType::isDynamicStrideOrOffset(sourceOffset) &&
+ !ShapedType::isDynamicStrideOrOffset(resultOffset))
return false;
// If cast is towards more static strides along any dimension, don't fold.
for (auto it : llvm::zip(sourceStrides, resultStrides)) {
auto ss = std::get<0>(it), st = std::get<1>(it);
if (ss != st)
- if (MemRefType::isDynamicStrideOrOffset(ss) &&
- !MemRefType::isDynamicStrideOrOffset(st))
+ if (ShapedType::isDynamicStrideOrOffset(ss) &&
+ !ShapedType::isDynamicStrideOrOffset(st))
return false;
}
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index ca542a1c8f854..9afb67ef0ceef 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -518,7 +518,7 @@ static void genBuffers(Merger &merger, CodeGen &codegen,
// Find upper bound in current dimension.
unsigned p = perm(enc, d);
Value up = linalg::createOrFoldDimOp(rewriter, loc, t->get(), p);
- if (shape[p] == MemRefType::kDynamicSize)
+ if (ShapedType::isDynamic(shape[p]))
args.push_back(up);
assert(codegen.highs[tensor][idx] == nullptr);
codegen.sizes[idx] = codegen.highs[tensor][idx] = up;
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 665021b4c70d6..216de423fc311 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -268,13 +268,12 @@ OpFoldResult DimOp::fold(ArrayRef<Attribute> operands) {
fromElements.getResult().getType().cast<RankedTensorType>();
// The case where the type encodes the size of the dimension is handled
// above.
- assert(resultType.getShape()[index.getInt()] ==
- RankedTensorType::kDynamicSize);
+ assert(ShapedType::isDynamic(resultType.getShape()[index.getInt()]));
// Find the operand of the fromElements that corresponds to this index.
auto dynExtents = fromElements.dynamicExtents().begin();
for (auto dim : resultType.getShape().take_front(index.getInt()))
- if (dim == RankedTensorType::kDynamicSize)
+ if (ShapedType::isDynamic(dim))
dynExtents++;
return Value{*dynExtents};
@@ -523,13 +522,13 @@ struct StaticTensorGenerate : public OpRewritePattern<GenerateOp> {
auto operandsIt = tensorFromElements.dynamicExtents().begin();
for (int64_t dim : resultType.getShape()) {
- if (dim != RankedTensorType::kDynamicSize) {
+ if (!ShapedType::isDynamic(dim)) {
newShape.push_back(dim);
continue;
}
APInt index;
if (!matchPattern(*operandsIt, m_ConstantInt(&index))) {
- newShape.push_back(RankedTensorType::kDynamicSize);
+ newShape.push_back(ShapedType::kDynamicSize);
newOperands.push_back(*operandsIt++);
continue;
}
@@ -661,7 +660,7 @@ static LogicalResult verify(ReshapeOp op) {
return op.emitOpError("source and destination tensor should have the "
"same number of elements");
}
- if (shapeSize == TensorType::kDynamicSize)
+ if (ShapedType::isDynamic(shapeSize))
return op.emitOpError("cannot use shape operand with dynamic length to "
"reshape to statically-ranked tensor type");
if (shapeSize != resultRankedType.getRank())
diff --git a/mlir/lib/Dialect/Vector/VectorTransferSplitRewritePatterns.cpp b/mlir/lib/Dialect/Vector/VectorTransferSplitRewritePatterns.cpp
index a6fbc303d64e3..cda4472695b2d 100644
--- a/mlir/lib/Dialect/Vector/VectorTransferSplitRewritePatterns.cpp
+++ b/mlir/lib/Dialect/Vector/VectorTransferSplitRewritePatterns.cpp
@@ -172,13 +172,13 @@ static MemRefType getCastCompatibleMemRefType(MemRefType aT, MemRefType bT) {
resStrides(bT.getRank(), 0);
for (int64_t idx = 0, e = aT.getRank(); idx < e; ++idx) {
resShape[idx] =
- (aShape[idx] == bShape[idx]) ? aShape[idx] : MemRefType::kDynamicSize;
+ (aShape[idx] == bShape[idx]) ? aShape[idx] : ShapedType::kDynamicSize;
resStrides[idx] = (aStrides[idx] == bStrides[idx])
? aStrides[idx]
- : MemRefType::kDynamicStrideOrOffset;
+ : ShapedType::kDynamicStrideOrOffset;
}
resOffset =
- (aOffset == bOffset) ? aOffset : MemRefType::kDynamicStrideOrOffset;
+ (aOffset == bOffset) ? aOffset : ShapedType::kDynamicStrideOrOffset;
return MemRefType::get(
resShape, aT.getElementType(),
makeStridedLinearLayoutMap(resStrides, resOffset, aT.getContext()));
diff --git a/mlir/lib/IR/BuiltinTypeInterfaces.cpp b/mlir/lib/IR/BuiltinTypeInterfaces.cpp
new file mode 100644
index 0000000000000..aaa2233c24394
--- /dev/null
+++ b/mlir/lib/IR/BuiltinTypeInterfaces.cpp
@@ -0,0 +1,51 @@
+//===- BuiltinTypeInterfaces.cpp ------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Diagnostics.h"
+#include "llvm/ADT/Sequence.h"
+
+using namespace mlir;
+using namespace mlir::detail;
+
+//===----------------------------------------------------------------------===//
+/// Tablegen Interface Definitions
+//===----------------------------------------------------------------------===//
+
+#include "mlir/IR/BuiltinTypeInterfaces.cpp.inc"
+
+//===----------------------------------------------------------------------===//
+// ShapedType
+//===----------------------------------------------------------------------===//
+
+constexpr int64_t ShapedType::kDynamicSize;
+constexpr int64_t ShapedType::kDynamicStrideOrOffset;
+
+int64_t ShapedType::getNumElements(ArrayRef<int64_t> shape) {
+ int64_t num = 1;
+ for (int64_t dim : shape) {
+ num *= dim;
+ assert(num >= 0 && "integer overflow in element count computation");
+ }
+ return num;
+}
+
+int64_t ShapedType::getSizeInBits() const {
+ assert(hasStaticShape() &&
+ "cannot get the bit size of an aggregate with a dynamic shape");
+
+ auto elementType = getElementType();
+ if (elementType.isIntOrFloat())
+ return elementType.getIntOrFloatBitWidth() * getNumElements();
+
+ if (auto complexType = elementType.dyn_cast<ComplexType>()) {
+ elementType = complexType.getElementType();
+ return elementType.getIntOrFloatBitWidth() * getNumElements() * 2;
+ }
+ return getNumElements() * elementType.cast<ShapedType>().getSizeInBits();
+}
diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp
index 6efd384ad3cce..4d3e44ff761d8 100644
--- a/mlir/lib/IR/BuiltinTypes.cpp
+++ b/mlir/lib/IR/BuiltinTypes.cpp
@@ -32,12 +32,6 @@ using namespace mlir::detail;
#define GET_TYPEDEF_CLASSES
#include "mlir/IR/BuiltinTypes.cpp.inc"
-//===----------------------------------------------------------------------===//
-/// Tablegen Interface Definitions
-//===----------------------------------------------------------------------===//
-
-#include "mlir/IR/BuiltinTypeInterfaces.cpp.inc"
-
//===----------------------------------------------------------------------===//
// BuiltinDialect
//===----------------------------------------------------------------------===//
@@ -271,171 +265,6 @@ LogicalResult OpaqueType::verify(function_ref<InFlightDiagnostic()> emitError,
return success();
}
-//===----------------------------------------------------------------------===//
-// ShapedType
-//===----------------------------------------------------------------------===//
-constexpr int64_t ShapedType::kDynamicSize;
-constexpr int64_t ShapedType::kDynamicStrideOrOffset;
-
-ShapedType ShapedType::clone(ArrayRef<int64_t> shape, Type elementType) {
- if (auto other = dyn_cast<MemRefType>()) {
- MemRefType::Builder b(other);
- b.setShape(shape);
- b.setElementType(elementType);
- return b;
- }
-
- if (auto other = dyn_cast<UnrankedMemRefType>()) {
- MemRefType::Builder b(shape, elementType);
- b.setMemorySpace(other.getMemorySpace());
- return b;
- }
-
- if (isa<TensorType>())
- return RankedTensorType::get(shape, elementType);
-
- if (auto vecTy = dyn_cast<VectorType>())
- return VectorType::get(shape, elementType, vecTy.getNumScalableDims());
-
- llvm_unreachable("Unhandled ShapedType clone case");
-}
-
-ShapedType ShapedType::clone(ArrayRef<int64_t> shape) {
- if (auto other = dyn_cast<MemRefType>()) {
- MemRefType::Builder b(other);
- b.setShape(shape);
- return b;
- }
-
- if (auto other = dyn_cast<UnrankedMemRefType>()) {
- MemRefType::Builder b(shape, other.getElementType());
- b.setShape(shape);
- b.setMemorySpace(other.getMemorySpace());
- return b;
- }
-
- if (isa<TensorType>())
- return RankedTensorType::get(shape, getElementType());
-
- if (auto vecTy = dyn_cast<VectorType>())
- return VectorType::get(shape, getElementType(), vecTy.getNumScalableDims());
-
- llvm_unreachable("Unhandled ShapedType clone case");
-}
-
-ShapedType ShapedType::clone(Type elementType) {
- if (auto other = dyn_cast<MemRefType>()) {
- MemRefType::Builder b(other);
- b.setElementType(elementType);
- return b;
- }
-
- if (auto other = dyn_cast<UnrankedMemRefType>()) {
- return UnrankedMemRefType::get(elementType, other.getMemorySpace());
- }
-
- if (isa<TensorType>()) {
- if (hasRank())
- return RankedTensorType::get(getShape(), elementType);
- return UnrankedTensorType::get(elementType);
- }
-
- if (auto vecTy = dyn_cast<VectorType>())
- return VectorType::get(getShape(), elementType, vecTy.getNumScalableDims());
-
- llvm_unreachable("Unhandled ShapedType clone hit");
-}
-
-Type ShapedType::getElementType() const {
- return TypeSwitch<Type, Type>(*this)
- .Case<VectorType, RankedTensorType, UnrankedTensorType, MemRefType,
- UnrankedMemRefType>([](auto ty) { return ty.getElementType(); });
-}
-
-unsigned ShapedType::getElementTypeBitWidth() const {
- return getElementType().getIntOrFloatBitWidth();
-}
-
-int64_t ShapedType::getNumElements() const {
- assert(hasStaticShape() && "cannot get element count of dynamic shaped type");
- auto shape = getShape();
- int64_t num = 1;
- for (auto dim : shape) {
- num *= dim;
- assert(num >= 0 && "integer overflow in element count computation");
- }
- return num;
-}
-
-int64_t ShapedType::getRank() const {
- assert(hasRank() && "cannot query rank of unranked shaped type");
- return getShape().size();
-}
-
-bool ShapedType::hasRank() const {
- return !isa<UnrankedMemRefType, UnrankedTensorType>();
-}
-
-int64_t ShapedType::getDimSize(unsigned idx) const {
- assert(idx < getRank() && "invalid index for shaped type");
- return getShape()[idx];
-}
-
-bool ShapedType::isDynamicDim(unsigned idx) const {
- assert(idx < getRank() && "invalid index for shaped type");
- return isDynamic(getShape()[idx]);
-}
-
-unsigned ShapedType::getDynamicDimIndex(unsigned index) const {
- assert(index < getRank() && "invalid index");
- assert(ShapedType::isDynamic(getDimSize(index)) && "invalid index");
- return llvm::count_if(getShape().take_front(index), ShapedType::isDynamic);
-}
-
-/// Get the number of bits require to store a value of the given shaped type.
-/// Compute the value recursively since tensors are allowed to have vectors as
-/// elements.
-int64_t ShapedType::getSizeInBits() const {
- assert(hasStaticShape() &&
- "cannot get the bit size of an aggregate with a dynamic shape");
-
- auto elementType = getElementType();
- if (elementType.isIntOrFloat())
- return elementType.getIntOrFloatBitWidth() * getNumElements();
-
- if (auto complexType = elementType.dyn_cast<ComplexType>()) {
- elementType = complexType.getElementType();
- return elementType.getIntOrFloatBitWidth() * getNumElements() * 2;
- }
-
- // Tensors can have vectors and other tensors as elements, other shaped types
- // cannot.
- assert(isa<TensorType>() && "unsupported element type");
- assert((elementType.isa<VectorType, TensorType>()) &&
- "unsupported tensor element type");
- return getNumElements() * elementType.cast<ShapedType>().getSizeInBits();
-}
-
-ArrayRef<int64_t> ShapedType::getShape() const {
- if (auto vectorType = dyn_cast<VectorType>())
- return vectorType.getShape();
- if (auto tensorType = dyn_cast<RankedTensorType>())
- return tensorType.getShape();
- return cast<MemRefType>().getShape();
-}
-
-int64_t ShapedType::getNumDynamicDims() const {
- return llvm::count_if(getShape(), isDynamic);
-}
-
-bool ShapedType::hasStaticShape() const {
- return hasRank() && llvm::none_of(getShape(), isDynamic);
-}
-
-bool ShapedType::hasStaticShape(ArrayRef<int64_t> shape) const {
- return hasStaticShape() && getShape() == shape;
-}
-
//===----------------------------------------------------------------------===//
// VectorType
//===----------------------------------------------------------------------===//
@@ -474,10 +303,44 @@ void VectorType::walkImmediateSubElements(
walkTypesFn(getElementType());
}
+VectorType VectorType::cloneWith(Optional<ArrayRef<int64_t>> shape,
+ Type elementType) {
+ return VectorType::get(shape.getValueOr(getShape()), elementType,
+ getNumScalableDims());
+}
+
//===----------------------------------------------------------------------===//
// TensorType
//===----------------------------------------------------------------------===//
+Type TensorType::getElementType() const {
+ return llvm::TypeSwitch<TensorType, Type>(*this)
+ .Case<RankedTensorType, UnrankedTensorType>(
+ [](auto type) { return type.getElementType(); });
+}
+
+bool TensorType::hasRank() const { return !isa<UnrankedTensorType>(); }
+
+ArrayRef<int64_t> TensorType::getShape() const {
+ return cast<RankedTensorType>().getShape();
+}
+
+TensorType TensorType::cloneWith(Optional<ArrayRef<int64_t>> shape,
+ Type elementType) const {
+ if (auto unrankedTy = dyn_cast<UnrankedTensorType>()) {
+ if (shape)
+ return RankedTensorType::get(*shape, elementType);
+ return UnrankedTensorType::get(elementType);
+ }
+
+ auto rankedTy = cast<RankedTensorType>();
+ if (!shape)
+ return RankedTensorType::get(rankedTy.getShape(), elementType,
+ rankedTy.getEncoding());
+ return RankedTensorType::get(shape.getValueOr(rankedTy.getShape()),
+ elementType, rankedTy.getEncoding());
+}
+
// Check if "elementType" can be an element type of a tensor.
static LogicalResult
checkTensorElementType(function_ref<InFlightDiagnostic()> emitError,
@@ -542,6 +405,35 @@ void UnrankedTensorType::walkImmediateSubElements(
// BaseMemRefType
//===----------------------------------------------------------------------===//
+Type BaseMemRefType::getElementType() const {
+ return llvm::TypeSwitch<BaseMemRefType, Type>(*this)
+ .Case<MemRefType, UnrankedMemRefType>(
+ [](auto type) { return type.getElementType(); });
+}
+
+bool BaseMemRefType::hasRank() const { return !isa<UnrankedMemRefType>(); }
+
+ArrayRef<int64_t> BaseMemRefType::getShape() const {
+ return cast<MemRefType>().getShape();
+}
+
+BaseMemRefType BaseMemRefType::cloneWith(Optional<ArrayRef<int64_t>> shape,
+ Type elementType) const {
+ if (auto unrankedTy = dyn_cast<UnrankedMemRefType>()) {
+ if (!shape)
+ return UnrankedMemRefType::get(elementType, getMemorySpace());
+ MemRefType::Builder builder(*shape, elementType);
+ builder.setMemorySpace(getMemorySpace());
+ return builder;
+ }
+
+ MemRefType::Builder builder(cast<MemRefType>());
+ if (shape)
+ builder.setShape(*shape);
+ builder.setElementType(elementType);
+ return builder;
+}
+
Attribute BaseMemRefType::getMemorySpace() const {
if (auto rankedMemRefTy = dyn_cast<MemRefType>())
return rankedMemRefTy.getMemorySpace();
diff --git a/mlir/lib/IR/CMakeLists.txt b/mlir/lib/IR/CMakeLists.txt
index 53326ad2ec648..3ad65c4ffe103 100644
--- a/mlir/lib/IR/CMakeLists.txt
+++ b/mlir/lib/IR/CMakeLists.txt
@@ -9,6 +9,7 @@ add_mlir_library(MLIRIR
BuiltinAttributes.cpp
BuiltinDialect.cpp
BuiltinTypes.cpp
+ BuiltinTypeInterfaces.cpp
Diagnostics.cpp
Dialect.cpp
Dominance.cpp
diff --git a/mlir/unittests/IR/ShapedTypeTest.cpp b/mlir/unittests/IR/ShapedTypeTest.cpp
index b46029f1ba06c..82674fd3768b6 100644
--- a/mlir/unittests/IR/ShapedTypeTest.cpp
+++ b/mlir/unittests/IR/ShapedTypeTest.cpp
@@ -30,14 +30,14 @@ TEST(ShapedTypeTest, CloneMemref) {
AffineMap map = makeStridedLinearLayoutMap({2, 3}, 5, &context);
ShapedType memrefType =
- MemRefType::Builder(memrefOriginalShape, memrefOriginalType)
+ (ShapedType)MemRefType::Builder(memrefOriginalShape, memrefOriginalType)
.setMemorySpace(memSpace)
.setLayout(AffineMapAttr::get(map));
// Update shape.
llvm::SmallVector<int64_t> memrefNewShape({30, 40});
ASSERT_NE(memrefOriginalShape, memrefNewShape);
ASSERT_EQ(memrefType.clone(memrefNewShape),
- (MemRefType)MemRefType::Builder(memrefNewShape, memrefOriginalType)
+ (ShapedType)MemRefType::Builder(memrefNewShape, memrefOriginalType)
.setMemorySpace(memSpace)
.setLayout(AffineMapAttr::get(map)));
// Update type.
@@ -81,25 +81,29 @@ TEST(ShapedTypeTest, CloneTensor) {
// Update shape.
llvm::SmallVector<int64_t> tensorNewShape({30, 40});
ASSERT_NE(tensorOriginalShape, tensorNewShape);
- ASSERT_EQ(tensorType.clone(tensorNewShape),
- RankedTensorType::get(tensorNewShape, tensorOriginalType));
+ ASSERT_EQ(
+ tensorType.clone(tensorNewShape),
+ (ShapedType)RankedTensorType::get(tensorNewShape, tensorOriginalType));
// Update type.
Type tensorNewType = f32;
ASSERT_NE(tensorOriginalType, tensorNewType);
- ASSERT_EQ(tensorType.clone(tensorNewType),
- RankedTensorType::get(tensorOriginalShape, tensorNewType));
+ ASSERT_EQ(
+ tensorType.clone(tensorNewType),
+ (ShapedType)RankedTensorType::get(tensorOriginalShape, tensorNewType));
// Update both.
ASSERT_EQ(tensorType.clone(tensorNewShape, tensorNewType),
- RankedTensorType::get(tensorNewShape, tensorNewType));
+ (ShapedType)RankedTensorType::get(tensorNewShape, tensorNewType));
// Test unranked tensor cloning.
ShapedType unrankedTensorType = UnrankedTensorType::get(tensorOriginalType);
- ASSERT_EQ(unrankedTensorType.clone(tensorNewShape),
- RankedTensorType::get(tensorNewShape, tensorOriginalType));
+ ASSERT_EQ(
+ unrankedTensorType.clone(tensorNewShape),
+ (ShapedType)RankedTensorType::get(tensorNewShape, tensorOriginalType));
ASSERT_EQ(unrankedTensorType.clone(tensorNewType),
- UnrankedTensorType::get(tensorNewType));
- ASSERT_EQ(unrankedTensorType.clone(tensorNewShape),
- RankedTensorType::get(tensorNewShape, tensorOriginalType));
+ (ShapedType)UnrankedTensorType::get(tensorNewType));
+ ASSERT_EQ(
+ unrankedTensorType.clone(tensorNewShape),
+ (ShapedType)RankedTensorType::get(tensorNewShape, tensorOriginalType));
}
TEST(ShapedTypeTest, CloneVector) {
More information about the Mlir-commits
mailing list