[Mlir-commits] [mlir] c0261eb - [mlir][IR] Improve `clone` function return type of shaped types

Matthias Springer llvmlistbot at llvm.org
Thu May 25 00:27:43 PDT 2023


Author: Matthias Springer
Date: 2023-05-25T09:27:33+02:00
New Revision: c0261eb02bb092d8842d3c57297e23f31d75cb96

URL: https://github.com/llvm/llvm-project/commit/c0261eb02bb092d8842d3c57297e23f31d75cb96
DIFF: https://github.com/llvm/llvm-project/commit/c0261eb02bb092d8842d3c57297e23f31d75cb96.diff

LOG: [mlir][IR] Improve `clone` function return type of shaped types

There are `clone` overloads that take a shape as a parameter. These overloads are guaranteed to return a ranked shaped type.

`TensorType::clone`/`BaseMemRefType::clone` used to always return a `TensorType`/`BaseMemRefType`. The variants that take a shape parameter now return a `RankedTensorType`/`MemRefType`. Better static type information can make extra casts at the call site obsolete.

E.g.:
```
{TensorType/RankedTensorType} t;
t.clone({1, 2})  // now returns RankedTensorType instead of TensorType
```

Also improve documentation for `clone`.

Differential Revision: https://reviews.llvm.org/D150865

Added: 
    

Modified: 
    mlir/include/mlir/IR/BuiltinTypeInterfaces.td
    mlir/include/mlir/IR/BuiltinTypes.h
    mlir/include/mlir/IR/BuiltinTypes.td
    mlir/lib/IR/BuiltinTypes.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/IR/BuiltinTypeInterfaces.td b/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
