[Mlir-commits] [mlir] [mlir] Convert TensorType and BaseMemRefType to interfaces (PR #133053)
Andrei Golubev
llvmlistbot at llvm.org
Wed Mar 26 02:00:22 PDT 2025
https://github.com/andrey-golubev created https://github.com/llvm/llvm-project/pull/133053
Existing design assumes "TensorType" is only a built-in (un)ranked tensor and "BaseMemRefType" is only a built-in (un)ranked memref. This means that the generic logic operating on "tensors" and "memrefs" is limited to just built-ins, no compatible user types allowed. For instance, this becomes important in one-shot bufferization when converting "user tensor" to "user memref" via the common infrastructure.
Remove this behaviour - that seems accidental - by following the footsteps of ShapedType (see 676bfb2a226e705d801016bb433b68a1e09a1e10). As with ShapedType, "tensor" and "memref" seem to always aspire to be interfaces.
>From 18d0dbac63e4a45f493ad52fe3d5f54d6ae24c91 Mon Sep 17 00:00:00 2001
From: "Golubev, Andrey" <andrey.golubev at intel.com>
Date: Tue, 25 Mar 2025 16:10:06 +0000
Subject: [PATCH] [mlir] Convert TensorType and BaseMemRefType to interfaces
Existing design assumes "TensorType" is only a built-in (un)ranked
tensor and "BaseMemRefType" is only a built-in (un)ranked memref. This
means that the generic logic operating on "tensors" and "memrefs" is
limited to just built-ins, no compatible user types allowed. For
instance, this becomes important in one-shot bufferization when
converting "user tensor" to "user memref" via the common infrastructure.
Remove this behaviour - that seems accidental - by following the
footsteps of ShapedType (see 676bfb2a226e705d801016bb433b68a1e09a1e10).
As with ShapedType, "tensor" and "memref" seem to always aspire to be
interfaces.
---
.../mlir/Dialect/XeGPU/IR/XeGPUTypes.td | 12 +-
mlir/include/mlir/IR/BuiltinTypeInterfaces.td | 68 +++++++++-
mlir/include/mlir/IR/BuiltinTypes.h | 110 ----------------
mlir/include/mlir/IR/BuiltinTypes.td | 58 ++++++---
.../Conversion/TosaToTensor/TosaToTensor.cpp | 12 +-
.../IR/BufferizableOpInterface.cpp | 3 +-
.../BufferizableOpInterfaceImpl.cpp | 27 ++--
mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp | 7 +
mlir/lib/IR/BuiltinTypes.cpp | 121 +++++-------------
mlir/test/lib/Dialect/Test/TestTypeDefs.td | 46 +++++++
mlir/unittests/IR/InterfaceTest.cpp | 34 +++++
11 files changed, 250 insertions(+), 248 deletions(-)
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
index ccd91a928e1dd..248ef9f855b14 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
@@ -31,7 +31,7 @@ class XeGPUTypeDef<string name, string typeMnemonic, list<Trait> traits = [],
}
def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc",
- [ShapedTypeInterface], "::mlir::TensorType"> {
+ [TensorTypeInterface]> {
let summary = "TensorDesc describing regions of interested data.";
let description = [{
TensorDesc is a type designed to describe regions of the interested data as well as some
@@ -105,7 +105,6 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc",
];
let extraClassDeclaration = [{
- using TensorType::clone;
using mlir::ShapedType::Trait<TensorDescType>::getElementTypeBitWidth;
using mlir::ShapedType::Trait<TensorDescType>::getRank;
using mlir::ShapedType::Trait<TensorDescType>::getNumElements;
@@ -115,8 +114,11 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc",
using mlir::ShapedType::Trait<TensorDescType>::getDimSize;
using mlir::ShapedType::Trait<TensorDescType>::getDynamicDimIndex;
+ TensorDescType cloneWith(std::optional<ArrayRef<int64_t>> shape, Type elementType) const;
+ bool hasRank() const { return true; }
+
TensorDescType clone(::mlir::Type elementType) {
- return llvm::cast<TensorDescType>(cloneWith(getShape(), elementType));
+ return cloneWith(getShape(), elementType);
}
BlockTensorDescAttr getEncodingAsBlockTensorDescAttr() const {
@@ -144,7 +146,7 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc",
return MemorySpace::Global;
}
- int getArrayLength() {
+ int getArrayLength() const {
auto attr = getEncoding();
auto block_attr = mlir::dyn_cast_if_present<BlockTensorDescAttr>(attr);
assert((!attr || block_attr) && "invalid on non BlockTensorDescAttr.");
@@ -154,7 +156,7 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc",
return 1;
}
- bool getBoundaryCheck() {
+ bool getBoundaryCheck() const {
auto attr = getEncoding();
auto block_attr = mlir::dyn_cast_if_present<BlockTensorDescAttr>(attr);
assert((!attr || block_attr) && "invalid on non BlockTensorDescAttr.");
diff --git a/mlir/include/mlir/IR/BuiltinTypeInterfaces.td b/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
index 8aa2c55570153..a26b7f25fcf10 100644
--- a/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
+++ b/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
@@ -143,21 +143,21 @@ def ShapedTypeInterface : TypeInterface<"ShapedType"> {
/// Return the number of elements present in the given shape.
static int64_t getNumElements(ArrayRef<int64_t> shape);
+ }];
+ let extraSharedClassDeclaration = [{
/// Return a clone of this type with the given new shape and element type.
/// The returned type is ranked, even if this type is unranked.
auto clone(::llvm::ArrayRef<int64_t> shape, Type elementType) {
- return cloneWith(shape, elementType);
+ return $_type.cloneWith(shape, elementType);
}
/// Return a clone of this type with the given new shape. The returned type
/// is ranked, even if this type is unranked.
auto clone(::llvm::ArrayRef<int64_t> shape) {
- return cloneWith(shape, getElementType());
+ return $_type.cloneWith(shape, $_type.getElementType());
}
- }];
- let extraSharedClassDeclaration = [{
/// Return a clone of this type with the given new element type. The
/// returned type is ranked if and only if this type is ranked. In that
/// case, the returned type has the same shape as this type.
@@ -227,4 +227,64 @@ def ShapedTypeInterface : TypeInterface<"ShapedType"> {
}];
}
+//===----------------------------------------------------------------------===//
+// TensorTypeInterface
+//===----------------------------------------------------------------------===//
+
+def TensorTypeInterface : TypeInterface<"TensorType", [ShapedTypeInterface]> {
+ let cppNamespace = "::mlir";
+ let description = [{
+ This interface provides a shared interface type for ranked, unranked and any
+ user-specified tensor types.
+
+ This interface attaches the ShapedTypeInterface to act as a mixin to
+ provide many useful utility functions.
+ }];
+
+ let extraClassDeclaration = [{
+ /// Return true if the specified element type is ok in a tensor.
+ static bool isValidElementType(::mlir::Type type);
+ }];
+
+ let extraClassOf = [{
+ return $_type.hasTrait<::mlir::TensorType::Trait>();
+ }];
+}
+
+//===----------------------------------------------------------------------===//
+// BaseMemRefTypeInterface
+//===----------------------------------------------------------------------===//
+
+def BaseMemRefTypeInterface : TypeInterface<"BaseMemRefType", [ShapedTypeInterface]> {
+ let cppNamespace = "::mlir";
+ let description = [{
+ This interface provides a shared interface type for ranked, unranked and any
+ user-specified memref types.
+
+ This interface attaches the ShapedTypeInterface to act as a mixin to
+ provide many useful utility functions.
+ }];
+
+ let methods = [
+ InterfaceMethod<[{
+ Returns the memory space in which data referred to by this memref resides.
+ }],
+ "::mlir::Attribute", "getMemorySpace">,
+ InterfaceMethod<[{
+ [deprecated] Returns the memory space in old raw integer representation.
+ New `Attribute getMemorySpace()` method should be used instead.
+ }],
+ "unsigned", "getMemorySpaceAsInt">,
+ ];
+
+ let extraClassDeclaration = [{
+ /// Return true if the specified element type is ok in a memref.
+ static bool isValidElementType(::mlir::Type type);
+ }];
+
+ let extraClassOf = [{
+ return $_type.hasTrait<::mlir::BaseMemRefType::Trait>();
+ }];
+}
+
#endif // MLIR_IR_BUILTINTYPEINTERFACES_TD_
diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h
index df1e02732617d..4f3365492f720 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.h
+++ b/mlir/include/mlir/IR/BuiltinTypes.h
@@ -43,108 +43,6 @@ template <typename ConcreteType>
class ValueSemantics
: public TypeTrait::TraitBase<ConcreteType, ValueSemantics> {};
-//===----------------------------------------------------------------------===//
-// TensorType
-//===----------------------------------------------------------------------===//
-
-/// Tensor types represent multi-dimensional arrays, and have two variants:
-/// RankedTensorType and UnrankedTensorType.
-/// Note: This class attaches the ShapedType trait to act as a mixin to
-/// provide many useful utility functions. This inheritance has no effect
-/// on derived tensor types.
-class TensorType : public Type, public ShapedType::Trait<TensorType> {
-public:
- using Type::Type;
-
- /// Returns the element type of this tensor type.
- Type getElementType() const;
-
- /// Returns if this type is ranked, i.e. it has a known number of dimensions.
- bool hasRank() const;
-
- /// Returns the shape of this tensor type.
- ArrayRef<int64_t> getShape() const;
-
- /// Clone this type with the given shape and element type. If the
- /// provided shape is `std::nullopt`, the current shape of the type is used.
- TensorType cloneWith(std::optional<ArrayRef<int64_t>> shape,
- Type elementType) const;
-
- // Make sure that base class overloads are visible.
- using ShapedType::Trait<TensorType>::clone;
-
- /// Return a clone of this type with the given new shape and element type.
- /// The returned type is ranked, even if this type is unranked.
- RankedTensorType clone(ArrayRef<int64_t> shape, Type elementType) const;
-
- /// Return a clone of this type with the given new shape. The returned type
- /// is ranked, even if this type is unranked.
- RankedTensorType clone(ArrayRef<int64_t> shape) const;
-
- /// Return true if the specified element type is ok in a tensor.
- static bool isValidElementType(Type type);
-
- /// Methods for support type inquiry through isa, cast, and dyn_cast.
- static bool classof(Type type);
-
- /// Allow implicit conversion to ShapedType.
- operator ShapedType() const { return llvm::cast<ShapedType>(*this); }
-};
-
-//===----------------------------------------------------------------------===//
-// BaseMemRefType
-//===----------------------------------------------------------------------===//
-
-/// This class provides a shared interface for ranked and unranked memref types.
-/// Note: This class attaches the ShapedType trait to act as a mixin to
-/// provide many useful utility functions. This inheritance has no effect
-/// on derived memref types.
-class BaseMemRefType : public Type, public ShapedType::Trait<BaseMemRefType> {
-public:
- using Type::Type;
-
- /// Returns the element type of this memref type.
- Type getElementType() const;
-
- /// Returns if this type is ranked, i.e. it has a known number of dimensions.
- bool hasRank() const;
-
- /// Returns the shape of this memref type.
- ArrayRef<int64_t> getShape() const;
-
- /// Clone this type with the given shape and element type. If the
- /// provided shape is `std::nullopt`, the current shape of the type is used.
- BaseMemRefType cloneWith(std::optional<ArrayRef<int64_t>> shape,
- Type elementType) const;
-
- // Make sure that base class overloads are visible.
- using ShapedType::Trait<BaseMemRefType>::clone;
-
- /// Return a clone of this type with the given new shape and element type.
- /// The returned type is ranked, even if this type is unranked.
- MemRefType clone(ArrayRef<int64_t> shape, Type elementType) const;
-
- /// Return a clone of this type with the given new shape. The returned type
- /// is ranked, even if this type is unranked.
- MemRefType clone(ArrayRef<int64_t> shape) const;
-
- /// Return true if the specified element type is ok in a memref.
- static bool isValidElementType(Type type);
-
- /// Methods for support type inquiry through isa, cast, and dyn_cast.
- static bool classof(Type type);
-
- /// Returns the memory space in which data referred to by this memref resides.
- Attribute getMemorySpace() const;
-
- /// [deprecated] Returns the memory space in old raw integer representation.
- /// New `Attribute getMemorySpace()` method should be used instead.
- unsigned getMemorySpaceAsInt() const;
-
- /// Allow implicit conversion to ShapedType.
- operator ShapedType() const { return llvm::cast<ShapedType>(*this); }
-};
-
} // namespace mlir
//===----------------------------------------------------------------------===//
@@ -390,10 +288,6 @@ class FixedVectorType : public VectorType {
// Deferred Method Definitions
//===----------------------------------------------------------------------===//
-inline bool BaseMemRefType::classof(Type type) {
- return llvm::isa<MemRefType, UnrankedMemRefType>(type);
-}
-
inline bool BaseMemRefType::isValidElementType(Type type) {
return type.isIntOrIndexOrFloat() ||
llvm::isa<ComplexType, MemRefType, VectorType, UnrankedMemRefType>(
@@ -401,10 +295,6 @@ inline bool BaseMemRefType::isValidElementType(Type type) {
llvm::isa<MemRefElementTypeInterface>(type);
}
-inline bool TensorType::classof(Type type) {
- return llvm::isa<RankedTensorType, UnrankedTensorType>(type);
-}
-
//===----------------------------------------------------------------------===//
// Type Utilities
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index af474b3e3ec47..575ae6a263b1b 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -542,8 +542,8 @@ def Builtin_Integer : Builtin_Type<"Integer", "integer"> {
//===----------------------------------------------------------------------===//
def Builtin_MemRef : Builtin_Type<"MemRef", "memref", [
- ShapedTypeInterface
- ], "BaseMemRefType"> {
+ BaseMemRefTypeInterface
+ ]> {
let summary = "Shaped reference to a region of memory";
let description = [{
Syntax:
@@ -794,7 +794,7 @@ def Builtin_MemRef : Builtin_Type<"MemRef", "memref", [
"unsigned":$memorySpaceInd)>
];
let extraClassDeclaration = [{
- using BaseMemRefType::clone;
+ using ShapedType::Trait<MemRefType>::clone;
using ShapedType::Trait<MemRefType>::getElementTypeBitWidth;
using ShapedType::Trait<MemRefType>::getRank;
using ShapedType::Trait<MemRefType>::getNumElements;
@@ -854,6 +854,13 @@ def Builtin_MemRef : Builtin_Type<"MemRef", "memref", [
/// Return "true" if the last dimension has a static unit stride. Also
/// return "true" for types with no strides.
bool isLastDimUnitStride();
+
+ /// Returns if this type is ranked (always true).
+ bool hasRank() const { return true; }
+
+ /// Returns a clone of this type with the given shape and element
+ /// type. If a shape is not provided, the current shape of the type is used.
+ MemRefType cloneWith(std::optional<ArrayRef<int64_t>> shape, Type elementType) const;
}];
let skipDefaultBuilders = 1;
let genVerifyDecl = 1;
@@ -934,8 +941,8 @@ def Builtin_Opaque : Builtin_Type<"Opaque", "opaque"> {
//===----------------------------------------------------------------------===//
def Builtin_RankedTensor : Builtin_Type<"RankedTensor", "tensor", [
- ShapedTypeInterface, ValueSemantics
- ], "TensorType"> {
+ TensorTypeInterface, ValueSemantics
+ ]> {
let summary = "Multi-dimensional array with a fixed number of dimensions";
let description = [{
Syntax:
@@ -1016,7 +1023,7 @@ def Builtin_RankedTensor : Builtin_Type<"RankedTensor", "tensor", [
}]>
];
let extraClassDeclaration = [{
- using TensorType::clone;
+ using ShapedType::Trait<RankedTensorType>::clone;
using ShapedType::Trait<RankedTensorType>::getElementTypeBitWidth;
using ShapedType::Trait<RankedTensorType>::getRank;
using ShapedType::Trait<RankedTensorType>::getNumElements;
@@ -1033,7 +1040,7 @@ def Builtin_RankedTensor : Builtin_Type<"RankedTensor", "tensor", [
/// Return a clone of this type with the given new element type and the same
/// shape as this type.
RankedTensorType clone(::mlir::Type elementType) {
- return ::llvm::cast<RankedTensorType>(cloneWith(getShape(), elementType));
+ return cloneWith(getShape(), elementType);
}
/// Return a clone of this type without the encoding.
@@ -1041,6 +1048,13 @@ def Builtin_RankedTensor : Builtin_Type<"RankedTensor", "tensor", [
return RankedTensorType::get(getShape(), getElementType());
}
+ /// Returns if this type is ranked (always true).
+ bool hasRank() const { return true; }
+
+ /// Returns a clone of this type with the given shape and element
+ /// type. If a shape is not provided, the current shape of the type is used.
+ RankedTensorType cloneWith(std::optional<ArrayRef<int64_t>> shape, Type elementType) const;
+
/// Return a clone of this type with the given new encoding and the same
/// shape and element type as this type.
RankedTensorType cloneWithEncoding(::mlir::Attribute encoding) {
@@ -1123,8 +1137,8 @@ def Builtin_Tuple : Builtin_Type<"Tuple", "tuple"> {
//===----------------------------------------------------------------------===//
def Builtin_UnrankedMemRef : Builtin_Type<"UnrankedMemRef", "unranked_memref", [
- ShapedTypeInterface
- ], "BaseMemRefType"> {
+ BaseMemRefTypeInterface
+ ]> {
let summary = "Shaped reference, with unknown rank, to a region of memory";
let description = [{
Syntax:
@@ -1170,7 +1184,7 @@ def Builtin_UnrankedMemRef : Builtin_Type<"UnrankedMemRef", "unranked_memref", [
}]>
];
let extraClassDeclaration = [{
- using BaseMemRefType::clone;
+ using ShapedType::Trait<UnrankedMemRefType>::clone;
using ShapedType::Trait<UnrankedMemRefType>::getElementTypeBitWidth;
using ShapedType::Trait<UnrankedMemRefType>::getRank;
using ShapedType::Trait<UnrankedMemRefType>::getNumElements;
@@ -1186,11 +1200,12 @@ def Builtin_UnrankedMemRef : Builtin_Type<"UnrankedMemRef", "unranked_memref", [
/// New `Attribute getMemorySpace()` method should be used instead.
unsigned getMemorySpaceAsInt() const;
- /// Return a clone of this type with the given new element type and the same
- /// shape as this type.
- MemRefType clone(::mlir::Type elementType) {
- return ::llvm::cast<MemRefType>(cloneWith(getShape(), elementType));
- }
+ /// Returns if this type is ranked (always false).
+ bool hasRank() const { return false; }
+
+ /// Returns a clone of this type with the given shape and element
+ /// type. If a shape is not provided, the current shape of the type is used.
+ BaseMemRefType cloneWith(std::optional<ArrayRef<int64_t>> shape, Type elementType) const;
}];
let skipDefaultBuilders = 1;
let genVerifyDecl = 1;
@@ -1201,8 +1216,8 @@ def Builtin_UnrankedMemRef : Builtin_Type<"UnrankedMemRef", "unranked_memref", [
//===----------------------------------------------------------------------===//
def Builtin_UnrankedTensor : Builtin_Type<"UnrankedTensor", "unranked_tensor", [
- ShapedTypeInterface, ValueSemantics
- ], "TensorType"> {
+ TensorTypeInterface, ValueSemantics
+ ]> {
let summary = "Multi-dimensional array with unknown dimensions";
let description = [{
Syntax:
@@ -1229,7 +1244,7 @@ def Builtin_UnrankedTensor : Builtin_Type<"UnrankedTensor", "unranked_tensor", [
}]>
];
let extraClassDeclaration = [{
- using TensorType::clone;
+ using ShapedType::Trait<UnrankedTensorType>::clone;
using ShapedType::Trait<UnrankedTensorType>::getElementTypeBitWidth;
using ShapedType::Trait<UnrankedTensorType>::getRank;
using ShapedType::Trait<UnrankedTensorType>::getNumElements;
@@ -1240,6 +1255,13 @@ def Builtin_UnrankedTensor : Builtin_Type<"UnrankedTensor", "unranked_tensor", [
using ShapedType::Trait<UnrankedTensorType>::getDynamicDimIndex;
ArrayRef<int64_t> getShape() const { return std::nullopt; }
+
+ /// Returns if this type is ranked (always false).
+ bool hasRank() const { return false; }
+
+ /// Returns a clone of this type with the given shape and element
+ /// type. If a shape is not provided, the current shape of the type is used.
+ TensorType cloneWith(std::optional<ArrayRef<int64_t>> shape, Type elementType) const;
}];
let skipDefaultBuilders = 1;
let genVerifyDecl = 1;
diff --git a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
index 5f23a33049f87..1d1bcee8600a8 100644
--- a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
+++ b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
@@ -41,7 +41,7 @@ TensorType inferReshapeInputType(TypedValue<TensorType> input,
// 0D tensor. While such construct is not incorrect on its own, bufferization
// cannot properly handle it at the moment, so we avoid it.
SmallVector<int64_t> shape(input.getType().getRank(), 1);
- return input.getType().clone(shape);
+ return mlir::cast<TensorType>(input.getType().clone(shape));
}
// Infer the result type of 'tensor.expand_shape' in the collapse-expand
@@ -51,7 +51,7 @@ TensorType inferReshapeExpandedType(TensorType inputType,
// Special case for 0D output tensor. Note: Watch out when using Type::clone()
// with just '{}', as it will invoke the incorrect overload.
if (newShape.empty())
- return inputType.clone(ArrayRef<int64_t>{});
+ return mlir::cast<TensorType>(inputType.clone(ArrayRef<int64_t>{}));
// Check if the input is static, and if so, get its total size
bool inputIsStatic = inputType.hasStaticShape();
@@ -98,7 +98,7 @@ TensorType inferReshapeExpandedType(TensorType inputType,
assert(!inputIsStatic || resultIsStatic);
// Create result type
- return inputType.clone(resultShape);
+ return mlir::cast<TensorType>(inputType.clone(resultShape));
}
// Infer the result type of 'tensor.collapse_shape' in the collapse-expand
@@ -108,11 +108,11 @@ TensorType inferReshapeCollapsedType(TensorType lhsType, TensorType rhsType) {
auto rhsShape = rhsType.getShape();
if (lhsShape.empty() || rhsShape.empty())
- return lhsType.clone(ArrayRef<int64_t>{});
+ return mlir::cast<TensorType>(lhsType.clone(ArrayRef<int64_t>{}));
if (ShapedType::isDynamicShape(lhsShape) ||
ShapedType::isDynamicShape(rhsShape))
- return lhsType.clone({ShapedType::kDynamic});
+ return mlir::cast<TensorType>(lhsType.clone({ShapedType::kDynamic}));
SmallVector<int64_t> intermediateShape;
unsigned currLhsDim = 0, currRhsDim = 0;
@@ -149,7 +149,7 @@ TensorType inferReshapeCollapsedType(TensorType lhsType, TensorType rhsType) {
assert(rhsShape[currRhsDim] == 1);
}
- return lhsType.clone(intermediateShape);
+ return mlir::cast<TensorType>(lhsType.clone(intermediateShape));
}
SmallVector<ReassociationExprs>
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
index 99ffa62c41a4d..e15f81a1ef433 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
@@ -735,8 +735,7 @@ void bufferization::replaceOpWithBufferizedValues(RewriterBase &rewriter,
if (llvm::isa<TensorType>(opResult.getType())) {
// The OpResult is a tensor. Such values are replaced with memrefs during
// bufferization.
- assert((llvm::isa<MemRefType>(replacement.getType()) ||
- llvm::isa<UnrankedMemRefType>(replacement.getType())) &&
+ assert(llvm::isa<BaseMemRefType>(replacement.getType()) &&
"tensor op result should be replaced with a memref value");
// The existing uses of the OpResult still expect a tensor. Insert a
// ToTensorOp. Throughout bufferization, this ToTensorOp will gradually
diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index 4ac6eca586961..8f9438babc070 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -77,9 +77,9 @@ struct CastOpInterface
// Case 3: Ranked tensor -> ranked tensor. The offsets and strides do not
// change.
auto rankedResultType = cast<RankedTensorType>(castOp.getType());
- return MemRefType::get(
+ return llvm::cast<BaseMemRefType>(MemRefType::get(
rankedResultType.getShape(), rankedResultType.getElementType(),
- llvm::cast<MemRefType>(*maybeSrcBufferType).getLayout(), memorySpace);
+ llvm::cast<MemRefType>(*maybeSrcBufferType).getLayout(), memorySpace));
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
@@ -157,8 +157,9 @@ struct CollapseShapeOpInterface
tensorResultType, srcBufferType.getMemorySpace());
}
- return memref::CollapseShapeOp::computeCollapsedType(
- srcBufferType, collapseShapeOp.getReassociationIndices());
+ return llvm::cast<BaseMemRefType>(
+ memref::CollapseShapeOp::computeCollapsedType(
+ srcBufferType, collapseShapeOp.getReassociationIndices()));
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
@@ -325,7 +326,7 @@ struct ExpandShapeOpInterface
expandShapeOp.getReassociationIndices());
if (failed(maybeResultType))
return failure();
- return *maybeResultType;
+ return llvm::cast<BaseMemRefType>(*maybeResultType);
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
@@ -405,10 +406,11 @@ struct ExtractSliceOpInterface
SmallVector<OpFoldResult> mixedOffsets = extractSliceOp.getMixedOffsets();
SmallVector<OpFoldResult> mixedSizes = extractSliceOp.getMixedSizes();
SmallVector<OpFoldResult> mixedStrides = extractSliceOp.getMixedStrides();
- return memref::SubViewOp::inferRankReducedResultType(
- extractSliceOp.getType().getShape(),
- llvm::cast<MemRefType>(*srcMemrefType), mixedOffsets, mixedSizes,
- mixedStrides);
+ return mlir::cast<BaseMemRefType>(
+ memref::SubViewOp::inferRankReducedResultType(
+ extractSliceOp.getType().getShape(),
+ llvm::cast<MemRefType>(*srcMemrefType), mixedOffsets, mixedSizes,
+ mixedStrides));
}
};
@@ -746,9 +748,10 @@ struct PadOpInterface
if (failed(maybeSrcBufferType))
return failure();
MemRefLayoutAttrInterface layout;
- return MemRefType::get(padOp.getResultType().getShape(),
- padOp.getResultType().getElementType(), layout,
- maybeSrcBufferType->getMemorySpace());
+ return llvm::cast<BaseMemRefType>(
+ MemRefType::get(padOp.getResultType().getShape(),
+ padOp.getResultType().getElementType(), layout,
+ maybeSrcBufferType->getMemorySpace()));
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index 78c242571935c..1fcf9df052ae3 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -396,6 +396,13 @@ FailureOr<VectorType> TensorDescType::getDistributedVectorType() {
getElementType());
}
+TensorDescType TensorDescType::cloneWith(std::optional<ArrayRef<int64_t>> shape,
+ Type elementType) const {
+ return TensorDescType::get(shape.value_or(this->getShape()), elementType,
+ this->getArrayLength(), this->getBoundaryCheck(),
+ this->getMemorySpace(), this->getSgMap());
+}
+
} // namespace xegpu
} // namespace mlir
diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp
index 3924d082f0628..02e7038a75fff 100644
--- a/mlir/lib/IR/BuiltinTypes.cpp
+++ b/mlir/lib/IR/BuiltinTypes.cpp
@@ -12,6 +12,7 @@
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinDialect.h"
+#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/TensorEncoding.h"
@@ -256,45 +257,6 @@ VectorType VectorType::cloneWith(std::optional<ArrayRef<int64_t>> shape,
// TensorType
//===----------------------------------------------------------------------===//
-Type TensorType::getElementType() const {
- return llvm::TypeSwitch<TensorType, Type>(*this)
- .Case<RankedTensorType, UnrankedTensorType>(
- [](auto type) { return type.getElementType(); });
-}
-
-bool TensorType::hasRank() const {
- return !llvm::isa<UnrankedTensorType>(*this);
-}
-
-ArrayRef<int64_t> TensorType::getShape() const {
- return llvm::cast<RankedTensorType>(*this).getShape();
-}
-
-TensorType TensorType::cloneWith(std::optional<ArrayRef<int64_t>> shape,
- Type elementType) const {
- if (llvm::dyn_cast<UnrankedTensorType>(*this)) {
- if (shape)
- return RankedTensorType::get(*shape, elementType);
- return UnrankedTensorType::get(elementType);
- }
-
- auto rankedTy = llvm::cast<RankedTensorType>(*this);
- if (!shape)
- return RankedTensorType::get(rankedTy.getShape(), elementType,
- rankedTy.getEncoding());
- return RankedTensorType::get(shape.value_or(rankedTy.getShape()), elementType,
- rankedTy.getEncoding());
-}
-
-RankedTensorType TensorType::clone(::llvm::ArrayRef<int64_t> shape,
- Type elementType) const {
- return ::llvm::cast<RankedTensorType>(cloneWith(shape, elementType));
-}
-
-RankedTensorType TensorType::clone(::llvm::ArrayRef<int64_t> shape) const {
- return ::llvm::cast<RankedTensorType>(cloneWith(shape, getElementType()));
-}
-
// Check if "elementType" can be an element type of a tensor.
static LogicalResult
checkTensorElementType(function_ref<InFlightDiagnostic()> emitError,
@@ -317,6 +279,12 @@ bool TensorType::isValidElementType(Type type) {
//===----------------------------------------------------------------------===//
// RankedTensorType
//===----------------------------------------------------------------------===//
+RankedTensorType
+RankedTensorType::cloneWith(std::optional<ArrayRef<int64_t>> shape,
+ Type elementType) const {
+ return RankedTensorType::get(shape.value_or(this->getShape()), elementType,
+ this->getEncoding());
+}
LogicalResult
RankedTensorType::verify(function_ref<InFlightDiagnostic()> emitError,
@@ -335,6 +303,13 @@ RankedTensorType::verify(function_ref<InFlightDiagnostic()> emitError,
// UnrankedTensorType
//===----------------------------------------------------------------------===//
+TensorType UnrankedTensorType::cloneWith(std::optional<ArrayRef<int64_t>> shape,
+ Type elementType) const {
+ if (shape)
+ return RankedTensorType::get(*shape, elementType);
+ return UnrankedTensorType::get(elementType);
+}
+
LogicalResult
UnrankedTensorType::verify(function_ref<InFlightDiagnostic()> emitError,
Type elementType) {
@@ -342,65 +317,18 @@ UnrankedTensorType::verify(function_ref<InFlightDiagnostic()> emitError,
}
//===----------------------------------------------------------------------===//
-// BaseMemRefType
+// MemRefType
//===----------------------------------------------------------------------===//
-Type BaseMemRefType::getElementType() const {
- return llvm::TypeSwitch<BaseMemRefType, Type>(*this)
- .Case<MemRefType, UnrankedMemRefType>(
- [](auto type) { return type.getElementType(); });
-}
-
-bool BaseMemRefType::hasRank() const {
- return !llvm::isa<UnrankedMemRefType>(*this);
-}
-
-ArrayRef<int64_t> BaseMemRefType::getShape() const {
- return llvm::cast<MemRefType>(*this).getShape();
-}
-
-BaseMemRefType BaseMemRefType::cloneWith(std::optional<ArrayRef<int64_t>> shape,
- Type elementType) const {
- if (llvm::dyn_cast<UnrankedMemRefType>(*this)) {
- if (!shape)
- return UnrankedMemRefType::get(elementType, getMemorySpace());
- MemRefType::Builder builder(*shape, elementType);
- builder.setMemorySpace(getMemorySpace());
- return builder;
- }
-
- MemRefType::Builder builder(llvm::cast<MemRefType>(*this));
+MemRefType MemRefType::cloneWith(std::optional<ArrayRef<int64_t>> shape,
+ Type elementType) const {
+ MemRefType::Builder builder(*this);
if (shape)
builder.setShape(*shape);
builder.setElementType(elementType);
- return builder;
-}
-
-MemRefType BaseMemRefType::clone(::llvm::ArrayRef<int64_t> shape,
- Type elementType) const {
- return ::llvm::cast<MemRefType>(cloneWith(shape, elementType));
-}
-
-MemRefType BaseMemRefType::clone(::llvm::ArrayRef<int64_t> shape) const {
- return ::llvm::cast<MemRefType>(cloneWith(shape, getElementType()));
+ return MemRefType(builder);
}
-Attribute BaseMemRefType::getMemorySpace() const {
- if (auto rankedMemRefTy = llvm::dyn_cast<MemRefType>(*this))
- return rankedMemRefTy.getMemorySpace();
- return llvm::cast<UnrankedMemRefType>(*this).getMemorySpace();
-}
-
-unsigned BaseMemRefType::getMemorySpaceAsInt() const {
- if (auto rankedMemRefTy = llvm::dyn_cast<MemRefType>(*this))
- return rankedMemRefTy.getMemorySpaceAsInt();
- return llvm::cast<UnrankedMemRefType>(*this).getMemorySpaceAsInt();
-}
-
-//===----------------------------------------------------------------------===//
-// MemRefType
-//===----------------------------------------------------------------------===//
-
std::optional<llvm::SmallDenseSet<unsigned>>
mlir::computeRankReductionMask(ArrayRef<int64_t> originalShape,
ArrayRef<int64_t> reducedShape,
@@ -888,6 +816,17 @@ bool MemRefType::isLastDimUnitStride() {
// UnrankedMemRefType
//===----------------------------------------------------------------------===//
+BaseMemRefType
+UnrankedMemRefType::cloneWith(std::optional<ArrayRef<int64_t>> shape,
+ Type elementType) const {
+ if (!shape)
+ return UnrankedMemRefType::get(elementType, getMemorySpace());
+
+ MemRefType::Builder builder(*shape, elementType);
+ builder.setMemorySpace(getMemorySpace());
+ return MemRefType(builder);
+}
+
unsigned UnrankedMemRefType::getMemorySpaceAsInt() const {
return detail::getMemorySpaceAsInt(getMemorySpace());
}
diff --git a/mlir/test/lib/Dialect/Test/TestTypeDefs.td b/mlir/test/lib/Dialect/Test/TestTypeDefs.td
index f1c31658c13ac..61fab9c889be7 100644
--- a/mlir/test/lib/Dialect/Test/TestTypeDefs.td
+++ b/mlir/test/lib/Dialect/Test/TestTypeDefs.td
@@ -403,4 +403,50 @@ def TestTypeOpAsmTypeInterface : Test_Type<"TestTypeOpAsmTypeInterface",
let mnemonic = "op_asm_type_interface";
}
+def TestTensorType : Test_Type<"TestTensor", [TensorTypeInterface]> {
+ let mnemonic = "test_tensor";
+ let parameters = (ins
+ ArrayRefParameter<"int64_t">:$shape,
+ "mlir::Type":$elementType
+ );
+ let assemblyFormat = "`<` `[` $shape `]` `,` $elementType `>`";
+
+ let extraClassDeclaration = [{
+ // ShapedTypeInterface:
+ bool hasRank() const {
+ return true;
+ }
+ test::TestTensorType cloneWith(std::optional<llvm::ArrayRef<int64_t>> shape, mlir::Type elementType) const {
+ return test::TestTensorType::get(getContext(), shape.value_or(getShape()), elementType);
+ }
+ }];
+}
+
+def TestMemrefType : Test_Type<"TestMemref", [BaseMemRefTypeInterface]> {
+ let mnemonic = "test_memref";
+ let parameters = (ins
+ ArrayRefParameter<"int64_t">:$shape,
+ "mlir::Type":$elementType,
+ DefaultValuedParameter<"mlir::Attribute", "nullptr">:$memSpace
+ );
+ let assemblyFormat = "`<` `[` $shape `]` `,` $elementType (`,` $memSpace^)? `>`";
+
+ let extraClassDeclaration = [{
+ // ShapedTypeInterface:
+ bool hasRank() const {
+ return true;
+ }
+ test::TestMemrefType cloneWith(std::optional<llvm::ArrayRef<int64_t>> shape, mlir::Type elementType) const {
+ return test::TestMemrefType::get(getContext(), shape.value_or(getShape()), elementType, getMemSpace());
+ }
+
+ // BaseMemRefTypeInterface:
+ mlir::Attribute getMemorySpace() const {
+ return getMemSpace();
+ }
+ // [deprecated]
+ unsigned getMemorySpaceAsInt() const { return 0; }
+ }];
+}
+
#endif // TEST_TYPEDEFS
diff --git a/mlir/unittests/IR/InterfaceTest.cpp b/mlir/unittests/IR/InterfaceTest.cpp
index 42196b003e7da..004367f7759f2 100644
--- a/mlir/unittests/IR/InterfaceTest.cpp
+++ b/mlir/unittests/IR/InterfaceTest.cpp
@@ -9,6 +9,7 @@
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinDialect.h"
#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/OwningOpRef.h"
#include "gtest/gtest.h"
@@ -84,3 +85,36 @@ TEST(InterfaceTest, TestImplicitConversion) {
typeA = typeB;
EXPECT_EQ(typeA, typeB);
}
+
+TEST(InterfaceTest, TestCustomTensorIsTensorType) {
+ MLIRContext context;
+ context.loadDialect<test::TestDialect>();
+
+ auto customTensorType = test::TestTensorType::get(
+ &context, {1, 2, 3}, mlir::IntegerType::get(&context, 32));
+ EXPECT_TRUE(mlir::isa<mlir::TensorType>(customTensorType));
+
+ auto customCloneType = customTensorType.clone({3, 4, 5});
+ EXPECT_EQ(customTensorType.getElementType(),
+ customCloneType.getElementType());
+ EXPECT_TRUE(mlir::isa<mlir::TensorType>(customCloneType));
+ EXPECT_TRUE(mlir::isa<test::TestTensorType>(customCloneType));
+}
+
+TEST(InterfaceTest, TestCustomMemrefIsBaseMemref) {
+ MLIRContext context;
+ context.loadDialect<test::TestDialect>();
+
+ auto customMemrefType = test::TestMemrefType::get(
+ &context, {1, 2, 3}, mlir::IntegerType::get(&context, 32),
+ mlir::StringAttr::get(&context, "some_memspace"));
+ EXPECT_TRUE(mlir::isa<mlir::BaseMemRefType>(customMemrefType));
+
+ auto customCloneType = customMemrefType.clone({3, 4, 5});
+ EXPECT_EQ(customMemrefType.getElementType(),
+ customCloneType.getElementType());
+ EXPECT_TRUE(mlir::isa<mlir::BaseMemRefType>(customCloneType));
+ EXPECT_TRUE(mlir::isa<test::TestMemrefType>(customCloneType));
+ EXPECT_EQ(customMemrefType.getMemorySpace(),
+ mlir::cast<test::TestMemrefType>(customCloneType).getMemorySpace());
+}
More information about the Mlir-commits
mailing list