[Mlir-commits] [mlir] [mlir] Convert TensorType and BaseMemRefType to interfaces (PR #133053)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Mar 26 02:00:59 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-tensor
Author: Andrei Golubev (andrey-golubev)
<details>
<summary>Changes</summary>
Existing design assumes "TensorType" is only a built-in (un)ranked tensor and "BaseMemRefType" is only a built-in (un)ranked memref. This means that the generic logic operating on "tensors" and "memrefs" is limited to just built-ins, no compatible user types allowed. For instance, this becomes important in one-shot bufferization when converting "user tensor" to "user memref" via the common infrastructure.
Remove this behaviour - that seems accidental - by following the footsteps of ShapedType (see 676bfb2a226e705d801016bb433b68a1e09a1e10). As with ShapedType, "tensor" and "memref" seem to always aspire to be interfaces.
---
Patch is 36.19 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/133053.diff
11 Files Affected:
- (modified) mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td (+7-5)
- (modified) mlir/include/mlir/IR/BuiltinTypeInterfaces.td (+64-4)
- (modified) mlir/include/mlir/IR/BuiltinTypes.h (-110)
- (modified) mlir/include/mlir/IR/BuiltinTypes.td (+40-18)
- (modified) mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp (+6-6)
- (modified) mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp (+1-2)
- (modified) mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp (+15-12)
- (modified) mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp (+7)
- (modified) mlir/lib/IR/BuiltinTypes.cpp (+30-91)
- (modified) mlir/test/lib/Dialect/Test/TestTypeDefs.td (+46)
- (modified) mlir/unittests/IR/InterfaceTest.cpp (+34)
``````````diff
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
index ccd91a928e1dd..248ef9f855b14 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
@@ -31,7 +31,7 @@ class XeGPUTypeDef<string name, string typeMnemonic, list<Trait> traits = [],
}
def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc",
- [ShapedTypeInterface], "::mlir::TensorType"> {
+ [TensorTypeInterface]> {
let summary = "TensorDesc describing regions of interested data.";
let description = [{
TensorDesc is a type designed to describe regions of the interested data as well as some
@@ -105,7 +105,6 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc",
];
let extraClassDeclaration = [{
- using TensorType::clone;
using mlir::ShapedType::Trait<TensorDescType>::getElementTypeBitWidth;
using mlir::ShapedType::Trait<TensorDescType>::getRank;
using mlir::ShapedType::Trait<TensorDescType>::getNumElements;
@@ -115,8 +114,11 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc",
using mlir::ShapedType::Trait<TensorDescType>::getDimSize;
using mlir::ShapedType::Trait<TensorDescType>::getDynamicDimIndex;
+ TensorDescType cloneWith(std::optional<ArrayRef<int64_t>> shape, Type elementType) const;
+ bool hasRank() const { return true; }
+
TensorDescType clone(::mlir::Type elementType) {
- return llvm::cast<TensorDescType>(cloneWith(getShape(), elementType));
+ return cloneWith(getShape(), elementType);
}
BlockTensorDescAttr getEncodingAsBlockTensorDescAttr() const {
@@ -144,7 +146,7 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc",
return MemorySpace::Global;
}
- int getArrayLength() {
+ int getArrayLength() const {
auto attr = getEncoding();
auto block_attr = mlir::dyn_cast_if_present<BlockTensorDescAttr>(attr);
assert((!attr || block_attr) && "invalid on non BlockTensorDescAttr.");
@@ -154,7 +156,7 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc",
return 1;
}
- bool getBoundaryCheck() {
+ bool getBoundaryCheck() const {
auto attr = getEncoding();
auto block_attr = mlir::dyn_cast_if_present<BlockTensorDescAttr>(attr);
assert((!attr || block_attr) && "invalid on non BlockTensorDescAttr.");
diff --git a/mlir/include/mlir/IR/BuiltinTypeInterfaces.td b/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
index 8aa2c55570153..a26b7f25fcf10 100644
--- a/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
+++ b/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
@@ -143,21 +143,21 @@ def ShapedTypeInterface : TypeInterface<"ShapedType"> {
/// Return the number of elements present in the given shape.
static int64_t getNumElements(ArrayRef<int64_t> shape);
+ }];
+ let extraSharedClassDeclaration = [{
/// Return a clone of this type with the given new shape and element type.
/// The returned type is ranked, even if this type is unranked.
auto clone(::llvm::ArrayRef<int64_t> shape, Type elementType) {
- return cloneWith(shape, elementType);
+ return $_type.cloneWith(shape, elementType);
}
/// Return a clone of this type with the given new shape. The returned type
/// is ranked, even if this type is unranked.
auto clone(::llvm::ArrayRef<int64_t> shape) {
- return cloneWith(shape, getElementType());
+ return $_type.cloneWith(shape, $_type.getElementType());
}
- }];
- let extraSharedClassDeclaration = [{
/// Return a clone of this type with the given new element type. The
/// returned type is ranked if and only if this type is ranked. In that
/// case, the returned type has the same shape as this type.
@@ -227,4 +227,64 @@ def ShapedTypeInterface : TypeInterface<"ShapedType"> {
}];
}
+//===----------------------------------------------------------------------===//
+// TensorTypeInterface
+//===----------------------------------------------------------------------===//
+
+def TensorTypeInterface : TypeInterface<"TensorType", [ShapedTypeInterface]> {
+ let cppNamespace = "::mlir";
+ let description = [{
+ This interface provides a shared interface type for ranked, unranked and any
+ user-specified tensor types.
+
+ This interface attaches the ShapedTypeInterface to act as a mixin to
+ provide many useful utility functions.
+ }];
+
+ let extraClassDeclaration = [{
+ /// Return true if the specified element type is ok in a tensor.
+ static bool isValidElementType(::mlir::Type type);
+ }];
+
+ let extraClassOf = [{
+ return $_type.hasTrait<::mlir::TensorType::Trait>();
+ }];
+}
+
+//===----------------------------------------------------------------------===//
+// BaseMemRefTypeInterface
+//===----------------------------------------------------------------------===//
+
+def BaseMemRefTypeInterface : TypeInterface<"BaseMemRefType", [ShapedTypeInterface]> {
+ let cppNamespace = "::mlir";
+ let description = [{
+ This interface provides a shared interface type for ranked, unranked and any
+ user-specified memref types.
+
+ This interface attaches the ShapedTypeInterface to act as a mixin to
+ provide many useful utility functions.
+ }];
+
+ let methods = [
+ InterfaceMethod<[{
+ Returns the memory space in which data referred to by this memref resides.
+ }],
+ "::mlir::Attribute", "getMemorySpace">,
+ InterfaceMethod<[{
+ [deprecated] Returns the memory space in old raw integer representation.
+ New `Attribute getMemorySpace()` method should be used instead.
+ }],
+ "unsigned", "getMemorySpaceAsInt">,
+ ];
+
+ let extraClassDeclaration = [{
+ /// Return true if the specified element type is ok in a memref.
+ static bool isValidElementType(::mlir::Type type);
+ }];
+
+ let extraClassOf = [{
+ return $_type.hasTrait<::mlir::BaseMemRefType::Trait>();
+ }];
+}
+
#endif // MLIR_IR_BUILTINTYPEINTERFACES_TD_
diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h
index df1e02732617d..4f3365492f720 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.h
+++ b/mlir/include/mlir/IR/BuiltinTypes.h
@@ -43,108 +43,6 @@ template <typename ConcreteType>
class ValueSemantics
: public TypeTrait::TraitBase<ConcreteType, ValueSemantics> {};
-//===----------------------------------------------------------------------===//
-// TensorType
-//===----------------------------------------------------------------------===//
-
-/// Tensor types represent multi-dimensional arrays, and have two variants:
-/// RankedTensorType and UnrankedTensorType.
-/// Note: This class attaches the ShapedType trait to act as a mixin to
-/// provide many useful utility functions. This inheritance has no effect
-/// on derived tensor types.
-class TensorType : public Type, public ShapedType::Trait<TensorType> {
-public:
- using Type::Type;
-
- /// Returns the element type of this tensor type.
- Type getElementType() const;
-
- /// Returns if this type is ranked, i.e. it has a known number of dimensions.
- bool hasRank() const;
-
- /// Returns the shape of this tensor type.
- ArrayRef<int64_t> getShape() const;
-
- /// Clone this type with the given shape and element type. If the
- /// provided shape is `std::nullopt`, the current shape of the type is used.
- TensorType cloneWith(std::optional<ArrayRef<int64_t>> shape,
- Type elementType) const;
-
- // Make sure that base class overloads are visible.
- using ShapedType::Trait<TensorType>::clone;
-
- /// Return a clone of this type with the given new shape and element type.
- /// The returned type is ranked, even if this type is unranked.
- RankedTensorType clone(ArrayRef<int64_t> shape, Type elementType) const;
-
- /// Return a clone of this type with the given new shape. The returned type
- /// is ranked, even if this type is unranked.
- RankedTensorType clone(ArrayRef<int64_t> shape) const;
-
- /// Return true if the specified element type is ok in a tensor.
- static bool isValidElementType(Type type);
-
- /// Methods for support type inquiry through isa, cast, and dyn_cast.
- static bool classof(Type type);
-
- /// Allow implicit conversion to ShapedType.
- operator ShapedType() const { return llvm::cast<ShapedType>(*this); }
-};
-
-//===----------------------------------------------------------------------===//
-// BaseMemRefType
-//===----------------------------------------------------------------------===//
-
-/// This class provides a shared interface for ranked and unranked memref types.
-/// Note: This class attaches the ShapedType trait to act as a mixin to
-/// provide many useful utility functions. This inheritance has no effect
-/// on derived memref types.
-class BaseMemRefType : public Type, public ShapedType::Trait<BaseMemRefType> {
-public:
- using Type::Type;
-
- /// Returns the element type of this memref type.
- Type getElementType() const;
-
- /// Returns if this type is ranked, i.e. it has a known number of dimensions.
- bool hasRank() const;
-
- /// Returns the shape of this memref type.
- ArrayRef<int64_t> getShape() const;
-
- /// Clone this type with the given shape and element type. If the
- /// provided shape is `std::nullopt`, the current shape of the type is used.
- BaseMemRefType cloneWith(std::optional<ArrayRef<int64_t>> shape,
- Type elementType) const;
-
- // Make sure that base class overloads are visible.
- using ShapedType::Trait<BaseMemRefType>::clone;
-
- /// Return a clone of this type with the given new shape and element type.
- /// The returned type is ranked, even if this type is unranked.
- MemRefType clone(ArrayRef<int64_t> shape, Type elementType) const;
-
- /// Return a clone of this type with the given new shape. The returned type
- /// is ranked, even if this type is unranked.
- MemRefType clone(ArrayRef<int64_t> shape) const;
-
- /// Return true if the specified element type is ok in a memref.
- static bool isValidElementType(Type type);
-
- /// Methods for support type inquiry through isa, cast, and dyn_cast.
- static bool classof(Type type);
-
- /// Returns the memory space in which data referred to by this memref resides.
- Attribute getMemorySpace() const;
-
- /// [deprecated] Returns the memory space in old raw integer representation.
- /// New `Attribute getMemorySpace()` method should be used instead.
- unsigned getMemorySpaceAsInt() const;
-
- /// Allow implicit conversion to ShapedType.
- operator ShapedType() const { return llvm::cast<ShapedType>(*this); }
-};
-
} // namespace mlir
//===----------------------------------------------------------------------===//
@@ -390,10 +288,6 @@ class FixedVectorType : public VectorType {
// Deferred Method Definitions
//===----------------------------------------------------------------------===//
-inline bool BaseMemRefType::classof(Type type) {
- return llvm::isa<MemRefType, UnrankedMemRefType>(type);
-}
-
inline bool BaseMemRefType::isValidElementType(Type type) {
return type.isIntOrIndexOrFloat() ||
llvm::isa<ComplexType, MemRefType, VectorType, UnrankedMemRefType>(
@@ -401,10 +295,6 @@ inline bool BaseMemRefType::isValidElementType(Type type) {
llvm::isa<MemRefElementTypeInterface>(type);
}
-inline bool TensorType::classof(Type type) {
- return llvm::isa<RankedTensorType, UnrankedTensorType>(type);
-}
-
//===----------------------------------------------------------------------===//
// Type Utilities
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index af474b3e3ec47..575ae6a263b1b 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -542,8 +542,8 @@ def Builtin_Integer : Builtin_Type<"Integer", "integer"> {
//===----------------------------------------------------------------------===//
def Builtin_MemRef : Builtin_Type<"MemRef", "memref", [
- ShapedTypeInterface
- ], "BaseMemRefType"> {
+ BaseMemRefTypeInterface
+ ]> {
let summary = "Shaped reference to a region of memory";
let description = [{
Syntax:
@@ -794,7 +794,7 @@ def Builtin_MemRef : Builtin_Type<"MemRef", "memref", [
"unsigned":$memorySpaceInd)>
];
let extraClassDeclaration = [{
- using BaseMemRefType::clone;
+ using ShapedType::Trait<MemRefType>::clone;
using ShapedType::Trait<MemRefType>::getElementTypeBitWidth;
using ShapedType::Trait<MemRefType>::getRank;
using ShapedType::Trait<MemRefType>::getNumElements;
@@ -854,6 +854,13 @@ def Builtin_MemRef : Builtin_Type<"MemRef", "memref", [
/// Return "true" if the last dimension has a static unit stride. Also
/// return "true" for types with no strides.
bool isLastDimUnitStride();
+
+ /// Returns if this type is ranked (always true).
+ bool hasRank() const { return true; }
+
+ /// Returns a clone of this type with the given shape and element
+ /// type. If a shape is not provided, the current shape of the type is used.
+ MemRefType cloneWith(std::optional<ArrayRef<int64_t>> shape, Type elementType) const;
}];
let skipDefaultBuilders = 1;
let genVerifyDecl = 1;
@@ -934,8 +941,8 @@ def Builtin_Opaque : Builtin_Type<"Opaque", "opaque"> {
//===----------------------------------------------------------------------===//
def Builtin_RankedTensor : Builtin_Type<"RankedTensor", "tensor", [
- ShapedTypeInterface, ValueSemantics
- ], "TensorType"> {
+ TensorTypeInterface, ValueSemantics
+ ]> {
let summary = "Multi-dimensional array with a fixed number of dimensions";
let description = [{
Syntax:
@@ -1016,7 +1023,7 @@ def Builtin_RankedTensor : Builtin_Type<"RankedTensor", "tensor", [
}]>
];
let extraClassDeclaration = [{
- using TensorType::clone;
+ using ShapedType::Trait<RankedTensorType>::clone;
using ShapedType::Trait<RankedTensorType>::getElementTypeBitWidth;
using ShapedType::Trait<RankedTensorType>::getRank;
using ShapedType::Trait<RankedTensorType>::getNumElements;
@@ -1033,7 +1040,7 @@ def Builtin_RankedTensor : Builtin_Type<"RankedTensor", "tensor", [
/// Return a clone of this type with the given new element type and the same
/// shape as this type.
RankedTensorType clone(::mlir::Type elementType) {
- return ::llvm::cast<RankedTensorType>(cloneWith(getShape(), elementType));
+ return cloneWith(getShape(), elementType);
}
/// Return a clone of this type without the encoding.
@@ -1041,6 +1048,13 @@ def Builtin_RankedTensor : Builtin_Type<"RankedTensor", "tensor", [
return RankedTensorType::get(getShape(), getElementType());
}
+ /// Returns if this type is ranked (always true).
+ bool hasRank() const { return true; }
+
+ /// Returns a clone of this type with the given shape and element
+ /// type. If a shape is not provided, the current shape of the type is used.
+ RankedTensorType cloneWith(std::optional<ArrayRef<int64_t>> shape, Type elementType) const;
+
/// Return a clone of this type with the given new encoding and the same
/// shape and element type as this type.
RankedTensorType cloneWithEncoding(::mlir::Attribute encoding) {
@@ -1123,8 +1137,8 @@ def Builtin_Tuple : Builtin_Type<"Tuple", "tuple"> {
//===----------------------------------------------------------------------===//
def Builtin_UnrankedMemRef : Builtin_Type<"UnrankedMemRef", "unranked_memref", [
- ShapedTypeInterface
- ], "BaseMemRefType"> {
+ BaseMemRefTypeInterface
+ ]> {
let summary = "Shaped reference, with unknown rank, to a region of memory";
let description = [{
Syntax:
@@ -1170,7 +1184,7 @@ def Builtin_UnrankedMemRef : Builtin_Type<"UnrankedMemRef", "unranked_memref", [
}]>
];
let extraClassDeclaration = [{
- using BaseMemRefType::clone;
+ using ShapedType::Trait<UnrankedMemRefType>::clone;
using ShapedType::Trait<UnrankedMemRefType>::getElementTypeBitWidth;
using ShapedType::Trait<UnrankedMemRefType>::getRank;
using ShapedType::Trait<UnrankedMemRefType>::getNumElements;
@@ -1186,11 +1200,12 @@ def Builtin_UnrankedMemRef : Builtin_Type<"UnrankedMemRef", "unranked_memref", [
/// New `Attribute getMemorySpace()` method should be used instead.
unsigned getMemorySpaceAsInt() const;
- /// Return a clone of this type with the given new element type and the same
- /// shape as this type.
- MemRefType clone(::mlir::Type elementType) {
- return ::llvm::cast<MemRefType>(cloneWith(getShape(), elementType));
- }
+ /// Returns if this type is ranked (always false).
+ bool hasRank() const { return false; }
+
+ /// Returns a clone of this type with the given shape and element
+ /// type. If a shape is not provided, the current shape of the type is used.
+ BaseMemRefType cloneWith(std::optional<ArrayRef<int64_t>> shape, Type elementType) const;
}];
let skipDefaultBuilders = 1;
let genVerifyDecl = 1;
@@ -1201,8 +1216,8 @@ def Builtin_UnrankedMemRef : Builtin_Type<"UnrankedMemRef", "unranked_memref", [
//===----------------------------------------------------------------------===//
def Builtin_UnrankedTensor : Builtin_Type<"UnrankedTensor", "unranked_tensor", [
- ShapedTypeInterface, ValueSemantics
- ], "TensorType"> {
+ TensorTypeInterface, ValueSemantics
+ ]> {
let summary = "Multi-dimensional array with unknown dimensions";
let description = [{
Syntax:
@@ -1229,7 +1244,7 @@ def Builtin_UnrankedTensor : Builtin_Type<"UnrankedTensor", "unranked_tensor", [
}]>
];
let extraClassDeclaration = [{
- using TensorType::clone;
+ using ShapedType::Trait<UnrankedTensorType>::clone;
using ShapedType::Trait<UnrankedTensorType>::getElementTypeBitWidth;
using ShapedType::Trait<UnrankedTensorType>::getRank;
using ShapedType::Trait<UnrankedTensorType>::getNumElements;
@@ -1240,6 +1255,13 @@ def Builtin_UnrankedTensor : Builtin_Type<"UnrankedTensor", "unranked_tensor", [
using ShapedType::Trait<UnrankedTensorType>::getDynamicDimIndex;
ArrayRef<int64_t> getShape() const { return std::nullopt; }
+
+ /// Returns if this type is ranked (always false).
+ bool hasRank() const { return false; }
+
+ /// Returns a clone of this type with the given shape and element
+ /// type. If a shape is not provided, the current shape of the type is used.
+ TensorType cloneWith(std::optional<ArrayRef<int64_t>> shape, Type elementType) const;
}];
let skipDefaultBuilders = 1;
let genVerifyDecl = 1;
diff --git a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
index 5f23a33049f87..1d1bcee8600a8 100644
--- a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
+++ b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
@@ -41,7 +41,7 @@ TensorType inferReshapeInputType(TypedValue<TensorType> input,
// 0D tensor. While such construct is not incorrect on its own, bufferization
// cannot properly handle it at the moment, so we avoid it.
SmallVector<int64_t> shape(input.getType().getRank(), 1);
- return input.getType().clone(shape);
+ return mlir::cast<TensorType>(input.getType().clone(shape));
}
// Infer the result type of 'tensor.expand_shape' in the collapse-expand
@@ -51,7 +51,7 @@ TensorType inferReshapeExpandedType(TensorType inputType,
// Special case for 0D output tensor. Note: Watch out when using Type::clone()
// with just '{}', as it will invoke the incorrect overload.
if (newShape.empty())
- return inputType.clone(ArrayRef<int64_t>{});
+ return mlir::cast<TensorType>(inputType.clone(ArrayRef<int64_t>{}));
// Check if the input is static, and if so, get its total size
bool inputIsStatic = inputType.hasStaticShape();
@@ -98,7 +98,7 @@ TensorType inferReshapeExpandedType(TensorType inputType,
assert(!inputIsStatic || resultIsStatic);
// Create result type
- return inputType.clone(resultShape);
+ return mlir::cast<TensorType>(inputType.clone(resultShape));
}
// Infer the result type of 'tensor.collapse_shape' in the collapse-expand
@@ -108,11 +108,11 @@ TensorType inferReshapeCollapsedType(TensorType lhsType, Ten...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/133053
More information about the Mlir-commits
mailing list