index bb38985715c09..db38e2e1bce22 100644
--- a/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
+++ b/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
@@ -59,8 +59,13 @@ def ShapedTypeInterface : TypeInterface<"ShapedType"> {
   }];
   let methods = [
     InterfaceMethod<[{
-      Returns a clone of this type with the given shape and element
-      type. If a shape is not provided, the current shape of the type is used.
+      Returns a clone of this type with the given shape and element type.
+
+      If no shape is provided, the shape of this type is used. In that case, if
+      this type is unranked, so is the resulting type.
+
+      If a shape is provided, the resulting type is always ranked, even if this
+      type is unranked.
     }],
     "::mlir::ShapedType", "cloneWith", (ins
       "::std::optional<::llvm::ArrayRef<int64_t>>":$shape,
@@ -89,7 +94,7 @@ def ShapedTypeInterface : TypeInterface<"ShapedType"> {
 
     /// Whether the given dimension size indicates a dynamic dimension.
     static constexpr bool isDynamic(int64_t dValue) {
-	return dValue == kDynamic;
+      return dValue == kDynamic;
     }
 
     /// Whether the given shape has any size that indicates a dynamic dimension.
@@ -99,18 +104,24 @@ def ShapedTypeInterface : TypeInterface<"ShapedType"> {
 
     /// Return the number of elements present in the given shape.
     static int64_t getNumElements(ArrayRef<int64_t> shape);
-  }];
 
-  let extraSharedClassDeclaration = [{
     /// Return a clone of this type with the given new shape and element type.
+    /// The returned type is ranked, even if this type is unranked.
     auto clone(::llvm::ArrayRef<int64_t> shape, Type elementType) {
-      return $_type.cloneWith(shape, elementType);
+      return cloneWith(shape, elementType);
     }
-    /// Return a clone of this type with the given new shape.
+
+    /// Return a clone of this type with the given new shape. The returned type
+    /// is ranked, even if this type is unranked.
     auto clone(::llvm::ArrayRef<int64_t> shape) {
-      return $_type.cloneWith(shape, $_type.getElementType());
+      return cloneWith(shape, getElementType());
     }
-    /// Return a clone of this type with the given new element type.
+  }];
+
+  let extraSharedClassDeclaration = [{
+    /// Return a clone of this type with the given new element type. The
+    /// returned type is ranked if and only if this type is ranked. In that
+    /// case, the returned type has the same shape as this type.
     auto clone(::mlir::Type elementType) {
       return $_type.cloneWith(/*shape=*/std::nullopt, elementType);
     }

diff  --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h
index 4fc82dd7a8e9d..79313b6facda0 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.h
+++ b/mlir/include/mlir/IR/BuiltinTypes.h
@@ -27,6 +27,8 @@ class AffineMap;
 class FloatType;
 class IndexType;
 class IntegerType;
+class MemRefType;
+class RankedTensorType;
 class StringAttr;
 class TypeRange;
 
@@ -95,6 +97,17 @@ class TensorType : public Type, public ShapedType::Trait<TensorType> {
   TensorType cloneWith(std::optional<ArrayRef<int64_t>> shape,
                        Type elementType) const;
 
+  // Make sure that base class overloads are visible.
+  using ShapedType::Trait<TensorType>::clone;
+
+  /// Return a clone of this type with the given new shape and element type.
+  /// The returned type is ranked, even if this type is unranked.
+  RankedTensorType clone(ArrayRef<int64_t> shape, Type elementType) const;
+
+  /// Return a clone of this type with the given new shape. The returned type
+  /// is ranked, even if this type is unranked.
+  RankedTensorType clone(ArrayRef<int64_t> shape) const;
+
   /// Return true if the specified element type is ok in a tensor.
   static bool isValidElementType(Type type);
 
@@ -131,6 +144,17 @@ class BaseMemRefType : public Type, public ShapedType::Trait<BaseMemRefType> {
   BaseMemRefType cloneWith(std::optional<ArrayRef<int64_t>> shape,
                            Type elementType) const;
 
+  // Make sure that base class overloads are visible.
+  using ShapedType::Trait<BaseMemRefType>::clone;
+
+  /// Return a clone of this type with the given new shape and element type.
+  /// The returned type is ranked, even if this type is unranked.
+  MemRefType clone(ArrayRef<int64_t> shape, Type elementType) const;
+
+  /// Return a clone of this type with the given new shape. The returned type
+  /// is ranked, even if this type is unranked.
+  MemRefType clone(ArrayRef<int64_t> shape) const;
+
   /// Return true if the specified element type is ok in a memref.
   static bool isValidElementType(Type type);
 

diff  --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index 218c240743ae6..58a0156d54a1f 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -629,7 +629,7 @@ def Builtin_MemRef : Builtin_Type<"MemRef", [
       "unsigned":$memorySpaceInd)>
   ];
   let extraClassDeclaration = [{
-    using ShapedType::Trait<MemRefType>::clone;
+    using BaseMemRefType::clone;
     using ShapedType::Trait<MemRefType>::getElementTypeBitWidth;
     using ShapedType::Trait<MemRefType>::getRank;
     using ShapedType::Trait<MemRefType>::getNumElements;
@@ -794,7 +794,7 @@ def Builtin_RankedTensor : Builtin_Type<"RankedTensor", [
     }]>
   ];
   let extraClassDeclaration = [{
-    using ShapedType::Trait<RankedTensorType>::clone;
+    using TensorType::clone;
     using ShapedType::Trait<RankedTensorType>::getElementTypeBitWidth;
     using ShapedType::Trait<RankedTensorType>::getRank;
     using ShapedType::Trait<RankedTensorType>::getNumElements;
@@ -807,6 +807,12 @@ def Builtin_RankedTensor : Builtin_Type<"RankedTensor", [
     /// This is a builder type that keeps local references to arguments.
     /// Arguments that are passed into the builder must outlive the builder.
     class Builder;
+
+    /// Return a clone of this type with the given new element type and the same
+    /// shape as this type.
+    RankedTensorType clone(::mlir::Type elementType) {
+      return ::llvm::cast<RankedTensorType>(cloneWith(getShape(), elementType));
+    }
   }];
   let skipDefaultBuilders = 1;
   let genVerifyDecl = 1;
@@ -931,7 +937,7 @@ def Builtin_UnrankedMemRef : Builtin_Type<"UnrankedMemRef", [
     }]>
   ];
   let extraClassDeclaration = [{
-    using ShapedType::Trait<UnrankedMemRefType>::clone;
+    using BaseMemRefType::clone;
     using ShapedType::Trait<UnrankedMemRefType>::getElementTypeBitWidth;
     using ShapedType::Trait<UnrankedMemRefType>::getRank;
     using ShapedType::Trait<UnrankedMemRefType>::getNumElements;
@@ -946,6 +952,12 @@ def Builtin_UnrankedMemRef : Builtin_Type<"UnrankedMemRef", [
     /// [deprecated] Returns the memory space in old raw integer representation.
     /// New `Attribute getMemorySpace()` method should be used instead.
     unsigned getMemorySpaceAsInt() const;
+
+    /// Return a clone of this type with the given new element type and the same
+    /// shape as this type.
+    MemRefType clone(::mlir::Type elementType) {
+      return ::llvm::cast<MemRefType>(cloneWith(getShape(), elementType));
+    }
   }];
   let skipDefaultBuilders = 1;
   let genVerifyDecl = 1;
@@ -984,7 +996,7 @@ def Builtin_UnrankedTensor : Builtin_Type<"UnrankedTensor", [
     }]>
   ];
   let extraClassDeclaration = [{
-    using ShapedType::Trait<UnrankedTensorType>::clone;
+    using TensorType::clone;
     using ShapedType::Trait<UnrankedTensorType>::getElementTypeBitWidth;
     using ShapedType::Trait<UnrankedTensorType>::getRank;
     using ShapedType::Trait<UnrankedTensorType>::getNumElements;

diff  --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp
index b46ea8a2e6e10..c816e4a6dbcf3 100644
--- a/mlir/lib/IR/BuiltinTypes.cpp
+++ b/mlir/lib/IR/BuiltinTypes.cpp
@@ -291,6 +291,15 @@ TensorType TensorType::cloneWith(std::optional<ArrayRef<int64_t>> shape,
                                rankedTy.getEncoding());
 }
 
+RankedTensorType TensorType::clone(::llvm::ArrayRef<int64_t> shape,
+                                   Type elementType) const {
+  return ::llvm::cast<RankedTensorType>(cloneWith(shape, elementType));
+}
+
+RankedTensorType TensorType::clone(::llvm::ArrayRef<int64_t> shape) const {
+  return ::llvm::cast<RankedTensorType>(cloneWith(shape, getElementType()));
+}
+
 // Check if "elementType" can be an element type of a tensor.
 static LogicalResult
 checkTensorElementType(function_ref<InFlightDiagnostic()> emitError,
@@ -370,6 +379,15 @@ BaseMemRefType BaseMemRefType::cloneWith(std::optional<ArrayRef<int64_t>> shape,
   return builder;
 }
 
+MemRefType BaseMemRefType::clone(::llvm::ArrayRef<int64_t> shape,
+                                 Type elementType) const {
+  return ::llvm::cast<MemRefType>(cloneWith(shape, elementType));
+}
+
+MemRefType BaseMemRefType::clone(::llvm::ArrayRef<int64_t> shape) const {
+  return ::llvm::cast<MemRefType>(cloneWith(shape, getElementType()));
+}
+
 Attribute BaseMemRefType::getMemorySpace() const {
   if (auto rankedMemRefTy = dyn_cast<MemRefType>())
     return rankedMemRefTy.getMemorySpace();


        


More information about the Mlir-commits mailing list