[Mlir-commits] [mlir] [mlir] Convert TensorType and BaseMemRefType to interfaces (PR #133053)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Mar 26 02:00:59 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-tensor

Author: Andrei Golubev (andrey-golubev)

<details>
<summary>Changes</summary>

Existing design assumes "TensorType" is only a built-in (un)ranked tensor and "BaseMemRefType" is only a built-in (un)ranked memref. This means that the generic logic operating on "tensors" and "memrefs" is limited to just built-ins, no compatible user types allowed. For instance, this becomes important in one-shot bufferization when converting "user tensor" to "user memref" via the common infrastructure.

Remove this behaviour - that seems accidental - by following the footsteps of ShapedType (see 676bfb2a226e705d801016bb433b68a1e09a1e10). As with ShapedType, "tensor" and "memref" seem to always aspire to be interfaces.

---

Patch is 36.19 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/133053.diff


11 Files Affected:

- (modified) mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td (+7-5) 
- (modified) mlir/include/mlir/IR/BuiltinTypeInterfaces.td (+64-4) 
- (modified) mlir/include/mlir/IR/BuiltinTypes.h (-110) 
- (modified) mlir/include/mlir/IR/BuiltinTypes.td (+40-18) 
- (modified) mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp (+6-6) 
- (modified) mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp (+1-2) 
- (modified) mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp (+15-12) 
- (modified) mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp (+7) 
- (modified) mlir/lib/IR/BuiltinTypes.cpp (+30-91) 
- (modified) mlir/test/lib/Dialect/Test/TestTypeDefs.td (+46) 
- (modified) mlir/unittests/IR/InterfaceTest.cpp (+34) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
index ccd91a928e1dd..248ef9f855b14 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
@@ -31,7 +31,7 @@ class XeGPUTypeDef<string name, string typeMnemonic, list<Trait> traits = [],
 }
 
 def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc",
