[Mlir-commits] [mlir] [mlir] Convert TensorType and BaseMemRefType to interfaces (PR #133053)

Andrei Golubev llvmlistbot at llvm.org
Wed Mar 26 02:00:22 PDT 2025


https://github.com/andrey-golubev created https://github.com/llvm/llvm-project/pull/133053

Existing design assumes "TensorType" is only a built-in (un)ranked tensor and "BaseMemRefType" is only a built-in (un)ranked memref. This means that the generic logic operating on "tensors" and "memrefs" is limited to just built-ins, no compatible user types allowed. For instance, this becomes important in one-shot bufferization when converting "user tensor" to "user memref" via the common infrastructure.

Remove this behaviour - that seems accidental - by following the footsteps of ShapedType (see 676bfb2a226e705d801016bb433b68a1e09a1e10). As with ShapedType, "tensor" and "memref" seem to always aspire to be interfaces.

>From 18d0dbac63e4a45f493ad52fe3d5f54d6ae24c91 Mon Sep 17 00:00:00 2001
From: "Golubev, Andrey" <andrey.golubev at intel.com>
Date: Tue, 25 Mar 2025 16:10:06 +0000
Subject: [PATCH] [mlir] Convert TensorType and BaseMemRefType to interfaces

Existing design assumes "TensorType" is only a built-in (un)ranked
tensor and "BaseMemRefType" is only a built-in (un)ranked memref. This
means that the generic logic operating on "tensors" and "memrefs" is
limited to just built-ins, no compatible user types allowed. For
instance, this becomes important in one-shot bufferization when
converting "user tensor" to "user memref" via the common infrastructure.

Remove this behaviour - that seems accidental - by following the
footsteps of ShapedType (see 676bfb2a226e705d801016bb433b68a1e09a1e10).
As with ShapedType, "tensor" and "memref" seem to always aspire to be
interfaces.
---
 .../mlir/Dialect/XeGPU/IR/XeGPUTypes.td       |  12 +-
 mlir/include/mlir/IR/BuiltinTypeInterfaces.td |  68 +++++++++-
 mlir/include/mlir/IR/BuiltinTypes.h           | 110 ----------------
 mlir/include/mlir/IR/BuiltinTypes.td          |  58 ++++++---
 .../Conversion/TosaToTensor/TosaToTensor.cpp  |  12 +-
 .../IR/BufferizableOpInterface.cpp            |   3 +-
 .../BufferizableOpInterfaceImpl.cpp           |  27 ++--
 mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp    |   7 +
 mlir/lib/IR/BuiltinTypes.cpp                  | 121 +++++-------------
 mlir/test/lib/Dialect/Test/TestTypeDefs.td    |  46 +++++++
 mlir/unittests/IR/InterfaceTest.cpp           |  34 +++++
 11 files changed, 250 insertions(+), 248 deletions(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
index ccd91a928e1dd..248ef9f855b14 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
@@ -31,7 +31,7 @@ class XeGPUTypeDef<string name, string typeMnemonic, list<Trait> traits = [],
 }
 
 def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc",
-        [ShapedTypeInterface], "::mlir::TensorType"> {
+        [TensorTypeInterface]> {
   let summary = "TensorDesc describing regions of interested data.";
   let description = [{
     TensorDesc is a type designed to describe regions of the interested data as well as some
@@ -105,7 +105,6 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc",
   ];
 
   let extraClassDeclaration = [{
-    using TensorType::clone;
     using mlir::ShapedType::Trait<TensorDescType>::getElementTypeBitWidth;
     using mlir::ShapedType::Trait<TensorDescType>::getRank;
     using mlir::ShapedType::Trait<TensorDescType>::getNumElements;
@@ -115,8 +114,11 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc",
     using mlir::ShapedType::Trait<TensorDescType>::getDimSize;
     using mlir::ShapedType::Trait<TensorDescType>::getDynamicDimIndex;
 
+    TensorDescType cloneWith(std::optional<ArrayRef<int64_t>> shape, Type elementType) const;
+    bool hasRank() const { return true; }
+
     TensorDescType clone(::mlir::Type elementType) {
-      return llvm::cast<TensorDescType>(cloneWith(getShape(), elementType));
+      return cloneWith(getShape(), elementType);
     }
 
     BlockTensorDescAttr getEncodingAsBlockTensorDescAttr() const {
@@ -144,7 +146,7 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc",
       return MemorySpace::Global;
     }
 
-    int getArrayLength() {
+    int getArrayLength() const {
       auto attr = getEncoding();
       auto block_attr = mlir::dyn_cast_if_present<BlockTensorDescAttr>(attr);
       assert((!attr || block_attr) && "invalid on non BlockTensorDescAttr.");
@@ -154,7 +156,7 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc",
       return 1;
     }
 
-    bool getBoundaryCheck() {
+    bool getBoundaryCheck() const {
       auto attr = getEncoding();
       auto block_attr = mlir::dyn_cast_if_present<BlockTensorDescAttr>(attr);
       assert((!attr || block_attr) && "invalid on non BlockTensorDescAttr.");
diff --git a/mlir/include/mlir/IR/BuiltinTypeInterfaces.td b/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
index 8aa2c55570153..a26b7f25fcf10 100644
--- a/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
+++ b/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
@@ -143,21 +143,21 @@ def ShapedTypeInterface : TypeInterface<"ShapedType"> {
 
     /// Return the number of elements present in the given shape.
     static int64_t getNumElements(ArrayRef<int64_t> shape);
+  }];
 
+  let extraSharedClassDeclaration = [{
     /// Return a clone of this type with the given new shape and element type.
     /// The returned type is ranked, even if this type is unranked.
     auto clone(::llvm::ArrayRef<int64_t> shape, Type elementType) {
-      return cloneWith(shape, elementType);
+      return $_type.cloneWith(shape, elementType);
     }
 
     /// Return a clone of this type with the given new shape. The returned type
     /// is ranked, even if this type is unranked.
     auto clone(::llvm::ArrayRef<int64_t> shape) {
-      return cloneWith(shape, getElementType());
+      return $_type.cloneWith(shape, $_type.getElementType());
     }
-  }];
 
-  let extraSharedClassDeclaration = [{
     /// Return a clone of this type with the given new element type. The
     /// returned type is ranked if and only if this type is ranked. In that
     /// case, the returned type has the same shape as this type.
@@ -227,4 +227,64 @@ def ShapedTypeInterface : TypeInterface<"ShapedType"> {
   }];
 }
 
+//===----------------------------------------------------------------------===//
+// TensorTypeInterface
+//===----------------------------------------------------------------------===//
+
+def TensorTypeInterface : TypeInterface<"TensorType", [ShapedTypeInterface]> {
+  let cppNamespace = "::mlir";
+  let description = [{
+    This interface provides a shared interface type for ranked, unranked and any
+    user-specified tensor types.
+
+    This interface attaches the ShapedTypeInterface to act as a mixin to
+    provide many useful utility functions.
+  }];
+
+  let extraClassDeclaration = [{
+    /// Return true if the specified element type is ok in a tensor.
+    static bool isValidElementType(::mlir::Type type);
+  }];
+
+  let extraClassOf = [{
+    return $_type.hasTrait<::mlir::TensorType::Trait>();
+  }];
+}
+
+//===----------------------------------------------------------------------===//
+// BaseMemRefTypeInterface
+//===----------------------------------------------------------------------===//
+
+def BaseMemRefTypeInterface : TypeInterface<"BaseMemRefType", [ShapedTypeInterface]> {
+  let cppNamespace = "::mlir";
+  let description = [{
+    This interface provides a shared interface type for ranked, unranked and any
+    user-specified memref types.
+
+    This interface attaches the ShapedTypeInterface to act as a mixin to
+    provide many useful utility functions.
+  }];
+
+  let methods = [
+    InterfaceMethod<[{
+      Returns the memory space in which data referred to by this memref resides.
+    }],
+    "::mlir::Attribute", "getMemorySpace">,
+    InterfaceMethod<[{
+      [deprecated] Returns the memory space in old raw integer representation.
+      New `Attribute getMemorySpace()` method should be used instead.
+    }],
+    "unsigned", "getMemorySpaceAsInt">,
+  ];
+
+  let extraClassDeclaration = [{
+    /// Return true if the specified element type is ok in a memref.
+    static bool isValidElementType(::mlir::Type type);
+  }];
+
+  let extraClassOf = [{
+    return $_type.hasTrait<::mlir::BaseMemRefType::Trait>();
+  }];
+}
+
 #endif // MLIR_IR_BUILTINTYPEINTERFACES_TD_
diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h
index df1e02732617d..4f3365492f720 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.h
+++ b/mlir/include/mlir/IR/BuiltinTypes.h
@@ -43,108 +43,6 @@ template <typename ConcreteType>
 class ValueSemantics
     : public TypeTrait::TraitBase<ConcreteType, ValueSemantics> {};
 
-//===----------------------------------------------------------------------===//
-// TensorType
-//===----------------------------------------------------------------------===//
-
-/// Tensor types represent multi-dimensional arrays, and have two variants:
-/// RankedTensorType and UnrankedTensorType.
-/// Note: This class attaches the ShapedType trait to act as a mixin to
-///       provide many useful utility functions. This inheritance has no effect
-///       on derived tensor types.
-class TensorType : public Type, public ShapedType::Trait<TensorType> {
-public:
-  using Type::Type;
-
-  /// Returns the element type of this tensor type.
-  Type getElementType() const;
-
-  /// Returns if this type is ranked, i.e. it has a known number of dimensions.
-  bool hasRank() const;
-
-  /// Returns the shape of this tensor type.
-  ArrayRef<int64_t> getShape() const;
-
-  /// Clone this type with the given shape and element type. If the
-  /// provided shape is `std::nullopt`, the current shape of the type is used.
-  TensorType cloneWith(std::optional<ArrayRef<int64_t>> shape,
-                       Type elementType) const;
-
-  // Make sure that base class overloads are visible.
-  using ShapedType::Trait<TensorType>::clone;
-
-  /// Return a clone of this type with the given new shape and element type.
-  /// The returned type is ranked, even if this type is unranked.
-  RankedTensorType clone(ArrayRef<int64_t> shape, Type elementType) const;
-
-  /// Return a clone of this type with the given new shape. The returned type
-  /// is ranked, even if this type is unranked.
-  RankedTensorType clone(ArrayRef<int64_t> shape) const;
-
-  /// Return true if the specified element type is ok in a tensor.
-  static bool isValidElementType(Type type);
-
-  /// Methods for support type inquiry through isa, cast, and dyn_cast.
-  static bool classof(Type type);
-
-  /// Allow implicit conversion to ShapedType.
-  operator ShapedType() const { return llvm::cast<ShapedType>(*this); }
-};
-
-//===----------------------------------------------------------------------===//
-// BaseMemRefType
-//===----------------------------------------------------------------------===//
-
-/// This class provides a shared interface for ranked and unranked memref types.
-/// Note: This class attaches the ShapedType trait to act as a mixin to
-///       provide many useful utility functions. This inheritance has no effect
-///       on derived memref types.
-class BaseMemRefType : public Type, public ShapedType::Trait<BaseMemRefType> {
-public:
-  using Type::Type;
-
-  /// Returns the element type of this memref type.
-  Type getElementType() const;
-
-  /// Returns if this type is ranked, i.e. it has a known number of dimensions.
-  bool hasRank() const;
-
-  /// Returns the shape of this memref type.
-  ArrayRef<int64_t> getShape() const;
-
-  /// Clone this type with the given shape and element type. If the
-  /// provided shape is `std::nullopt`, the current shape of the type is used.
-  BaseMemRefType cloneWith(std::optional<ArrayRef<int64_t>> shape,
-                           Type elementType) const;
-
-  // Make sure that base class overloads are visible.
-  using ShapedType::Trait<BaseMemRefType>::clone;
-
-  /// Return a clone of this type with the given new shape and element type.
-  /// The returned type is ranked, even if this type is unranked.
-  MemRefType clone(ArrayRef<int64_t> shape, Type elementType) const;
-
-  /// Return a clone of this type with the given new shape. The returned type
-  /// is ranked, even if this type is unranked.
-  MemRefType clone(ArrayRef<int64_t> shape) const;
-
-  /// Return true if the specified element type is ok in a memref.
-  static bool isValidElementType(Type type);
-
-  /// Methods for support type inquiry through isa, cast, and dyn_cast.
-  static bool classof(Type type);
-
-  /// Returns the memory space in which data referred to by this memref resides.
-  Attribute getMemorySpace() const;
-
-  /// [deprecated] Returns the memory space in old raw integer representation.
-  /// New `Attribute getMemorySpace()` method should be used instead.
-  unsigned getMemorySpaceAsInt() const;
-
-  /// Allow implicit conversion to ShapedType.
-  operator ShapedType() const { return llvm::cast<ShapedType>(*this); }
-};
-
 } // namespace mlir
 
 //===----------------------------------------------------------------------===//
@@ -390,10 +288,6 @@ class FixedVectorType : public VectorType {
 // Deferred Method Definitions
 //===----------------------------------------------------------------------===//
 
-inline bool BaseMemRefType::classof(Type type) {
-  return llvm::isa<MemRefType, UnrankedMemRefType>(type);
-}
-
 inline bool BaseMemRefType::isValidElementType(Type type) {
   return type.isIntOrIndexOrFloat() ||
          llvm::isa<ComplexType, MemRefType, VectorType, UnrankedMemRefType>(
@@ -401,10 +295,6 @@ inline bool BaseMemRefType::isValidElementType(Type type) {
          llvm::isa<MemRefElementTypeInterface>(type);
 }
 
-inline bool TensorType::classof(Type type) {
-  return llvm::isa<RankedTensorType, UnrankedTensorType>(type);
-}
-
 //===----------------------------------------------------------------------===//
 // Type Utilities
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index af474b3e3ec47..575ae6a263b1b 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -542,8 +542,8 @@ def Builtin_Integer : Builtin_Type<"Integer", "integer"> {
 //===----------------------------------------------------------------------===//
 
 def Builtin_MemRef : Builtin_Type<"MemRef", "memref", [
-    ShapedTypeInterface
-  ], "BaseMemRefType"> {
+    BaseMemRefTypeInterface
+  ]> {
   let summary = "Shaped reference to a region of memory";
   let description = [{
     Syntax:
@@ -794,7 +794,7 @@ def Builtin_MemRef : Builtin_Type<"MemRef", "memref", [
       "unsigned":$memorySpaceInd)>
   ];
   let extraClassDeclaration = [{
-    using BaseMemRefType::clone;
+    using ShapedType::Trait<MemRefType>::clone;
     using ShapedType::Trait<MemRefType>::getElementTypeBitWidth;
     using ShapedType::Trait<MemRefType>::getRank;
     using ShapedType::Trait<MemRefType>::getNumElements;
@@ -854,6 +854,13 @@ def Builtin_MemRef : Builtin_Type<"MemRef", "memref", [
     /// Return "true" if the last dimension has a static unit stride. Also
     /// return "true" for types with no strides.
     bool isLastDimUnitStride();
+
+    /// Returns if this type is ranked (always true).
+    bool hasRank() const { return true; }
+
+    /// Returns a clone of this type with the given shape and element
+    /// type. If a shape is not provided, the current shape of the type is used.
+    MemRefType cloneWith(std::optional<ArrayRef<int64_t>> shape, Type elementType) const;
   }];
   let skipDefaultBuilders = 1;
   let genVerifyDecl = 1;
@@ -934,8 +941,8 @@ def Builtin_Opaque : Builtin_Type<"Opaque", "opaque"> {
 //===----------------------------------------------------------------------===//
 
 def Builtin_RankedTensor : Builtin_Type<"RankedTensor", "tensor", [
-    ShapedTypeInterface, ValueSemantics
-  ], "TensorType"> {
+    TensorTypeInterface, ValueSemantics
+  ]> {
   let summary = "Multi-dimensional array with a fixed number of dimensions";
   let description = [{
     Syntax:
@@ -1016,7 +1023,7 @@ def Builtin_RankedTensor : Builtin_Type<"RankedTensor", "tensor", [
     }]>
   ];
   let extraClassDeclaration = [{
-    using TensorType::clone;
+    using ShapedType::Trait<RankedTensorType>::clone;
     using ShapedType::Trait<RankedTensorType>::getElementTypeBitWidth;
     using ShapedType::Trait<RankedTensorType>::getRank;
     using ShapedType::Trait<RankedTensorType>::getNumElements;
@@ -1033,7 +1040,7 @@ def Builtin_RankedTensor : Builtin_Type<"RankedTensor", "tensor", [
     /// Return a clone of this type with the given new element type and the same
     /// shape as this type.
     RankedTensorType clone(::mlir::Type elementType) {
-      return ::llvm::cast<RankedTensorType>(cloneWith(getShape(), elementType));
+      return cloneWith(getShape(), elementType);
     }
 
     /// Return a clone of this type without the encoding.
@@ -1041,6 +1048,13 @@ def Builtin_RankedTensor : Builtin_Type<"RankedTensor", "tensor", [
       return RankedTensorType::get(getShape(), getElementType());
     }
 
+    /// Returns if this type is ranked (always true).
+    bool hasRank() const { return true; }
+
+    /// Returns a clone of this type with the given shape and element
+    /// type. If a shape is not provided, the current shape of the type is used.
+    RankedTensorType cloneWith(std::optional<ArrayRef<int64_t>> shape, Type elementType) const;
+
     /// Return a clone of this type with the given new encoding and the same
     /// shape and element type as this type.
     RankedTensorType cloneWithEncoding(::mlir::Attribute encoding) {
@@ -1123,8 +1137,8 @@ def Builtin_Tuple : Builtin_Type<"Tuple", "tuple"> {
 //===----------------------------------------------------------------------===//
 
 def Builtin_UnrankedMemRef : Builtin_Type<"UnrankedMemRef", "unranked_memref", [
-    ShapedTypeInterface
-  ], "BaseMemRefType"> {
+    BaseMemRefTypeInterface
+  ]> {
   let summary = "Shaped reference, with unknown rank, to a region of memory";
   let description = [{
     Syntax:
@@ -1170,7 +1184,7 @@ def Builtin_UnrankedMemRef : Builtin_Type<"UnrankedMemRef", "unranked_memref", [
     }]>
   ];
   let extraClassDeclaration = [{
-    using BaseMemRefType::clone;
+    using ShapedType::Trait<UnrankedMemRefType>::clone;
     using ShapedType::Trait<UnrankedMemRefType>::getElementTypeBitWidth;
     using ShapedType::Trait<UnrankedMemRefType>::getRank;
     using ShapedType::Trait<UnrankedMemRefType>::getNumElements;
@@ -1186,11 +1200,12 @@ def Builtin_UnrankedMemRef : Builtin_Type<"UnrankedMemRef", "unranked_memref", [
     /// New `Attribute getMemorySpace()` method should be used instead.
     unsigned getMemorySpaceAsInt() const;
 
-    /// Return a clone of this type with the given new element type and the same
-    /// shape as this type.
-    MemRefType clone(::mlir::Type elementType) {
-      return ::llvm::cast<MemRefType>(cloneWith(getShape(), elementType));
-    }
+    /// Returns if this type is ranked (always false).
+    bool hasRank() const { return false; }
+
+    /// Returns a clone of this type with the given shape and element
+    /// type. If a shape is not provided, the current shape of the type is used.
+    BaseMemRefType cloneWith(std::optional<ArrayRef<int64_t>> shape, Type elementType) const;
   }];
   let skipDefaultBuilders = 1;
   let genVerifyDecl = 1;
@@ -1201,8 +1216,8 @@ def Builtin_UnrankedMemRef : Builtin_Type<"UnrankedMemRef", "unranked_memref", [
 //===----------------------------------------------------------------------===//
 
 def Builtin_UnrankedTensor : Builtin_Type<"UnrankedTensor", "unranked_tensor", [
-    ShapedTypeInterface, ValueSemantics
-  ], "TensorType"> {
+    TensorTypeInterface, ValueSemantics
+  ]> {
   let summary = "Multi-dimensional array with unknown dimensions";
   let description = [{
     Syntax:
@@ -1229,7 +1244,7 @@ def Builtin_UnrankedTensor : Builtin_Type<"UnrankedTensor", "unranked_tensor", [
     }]>
   ];
   let extraClassDeclaration = [{
-    using TensorType::clone;
+    using ShapedType::Trait<UnrankedTensorType>::clone;
     using ShapedType::Trait<UnrankedTensorType>::getElementTypeBitWidth;
     using ShapedType::Trait<UnrankedTensorType>::getRank;
     using ShapedType::Trait<UnrankedTensorType>::getNumElements;
@@ -1240,6 +1255,13 @@ def Builtin_UnrankedTensor : Builtin_Type<"UnrankedTensor", "unranked_tensor", [
     using ShapedType::Trait<UnrankedTensorType>::getDynamicDimIndex;
 
     ArrayRef<int64_t> getShape() const { return std::nullopt; }
+
+    /// Returns if this type is ranked (always false).
+    bool hasRank() const { return false; }
+
+    /// Returns a clone of this type with the given shape and element
+    /// type. If a shape is not provided, the current shape of the type is used.
+    TensorType cloneWith(std::optional<ArrayRef<int64_t>> shape, Type elementType) const;
   }];
   let skipDefaultBuilders = 1;
   let genVerifyDecl = 1;
diff --git a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
index 5f23a33049f87..1d1bcee8600a8 100644
--- a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
+++ b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
@@ -41,7 +41,7 @@ TensorType inferReshapeInputType(TypedValue<TensorType> input,
   // 0D tensor. While such construct is not incorrect on its own, bufferization
   // cannot properly handle it at the moment, so we avoid it.
   SmallVector<int64_t> shape(input.getType().getRank(), 1);
-  return input.getType().clone(shape);
+  return mlir::cast<TensorType>(input.getType().clone(shape));
 }
 
 // Infer the result type of 'tensor.expand_shape' in the collapse-expand
@@ -51,7 +51,7 @@ TensorType inferReshapeExpandedType(TensorType inputType,
   // Special case for 0D output tensor. Note: Watch out when using Type::clone()
   // with just '{}', as it will invoke the incorrect overload.
   if (newShape.empty())
-    return inputType.clone(ArrayRef<int64_t>{});
+    return mlir::cast<TensorType>(inputType.clone(ArrayRef<int64_t>{}));
 
   // Check if the input is static, and if so, get its total size
   bool inputIsStatic = inputType.hasStaticShape();
@@ -98,7 +98,7 @@ TensorType inferReshapeExpandedType(TensorType inputType,
   assert(!inputIsStatic || resultIsStatic);
 
   // Create result type
-  return inputType.clone(resultShape);
+  return mlir::cast<TensorType>(inputType.clone(resultShape));
 }
 
 // Infer the result type of 'tensor.collapse_shape' in the collapse-expand
@@ -108,11 +108,11 @@ TensorType inferReshapeCollapsedType(TensorType lhsType, TensorType rhsType) {
   auto rhsShape = rhsType.getShape();
 
   if (lhsShape.empty() || rhsShape.empty())
-    return lhsType.clone(ArrayRef<int64_t>{});
+    return mlir::cast<TensorType>(lhsType.clone(ArrayRef<int64_t>{}));
 
   if (ShapedType::isDynamicShape(lhsShape) ||
       ShapedType::isDynamicShape(rhsShape))
-    return lhsType.clone({ShapedType::kDynamic});
+    return mlir::cast<TensorType>(lhsType.clone({ShapedType::kDynamic}));
 
   SmallVector<int64_t> intermediateShape;
   unsigned currLhsDim = 0, currRhsDim = 0;
@@ -149,7 +149,7 @@ TensorType inferReshapeCollapsedType(TensorType lhsType, TensorType rhsType) {
     assert(rhsShape[currRhsDim] == 1);
   }
 
-  return lhsType.clone(intermediateShape);
+  return mlir::cast<TensorType>(lhsType.clone(intermediateShape));
 }
 
 SmallVector<ReassociationExprs>
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
index 99ffa62c41a4d..e15f81a1ef433 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
@@ -735,8 +735,7 @@ void bufferization::replaceOpWithBufferizedValues(RewriterBase &rewriter,
     if (llvm::isa<TensorType>(opResult.getType())) {
       // The OpResult is a tensor. Such values are replaced with memrefs during
       // bufferization.
-      assert((llvm::isa<MemRefType>(replacement.getType()) ||
-              llvm::isa<UnrankedMemRefType>(replacement.getType())) &&
+      assert(llvm::isa<BaseMemRefType>(replacement.getType()) &&
              "tensor op result should be replaced with a memref value");
       // The existing uses of the OpResult still expect a tensor. Insert a
       // ToTensorOp. Throughout bufferization, this ToTensorOp will gradually
diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index 4ac6eca586961..8f9438babc070 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -77,9 +77,9 @@ struct CastOpInterface
     // Case 3: Ranked tensor -> ranked tensor. The offsets and strides do not
     // change.
     auto rankedResultType = cast<RankedTensorType>(castOp.getType());
-    return MemRefType::get(
+    return llvm::cast<BaseMemRefType>(MemRefType::get(
         rankedResultType.getShape(), rankedResultType.getElementType(),
-        llvm::cast<MemRefType>(*maybeSrcBufferType).getLayout(), memorySpace);
+        llvm::cast<MemRefType>(*maybeSrcBufferType).getLayout(), memorySpace));
   }
 
   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
@@ -157,8 +157,9 @@ struct CollapseShapeOpInterface
           tensorResultType, srcBufferType.getMemorySpace());
     }
 
-    return memref::CollapseShapeOp::computeCollapsedType(
-        srcBufferType, collapseShapeOp.getReassociationIndices());
+    return llvm::cast<BaseMemRefType>(
+        memref::CollapseShapeOp::computeCollapsedType(
+            srcBufferType, collapseShapeOp.getReassociationIndices()));
   }
 
   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
@@ -325,7 +326,7 @@ struct ExpandShapeOpInterface
         expandShapeOp.getReassociationIndices());
     if (failed(maybeResultType))
       return failure();
-    return *maybeResultType;
+    return llvm::cast<BaseMemRefType>(*maybeResultType);
   }
 
   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
@@ -405,10 +406,11 @@ struct ExtractSliceOpInterface
     SmallVector<OpFoldResult> mixedOffsets = extractSliceOp.getMixedOffsets();
     SmallVector<OpFoldResult> mixedSizes = extractSliceOp.getMixedSizes();
     SmallVector<OpFoldResult> mixedStrides = extractSliceOp.getMixedStrides();
-    return memref::SubViewOp::inferRankReducedResultType(
-        extractSliceOp.getType().getShape(),
-        llvm::cast<MemRefType>(*srcMemrefType), mixedOffsets, mixedSizes,
-        mixedStrides);
+    return mlir::cast<BaseMemRefType>(
+        memref::SubViewOp::inferRankReducedResultType(
+            extractSliceOp.getType().getShape(),
+            llvm::cast<MemRefType>(*srcMemrefType), mixedOffsets, mixedSizes,
+            mixedStrides));
   }
 };
 
@@ -746,9 +748,10 @@ struct PadOpInterface
     if (failed(maybeSrcBufferType))
       return failure();
     MemRefLayoutAttrInterface layout;
-    return MemRefType::get(padOp.getResultType().getShape(),
-                           padOp.getResultType().getElementType(), layout,
-                           maybeSrcBufferType->getMemorySpace());
+    return llvm::cast<BaseMemRefType>(
+        MemRefType::get(padOp.getResultType().getShape(),
+                        padOp.getResultType().getElementType(), layout,
+                        maybeSrcBufferType->getMemorySpace()));
   }
 
   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index 78c242571935c..1fcf9df052ae3 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -396,6 +396,13 @@ FailureOr<VectorType> TensorDescType::getDistributedVectorType() {
                          getElementType());
 }
 
+TensorDescType TensorDescType::cloneWith(std::optional<ArrayRef<int64_t>> shape,
+                                         Type elementType) const {
+  return TensorDescType::get(shape.value_or(this->getShape()), elementType,
+                             this->getArrayLength(), this->getBoundaryCheck(),
+                             this->getMemorySpace(), this->getSgMap());
+}
+
 } // namespace xegpu
 } // namespace mlir
 
diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp
index 3924d082f0628..02e7038a75fff 100644
--- a/mlir/lib/IR/BuiltinTypes.cpp
+++ b/mlir/lib/IR/BuiltinTypes.cpp
@@ -12,6 +12,7 @@
 #include "mlir/IR/AffineMap.h"
 #include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/BuiltinDialect.h"
+#include "mlir/IR/BuiltinTypeInterfaces.h"
 #include "mlir/IR/Diagnostics.h"
 #include "mlir/IR/Dialect.h"
 #include "mlir/IR/TensorEncoding.h"
@@ -256,45 +257,6 @@ VectorType VectorType::cloneWith(std::optional<ArrayRef<int64_t>> shape,
 // TensorType
 //===----------------------------------------------------------------------===//
 
-Type TensorType::getElementType() const {
-  return llvm::TypeSwitch<TensorType, Type>(*this)
-      .Case<RankedTensorType, UnrankedTensorType>(
-          [](auto type) { return type.getElementType(); });
-}
-
-bool TensorType::hasRank() const {
-  return !llvm::isa<UnrankedTensorType>(*this);
-}
-
-ArrayRef<int64_t> TensorType::getShape() const {
-  return llvm::cast<RankedTensorType>(*this).getShape();
-}
-
-TensorType TensorType::cloneWith(std::optional<ArrayRef<int64_t>> shape,
-                                 Type elementType) const {
-  if (llvm::dyn_cast<UnrankedTensorType>(*this)) {
-    if (shape)
-      return RankedTensorType::get(*shape, elementType);
-    return UnrankedTensorType::get(elementType);
-  }
-
-  auto rankedTy = llvm::cast<RankedTensorType>(*this);
-  if (!shape)
-    return RankedTensorType::get(rankedTy.getShape(), elementType,
-                                 rankedTy.getEncoding());
-  return RankedTensorType::get(shape.value_or(rankedTy.getShape()), elementType,
-                               rankedTy.getEncoding());
-}
-
-RankedTensorType TensorType::clone(::llvm::ArrayRef<int64_t> shape,
-                                   Type elementType) const {
-  return ::llvm::cast<RankedTensorType>(cloneWith(shape, elementType));
-}
-
-RankedTensorType TensorType::clone(::llvm::ArrayRef<int64_t> shape) const {
-  return ::llvm::cast<RankedTensorType>(cloneWith(shape, getElementType()));
-}
-
 // Check if "elementType" can be an element type of a tensor.
 static LogicalResult
 checkTensorElementType(function_ref<InFlightDiagnostic()> emitError,
@@ -317,6 +279,12 @@ bool TensorType::isValidElementType(Type type) {
 //===----------------------------------------------------------------------===//
 // RankedTensorType
 //===----------------------------------------------------------------------===//
+RankedTensorType
+RankedTensorType::cloneWith(std::optional<ArrayRef<int64_t>> shape,
+                            Type elementType) const {
+  return RankedTensorType::get(shape.value_or(this->getShape()), elementType,
+                               this->getEncoding());
+}
 
 LogicalResult
 RankedTensorType::verify(function_ref<InFlightDiagnostic()> emitError,
@@ -335,6 +303,13 @@ RankedTensorType::verify(function_ref<InFlightDiagnostic()> emitError,
 // UnrankedTensorType
 //===----------------------------------------------------------------------===//
 
+TensorType UnrankedTensorType::cloneWith(std::optional<ArrayRef<int64_t>> shape,
+                                         Type elementType) const {
+  if (shape)
+    return RankedTensorType::get(*shape, elementType);
+  return UnrankedTensorType::get(elementType);
+}
+
 LogicalResult
 UnrankedTensorType::verify(function_ref<InFlightDiagnostic()> emitError,
                            Type elementType) {
@@ -342,65 +317,18 @@ UnrankedTensorType::verify(function_ref<InFlightDiagnostic()> emitError,
 }
 
 //===----------------------------------------------------------------------===//
-// BaseMemRefType
+// MemRefType
 //===----------------------------------------------------------------------===//
 
-Type BaseMemRefType::getElementType() const {
-  return llvm::TypeSwitch<BaseMemRefType, Type>(*this)
-      .Case<MemRefType, UnrankedMemRefType>(
-          [](auto type) { return type.getElementType(); });
-}
-
-bool BaseMemRefType::hasRank() const {
-  return !llvm::isa<UnrankedMemRefType>(*this);
-}
-
-ArrayRef<int64_t> BaseMemRefType::getShape() const {
-  return llvm::cast<MemRefType>(*this).getShape();
-}
-
-BaseMemRefType BaseMemRefType::cloneWith(std::optional<ArrayRef<int64_t>> shape,
-                                         Type elementType) const {
-  if (llvm::dyn_cast<UnrankedMemRefType>(*this)) {
-    if (!shape)
-      return UnrankedMemRefType::get(elementType, getMemorySpace());
-    MemRefType::Builder builder(*shape, elementType);
-    builder.setMemorySpace(getMemorySpace());
-    return builder;
-  }
-
-  MemRefType::Builder builder(llvm::cast<MemRefType>(*this));
+MemRefType MemRefType::cloneWith(std::optional<ArrayRef<int64_t>> shape,
+                                 Type elementType) const {
+  MemRefType::Builder builder(*this);
   if (shape)
     builder.setShape(*shape);
   builder.setElementType(elementType);
-  return builder;
-}
-
-MemRefType BaseMemRefType::clone(::llvm::ArrayRef<int64_t> shape,
-                                 Type elementType) const {
-  return ::llvm::cast<MemRefType>(cloneWith(shape, elementType));
-}
-
-MemRefType BaseMemRefType::clone(::llvm::ArrayRef<int64_t> shape) const {
-  return ::llvm::cast<MemRefType>(cloneWith(shape, getElementType()));
+  return MemRefType(builder);
 }
 
-Attribute BaseMemRefType::getMemorySpace() const {
-  if (auto rankedMemRefTy = llvm::dyn_cast<MemRefType>(*this))
-    return rankedMemRefTy.getMemorySpace();
-  return llvm::cast<UnrankedMemRefType>(*this).getMemorySpace();
-}
-
-unsigned BaseMemRefType::getMemorySpaceAsInt() const {
-  if (auto rankedMemRefTy = llvm::dyn_cast<MemRefType>(*this))
-    return rankedMemRefTy.getMemorySpaceAsInt();
-  return llvm::cast<UnrankedMemRefType>(*this).getMemorySpaceAsInt();
-}
-
-//===----------------------------------------------------------------------===//
-// MemRefType
-//===----------------------------------------------------------------------===//
-
 std::optional<llvm::SmallDenseSet<unsigned>>
 mlir::computeRankReductionMask(ArrayRef<int64_t> originalShape,
                                ArrayRef<int64_t> reducedShape,
@@ -888,6 +816,17 @@ bool MemRefType::isLastDimUnitStride() {
 // UnrankedMemRefType
 //===----------------------------------------------------------------------===//
 
+BaseMemRefType
+UnrankedMemRefType::cloneWith(std::optional<ArrayRef<int64_t>> shape,
+                              Type elementType) const {
+  if (!shape)
+    return UnrankedMemRefType::get(elementType, getMemorySpace());
+
+  MemRefType::Builder builder(*shape, elementType);
+  builder.setMemorySpace(getMemorySpace());
+  return MemRefType(builder);
+}
+
 unsigned UnrankedMemRefType::getMemorySpaceAsInt() const {
   return detail::getMemorySpaceAsInt(getMemorySpace());
 }
diff --git a/mlir/test/lib/Dialect/Test/TestTypeDefs.td b/mlir/test/lib/Dialect/Test/TestTypeDefs.td
index f1c31658c13ac..61fab9c889be7 100644
--- a/mlir/test/lib/Dialect/Test/TestTypeDefs.td
+++ b/mlir/test/lib/Dialect/Test/TestTypeDefs.td
@@ -403,4 +403,50 @@ def TestTypeOpAsmTypeInterface : Test_Type<"TestTypeOpAsmTypeInterface",
   let mnemonic = "op_asm_type_interface";
 }
 
+def TestTensorType : Test_Type<"TestTensor", [TensorTypeInterface]> {
+  let mnemonic = "test_tensor";
+  let parameters = (ins
+    ArrayRefParameter<"int64_t">:$shape,
+    "mlir::Type":$elementType
+  );
+  let assemblyFormat = "`<` `[` $shape `]` `,` $elementType `>`";
+
+  let extraClassDeclaration = [{
+    // ShapedTypeInterface:
+    bool hasRank() const {
+      return true;
+    }
+    test::TestTensorType cloneWith(std::optional<llvm::ArrayRef<int64_t>> shape, mlir::Type elementType) const {
+      return test::TestTensorType::get(getContext(), shape.value_or(getShape()), elementType);
+    }
+  }];
+}
+
+def TestMemrefType : Test_Type<"TestMemref", [BaseMemRefTypeInterface]> {
+  let mnemonic = "test_memref";
+  let parameters = (ins
+    ArrayRefParameter<"int64_t">:$shape,
+    "mlir::Type":$elementType,
+    DefaultValuedParameter<"mlir::Attribute", "nullptr">:$memSpace
+  );
+  let assemblyFormat = "`<` `[` $shape `]` `,` $elementType (`,` $memSpace^)? `>`";
+
+  let extraClassDeclaration = [{
+    // ShapedTypeInterface:
+    bool hasRank() const {
+      return true;
+    }
+    test::TestMemrefType cloneWith(std::optional<llvm::ArrayRef<int64_t>> shape, mlir::Type elementType) const {
+      return test::TestMemrefType::get(getContext(), shape.value_or(getShape()), elementType, getMemSpace());
+    }
+
+    // BaseMemRefTypeInterface:
+    mlir::Attribute getMemorySpace() const {
+      return getMemSpace();
+    }
+    // [deprecated]
+    unsigned getMemorySpaceAsInt() const { return 0; }
+  }];
+}
+
 #endif // TEST_TYPEDEFS
diff --git a/mlir/unittests/IR/InterfaceTest.cpp b/mlir/unittests/IR/InterfaceTest.cpp
index 42196b003e7da..004367f7759f2 100644
--- a/mlir/unittests/IR/InterfaceTest.cpp
+++ b/mlir/unittests/IR/InterfaceTest.cpp
@@ -9,6 +9,7 @@
 #include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/BuiltinDialect.h"
 #include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/BuiltinTypeInterfaces.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/OwningOpRef.h"
 #include "gtest/gtest.h"
@@ -84,3 +85,36 @@ TEST(InterfaceTest, TestImplicitConversion) {
   typeA = typeB;
   EXPECT_EQ(typeA, typeB);
 }
+
+TEST(InterfaceTest, TestCustomTensorIsTensorType) {
+  MLIRContext context;
+  context.loadDialect<test::TestDialect>();
+
+  auto customTensorType = test::TestTensorType::get(
+      &context, {1, 2, 3}, mlir::IntegerType::get(&context, 32));
+  EXPECT_TRUE(mlir::isa<mlir::TensorType>(customTensorType));
+
+  auto customCloneType = customTensorType.clone({3, 4, 5});
+  EXPECT_EQ(customTensorType.getElementType(),
+            customCloneType.getElementType());
+  EXPECT_TRUE(mlir::isa<mlir::TensorType>(customCloneType));
+  EXPECT_TRUE(mlir::isa<test::TestTensorType>(customCloneType));
+}
+
+TEST(InterfaceTest, TestCustomMemrefIsBaseMemref) {
+  MLIRContext context;
+  context.loadDialect<test::TestDialect>();
+
+  auto customMemrefType = test::TestMemrefType::get(
+      &context, {1, 2, 3}, mlir::IntegerType::get(&context, 32),
+      mlir::StringAttr::get(&context, "some_memspace"));
+  EXPECT_TRUE(mlir::isa<mlir::BaseMemRefType>(customMemrefType));
+
+  auto customCloneType = customMemrefType.clone({3, 4, 5});
+  EXPECT_EQ(customMemrefType.getElementType(),
+            customCloneType.getElementType());
+  EXPECT_TRUE(mlir::isa<mlir::BaseMemRefType>(customCloneType));
+  EXPECT_TRUE(mlir::isa<test::TestMemrefType>(customCloneType));
+  EXPECT_EQ(customMemrefType.getMemorySpace(),
+            mlir::cast<test::TestMemrefType>(customCloneType).getMemorySpace());
+}



More information about the Mlir-commits mailing list