-        [ShapedTypeInterface], "::mlir::TensorType"> {
+        [TensorTypeInterface]> {
   let summary = "TensorDesc describing regions of interested data.";
   let description = [{
     TensorDesc is a type designed to describe regions of the interested data as well as some
@@ -105,7 +105,6 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc",
   ];
 
   let extraClassDeclaration = [{
-    using TensorType::clone;
     using mlir::ShapedType::Trait<TensorDescType>::getElementTypeBitWidth;
     using mlir::ShapedType::Trait<TensorDescType>::getRank;
     using mlir::ShapedType::Trait<TensorDescType>::getNumElements;
@@ -115,8 +114,11 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc",
     using mlir::ShapedType::Trait<TensorDescType>::getDimSize;
     using mlir::ShapedType::Trait<TensorDescType>::getDynamicDimIndex;
 
+    TensorDescType cloneWith(std::optional<ArrayRef<int64_t>> shape, Type elementType) const;
+    bool hasRank() const { return true; }
+
     TensorDescType clone(::mlir::Type elementType) {
-      return llvm::cast<TensorDescType>(cloneWith(getShape(), elementType));
+      return cloneWith(getShape(), elementType);
     }
 
     BlockTensorDescAttr getEncodingAsBlockTensorDescAttr() const {
@@ -144,7 +146,7 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc",
       return MemorySpace::Global;
     }
 
-    int getArrayLength() {
+    int getArrayLength() const {
       auto attr = getEncoding();
       auto block_attr = mlir::dyn_cast_if_present<BlockTensorDescAttr>(attr);
       assert((!attr || block_attr) && "invalid on non BlockTensorDescAttr.");
@@ -154,7 +156,7 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc",
       return 1;
     }
 
-    bool getBoundaryCheck() {
+    bool getBoundaryCheck() const {
       auto attr = getEncoding();
       auto block_attr = mlir::dyn_cast_if_present<BlockTensorDescAttr>(attr);
       assert((!attr || block_attr) && "invalid on non BlockTensorDescAttr.");
diff --git a/mlir/include/mlir/IR/BuiltinTypeInterfaces.td b/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
index 8aa2c55570153..a26b7f25fcf10 100644
--- a/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
+++ b/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
@@ -143,21 +143,21 @@ def ShapedTypeInterface : TypeInterface<"ShapedType"> {
 
     /// Return the number of elements present in the given shape.
     static int64_t getNumElements(ArrayRef<int64_t> shape);
+  }];
 
+  let extraSharedClassDeclaration = [{
     /// Return a clone of this type with the given new shape and element type.
     /// The returned type is ranked, even if this type is unranked.
     auto clone(::llvm::ArrayRef<int64_t> shape, Type elementType) {
-      return cloneWith(shape, elementType);
+      return $_type.cloneWith(shape, elementType);
     }
 
     /// Return a clone of this type with the given new shape. The returned type
     /// is ranked, even if this type is unranked.
     auto clone(::llvm::ArrayRef<int64_t> shape) {
-      return cloneWith(shape, getElementType());
+      return $_type.cloneWith(shape, $_type.getElementType());
     }
-  }];
 
-  let extraSharedClassDeclaration = [{
     /// Return a clone of this type with the given new element type. The
     /// returned type is ranked if and only if this type is ranked. In that
     /// case, the returned type has the same shape as this type.
@@ -227,4 +227,64 @@ def ShapedTypeInterface : TypeInterface<"ShapedType"> {
   }];
 }
 
+//===----------------------------------------------------------------------===//
+// TensorTypeInterface
+//===----------------------------------------------------------------------===//
+
+def TensorTypeInterface : TypeInterface<"TensorType", [ShapedTypeInterface]> {
+  let cppNamespace = "::mlir";
+  let description = [{
+    This interface provides a shared interface type for ranked, unranked and any
+    user-specified tensor types.
+
+    This interface attaches the ShapedTypeInterface to act as a mixin to
+    provide many useful utility functions.
+  }];
+
+  let extraClassDeclaration = [{
+    /// Return true if the specified element type is ok in a tensor.
+    static bool isValidElementType(::mlir::Type type);
+  }];
+
+  let extraClassOf = [{
+    return $_type.hasTrait<::mlir::TensorType::Trait>();
+  }];
+}
+
+//===----------------------------------------------------------------------===//
+// BaseMemRefTypeInterface
+//===----------------------------------------------------------------------===//
+
+def BaseMemRefTypeInterface : TypeInterface<"BaseMemRefType", [ShapedTypeInterface]> {
+  let cppNamespace = "::mlir";
+  let description = [{
+    This interface provides a shared interface type for ranked, unranked and any
+    user-specified memref types.
+
+    This interface attaches the ShapedTypeInterface to act as a mixin to
+    provide many useful utility functions.
+  }];
+
+  let methods = [
+    InterfaceMethod<[{
+      Returns the memory space in which data referred to by this memref resides.
+    }],
+    "::mlir::Attribute", "getMemorySpace">,
+    InterfaceMethod<[{
+      [deprecated] Returns the memory space in old raw integer representation.
+      New `Attribute getMemorySpace()` method should be used instead.
+    }],
+    "unsigned", "getMemorySpaceAsInt">,
+  ];
+
+  let extraClassDeclaration = [{
+    /// Return true if the specified element type is ok in a memref.
+    static bool isValidElementType(::mlir::Type type);
+  }];
+
+  let extraClassOf = [{
+    return $_type.hasTrait<::mlir::BaseMemRefType::Trait>();
+  }];
+}
+
 #endif // MLIR_IR_BUILTINTYPEINTERFACES_TD_
diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h
index df1e02732617d..4f3365492f720 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.h
+++ b/mlir/include/mlir/IR/BuiltinTypes.h
@@ -43,108 +43,6 @@ template <typename ConcreteType>
 class ValueSemantics
     : public TypeTrait::TraitBase<ConcreteType, ValueSemantics> {};
 
-//===----------------------------------------------------------------------===//
-// TensorType
-//===----------------------------------------------------------------------===//
-
-/// Tensor types represent multi-dimensional arrays, and have two variants:
-/// RankedTensorType and UnrankedTensorType.
-/// Note: This class attaches the ShapedType trait to act as a mixin to
-///       provide many useful utility functions. This inheritance has no effect
-///       on derived tensor types.
-class TensorType : public Type, public ShapedType::Trait<TensorType> {
-public:
-  using Type::Type;
-
-  /// Returns the element type of this tensor type.
-  Type getElementType() const;
-
-  /// Returns if this type is ranked, i.e. it has a known number of dimensions.
-  bool hasRank() const;
-
-  /// Returns the shape of this tensor type.
-  ArrayRef<int64_t> getShape() const;
-
-  /// Clone this type with the given shape and element type. If the
-  /// provided shape is `std::nullopt`, the current shape of the type is used.
-  TensorType cloneWith(std::optional<ArrayRef<int64_t>> shape,
-                       Type elementType) const;
-
-  // Make sure that base class overloads are visible.
-  using ShapedType::Trait<TensorType>::clone;
-
-  /// Return a clone of this type with the given new shape and element type.
-  /// The returned type is ranked, even if this type is unranked.
-  RankedTensorType clone(ArrayRef<int64_t> shape, Type elementType) const;
-
-  /// Return a clone of this type with the given new shape. The returned type
-  /// is ranked, even if this type is unranked.
-  RankedTensorType clone(ArrayRef<int64_t> shape) const;
-
-  /// Return true if the specified element type is ok in a tensor.
-  static bool isValidElementType(Type type);
-
-  /// Methods for support type inquiry through isa, cast, and dyn_cast.
-  static bool classof(Type type);
-
-  /// Allow implicit conversion to ShapedType.
-  operator ShapedType() const { return llvm::cast<ShapedType>(*this); }
-};
-
-//===----------------------------------------------------------------------===//
-// BaseMemRefType
-//===----------------------------------------------------------------------===//
-
-/// This class provides a shared interface for ranked and unranked memref types.
-/// Note: This class attaches the ShapedType trait to act as a mixin to
-///       provide many useful utility functions. This inheritance has no effect
-///       on derived memref types.
-class BaseMemRefType : public Type, public ShapedType::Trait<BaseMemRefType> {
-public:
-  using Type::Type;
-
-  /// Returns the element type of this memref type.
-  Type getElementType() const;
-
-  /// Returns if this type is ranked, i.e. it has a known number of dimensions.
-  bool hasRank() const;
-
-  /// Returns the shape of this memref type.
-  ArrayRef<int64_t> getShape() const;
-
-  /// Clone this type with the given shape and element type. If the
-  /// provided shape is `std::nullopt`, the current shape of the type is used.
-  BaseMemRefType cloneWith(std::optional<ArrayRef<int64_t>> shape,
-                           Type elementType) const;
-
-  // Make sure that base class overloads are visible.
-  using ShapedType::Trait<BaseMemRefType>::clone;
-
-  /// Return a clone of this type with the given new shape and element type.
-  /// The returned type is ranked, even if this type is unranked.
-  MemRefType clone(ArrayRef<int64_t> shape, Type elementType) const;
-
-  /// Return a clone of this type with the given new shape. The returned type
-  /// is ranked, even if this type is unranked.
-  MemRefType clone(ArrayRef<int64_t> shape) const;
-
-  /// Return true if the specified element type is ok in a memref.
-  static bool isValidElementType(Type type);
-
-  /// Methods for support type inquiry through isa, cast, and dyn_cast.
-  static bool classof(Type type);
-
-  /// Returns the memory space in which data referred to by this memref resides.
-  Attribute getMemorySpace() const;
-
-  /// [deprecated] Returns the memory space in old raw integer representation.
-  /// New `Attribute getMemorySpace()` method should be used instead.
-  unsigned getMemorySpaceAsInt() const;
-
-  /// Allow implicit conversion to ShapedType.
-  operator ShapedType() const { return llvm::cast<ShapedType>(*this); }
-};
-
 } // namespace mlir
 
 //===----------------------------------------------------------------------===//
@@ -390,10 +288,6 @@ class FixedVectorType : public VectorType {
 // Deferred Method Definitions
 //===----------------------------------------------------------------------===//
 
-inline bool BaseMemRefType::classof(Type type) {
-  return llvm::isa<MemRefType, UnrankedMemRefType>(type);
-}
-
 inline bool BaseMemRefType::isValidElementType(Type type) {
   return type.isIntOrIndexOrFloat() ||
          llvm::isa<ComplexType, MemRefType, VectorType, UnrankedMemRefType>(
@@ -401,10 +295,6 @@ inline bool BaseMemRefType::isValidElementType(Type type) {
          llvm::isa<MemRefElementTypeInterface>(type);
 }
 
-inline bool TensorType::classof(Type type) {
-  return llvm::isa<RankedTensorType, UnrankedTensorType>(type);
-}
-
 //===----------------------------------------------------------------------===//
 // Type Utilities
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index af474b3e3ec47..575ae6a263b1b 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -542,8 +542,8 @@ def Builtin_Integer : Builtin_Type<"Integer", "integer"> {
 //===----------------------------------------------------------------------===//
 
 def Builtin_MemRef : Builtin_Type<"MemRef", "memref", [
-    ShapedTypeInterface
-  ], "BaseMemRefType"> {
+    BaseMemRefTypeInterface
+  ]> {
   let summary = "Shaped reference to a region of memory";
   let description = [{
     Syntax:
@@ -794,7 +794,7 @@ def Builtin_MemRef : Builtin_Type<"MemRef", "memref", [
       "unsigned":$memorySpaceInd)>
   ];
   let extraClassDeclaration = [{
-    using BaseMemRefType::clone;
+    using ShapedType::Trait<MemRefType>::clone;
     using ShapedType::Trait<MemRefType>::getElementTypeBitWidth;
     using ShapedType::Trait<MemRefType>::getRank;
     using ShapedType::Trait<MemRefType>::getNumElements;
@@ -854,6 +854,13 @@ def Builtin_MemRef : Builtin_Type<"MemRef", "memref", [
     /// Return "true" if the last dimension has a static unit stride. Also
     /// return "true" for types with no strides.
     bool isLastDimUnitStride();
+
+    /// Returns if this type is ranked (always true).
+    bool hasRank() const { return true; }
+
+    /// Returns a clone of this type with the given shape and element
+    /// type. If a shape is not provided, the current shape of the type is used.
+    MemRefType cloneWith(std::optional<ArrayRef<int64_t>> shape, Type elementType) const;
   }];
   let skipDefaultBuilders = 1;
   let genVerifyDecl = 1;
@@ -934,8 +941,8 @@ def Builtin_Opaque : Builtin_Type<"Opaque", "opaque"> {
 //===----------------------------------------------------------------------===//
 
 def Builtin_RankedTensor : Builtin_Type<"RankedTensor", "tensor", [
-    ShapedTypeInterface, ValueSemantics
-  ], "TensorType"> {
+    TensorTypeInterface, ValueSemantics
+  ]> {
   let summary = "Multi-dimensional array with a fixed number of dimensions";
   let description = [{
     Syntax:
@@ -1016,7 +1023,7 @@ def Builtin_RankedTensor : Builtin_Type<"RankedTensor", "tensor", [
     }]>
   ];
   let extraClassDeclaration = [{
-    using TensorType::clone;
+    using ShapedType::Trait<RankedTensorType>::clone;
     using ShapedType::Trait<RankedTensorType>::getElementTypeBitWidth;
     using ShapedType::Trait<RankedTensorType>::getRank;
     using ShapedType::Trait<RankedTensorType>::getNumElements;
@@ -1033,7 +1040,7 @@ def Builtin_RankedTensor : Builtin_Type<"RankedTensor", "tensor", [
     /// Return a clone of this type with the given new element type and the same
     /// shape as this type.
     RankedTensorType clone(::mlir::Type elementType) {
-      return ::llvm::cast<RankedTensorType>(cloneWith(getShape(), elementType));
+      return cloneWith(getShape(), elementType);
     }
 
     /// Return a clone of this type without the encoding.
@@ -1041,6 +1048,13 @@ def Builtin_RankedTensor : Builtin_Type<"RankedTensor", "tensor", [
       return RankedTensorType::get(getShape(), getElementType());
     }
 
+    /// Returns if this type is ranked (always true).
+    bool hasRank() const { return true; }
+
+    /// Returns a clone of this type with the given shape and element
+    /// type. If a shape is not provided, the current shape of the type is used.
+    RankedTensorType cloneWith(std::optional<ArrayRef<int64_t>> shape, Type elementType) const;
+
     /// Return a clone of this type with the given new encoding and the same
     /// shape and element type as this type.
     RankedTensorType cloneWithEncoding(::mlir::Attribute encoding) {
@@ -1123,8 +1137,8 @@ def Builtin_Tuple : Builtin_Type<"Tuple", "tuple"> {
 //===----------------------------------------------------------------------===//
 
 def Builtin_UnrankedMemRef : Builtin_Type<"UnrankedMemRef", "unranked_memref", [
-    ShapedTypeInterface
-  ], "BaseMemRefType"> {
+    BaseMemRefTypeInterface
+  ]> {
   let summary = "Shaped reference, with unknown rank, to a region of memory";
   let description = [{
     Syntax:
@@ -1170,7 +1184,7 @@ def Builtin_UnrankedMemRef : Builtin_Type<"UnrankedMemRef", "unranked_memref", [
     }]>
   ];
   let extraClassDeclaration = [{
-    using BaseMemRefType::clone;
+    using ShapedType::Trait<UnrankedMemRefType>::clone;
     using ShapedType::Trait<UnrankedMemRefType>::getElementTypeBitWidth;
     using ShapedType::Trait<UnrankedMemRefType>::getRank;
     using ShapedType::Trait<UnrankedMemRefType>::getNumElements;
@@ -1186,11 +1200,12 @@ def Builtin_UnrankedMemRef : Builtin_Type<"UnrankedMemRef", "unranked_memref", [
     /// New `Attribute getMemorySpace()` method should be used instead.
     unsigned getMemorySpaceAsInt() const;
 
-    /// Return a clone of this type with the given new element type and the same
-    /// shape as this type.
-    MemRefType clone(::mlir::Type elementType) {
-      return ::llvm::cast<MemRefType>(cloneWith(getShape(), elementType));
-    }
+    /// Returns if this type is ranked (always false).
+    bool hasRank() const { return false; }
+
+    /// Returns a clone of this type with the given shape and element
+    /// type. If a shape is not provided, the current shape of the type is used.
+    BaseMemRefType cloneWith(std::optional<ArrayRef<int64_t>> shape, Type elementType) const;
   }];
   let skipDefaultBuilders = 1;
   let genVerifyDecl = 1;
@@ -1201,8 +1216,8 @@ def Builtin_UnrankedMemRef : Builtin_Type<"UnrankedMemRef", "unranked_memref", [
 //===----------------------------------------------------------------------===//
 
 def Builtin_UnrankedTensor : Builtin_Type<"UnrankedTensor", "unranked_tensor", [
-    ShapedTypeInterface, ValueSemantics
-  ], "TensorType"> {
+    TensorTypeInterface, ValueSemantics
+  ]> {
   let summary = "Multi-dimensional array with unknown dimensions";
   let description = [{
     Syntax:
@@ -1229,7 +1244,7 @@ def Builtin_UnrankedTensor : Builtin_Type<"UnrankedTensor", "unranked_tensor", [
     }]>
   ];
   let extraClassDeclaration = [{
-    using TensorType::clone;
+    using ShapedType::Trait<UnrankedTensorType>::clone;
     using ShapedType::Trait<UnrankedTensorType>::getElementTypeBitWidth;
     using ShapedType::Trait<UnrankedTensorType>::getRank;
     using ShapedType::Trait<UnrankedTensorType>::getNumElements;
@@ -1240,6 +1255,13 @@ def Builtin_UnrankedTensor : Builtin_Type<"UnrankedTensor", "unranked_tensor", [
     using ShapedType::Trait<UnrankedTensorType>::getDynamicDimIndex;
 
     ArrayRef<int64_t> getShape() const { return std::nullopt; }
+
+    /// Returns if this type is ranked (always false).
+    bool hasRank() const { return false; }
+
+    /// Returns a clone of this type with the given shape and element
+    /// type. If a shape is not provided, the current shape of the type is used.
+    TensorType cloneWith(std::optional<ArrayRef<int64_t>> shape, Type elementType) const;
   }];
   let skipDefaultBuilders = 1;
   let genVerifyDecl = 1;
diff --git a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
index 5f23a33049f87..1d1bcee8600a8 100644
--- a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
+++ b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
@@ -41,7 +41,7 @@ TensorType inferReshapeInputType(TypedValue<TensorType> input,
   // 0D tensor. While such construct is not incorrect on its own, bufferization
   // cannot properly handle it at the moment, so we avoid it.
   SmallVector<int64_t> shape(input.getType().getRank(), 1);
-  return input.getType().clone(shape);
+  return mlir::cast<TensorType>(input.getType().clone(shape));
 }
 
 // Infer the result type of 'tensor.expand_shape' in the collapse-expand
@@ -51,7 +51,7 @@ TensorType inferReshapeExpandedType(TensorType inputType,
   // Special case for 0D output tensor. Note: Watch out when using Type::clone()
   // with just '{}', as it will invoke the incorrect overload.
   if (newShape.empty())
-    return inputType.clone(ArrayRef<int64_t>{});
+    return mlir::cast<TensorType>(inputType.clone(ArrayRef<int64_t>{}));
 
   // Check if the input is static, and if so, get its total size
   bool inputIsStatic = inputType.hasStaticShape();
@@ -98,7 +98,7 @@ TensorType inferReshapeExpandedType(TensorType inputType,
   assert(!inputIsStatic || resultIsStatic);
 
   // Create result type
-  return inputType.clone(resultShape);
+  return mlir::cast<TensorType>(inputType.clone(resultShape));
 }
 
 // Infer the result type of 'tensor.collapse_shape' in the collapse-expand
@@ -108,11 +108,11 @@ TensorType inferReshapeCollapsedType(TensorType lhsType, Ten...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/133053


More information about the Mlir-commits mailing list