[Mlir-commits] [mlir] c0261eb - [mlir][IR] Improve `clone` function return type of shaped types
Matthias Springer
llvmlistbot at llvm.org
Thu May 25 00:27:43 PDT 2023
Author: Matthias Springer
Date: 2023-05-25T09:27:33+02:00
New Revision: c0261eb02bb092d8842d3c57297e23f31d75cb96
URL: https://github.com/llvm/llvm-project/commit/c0261eb02bb092d8842d3c57297e23f31d75cb96
DIFF: https://github.com/llvm/llvm-project/commit/c0261eb02bb092d8842d3c57297e23f31d75cb96.diff
LOG: [mlir][IR] Improve `clone` function return type of shaped types
There are `clone` overloads that take a shape as a parameter. These overloads are guaranteed to return a ranked shaped type.
`TensorType::clone`/`BaseMemRefType::clone` used to always return a `TensorType`/`BaseMemRefType`. The variants that take a shape parameter now return a `RankedTensorType`/`MemRefType`. Better static type information can make extra casts at the call site obsolete.
E.g.:
```
{TensorType/RankedTensorType} t;
t.clone({1, 2}) // now returns RankedTensorType instead of TensorType
```
Also improve documentation for `clone`.
Differential Revision: https://reviews.llvm.org/D150865
Added:
Modified:
mlir/include/mlir/IR/BuiltinTypeInterfaces.td
mlir/include/mlir/IR/BuiltinTypes.h
mlir/include/mlir/IR/BuiltinTypes.td
mlir/lib/IR/BuiltinTypes.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/IR/BuiltinTypeInterfaces.td b/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
index bb38985715c09..db38e2e1bce22 100644
--- a/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
+++ b/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
@@ -59,8 +59,13 @@ def ShapedTypeInterface : TypeInterface<"ShapedType"> {
}];
let methods = [
InterfaceMethod<[{
- Returns a clone of this type with the given shape and element
- type. If a shape is not provided, the current shape of the type is used.
+ Returns a clone of this type with the given shape and element type.
+
+ If no shape is provided, the shape of this type is used. In that case, if
+ this type is unranked, so is the resulting type.
+
+ If a shape is provided, the resulting type is always ranked, even if this
+ type is unranked.
}],
"::mlir::ShapedType", "cloneWith", (ins
"::std::optional<::llvm::ArrayRef<int64_t>>":$shape,
@@ -89,7 +94,7 @@ def ShapedTypeInterface : TypeInterface<"ShapedType"> {
/// Whether the given dimension size indicates a dynamic dimension.
static constexpr bool isDynamic(int64_t dValue) {
- return dValue == kDynamic;
+ return dValue == kDynamic;
}
/// Whether the given shape has any size that indicates a dynamic dimension.
@@ -99,18 +104,24 @@ def ShapedTypeInterface : TypeInterface<"ShapedType"> {
/// Return the number of elements present in the given shape.
static int64_t getNumElements(ArrayRef<int64_t> shape);
- }];
- let extraSharedClassDeclaration = [{
/// Return a clone of this type with the given new shape and element type.
+ /// The returned type is ranked, even if this type is unranked.
auto clone(::llvm::ArrayRef<int64_t> shape, Type elementType) {
- return $_type.cloneWith(shape, elementType);
+ return cloneWith(shape, elementType);
}
- /// Return a clone of this type with the given new shape.
+
+ /// Return a clone of this type with the given new shape. The returned type
+ /// is ranked, even if this type is unranked.
auto clone(::llvm::ArrayRef<int64_t> shape) {
- return $_type.cloneWith(shape, $_type.getElementType());
+ return cloneWith(shape, getElementType());
}
- /// Return a clone of this type with the given new element type.
+ }];
+
+ let extraSharedClassDeclaration = [{
+ /// Return a clone of this type with the given new element type. The
+ /// returned type is ranked if and only if this type is ranked. In that
+ /// case, the returned type has the same shape as this type.
auto clone(::mlir::Type elementType) {
return $_type.cloneWith(/*shape=*/std::nullopt, elementType);
}
diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h
index 4fc82dd7a8e9d..79313b6facda0 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.h
+++ b/mlir/include/mlir/IR/BuiltinTypes.h
@@ -27,6 +27,8 @@ class AffineMap;
class FloatType;
class IndexType;
class IntegerType;
+class MemRefType;
+class RankedTensorType;
class StringAttr;
class TypeRange;
@@ -95,6 +97,17 @@ class TensorType : public Type, public ShapedType::Trait<TensorType> {
TensorType cloneWith(std::optional<ArrayRef<int64_t>> shape,
Type elementType) const;
+ // Make sure that base class overloads are visible.
+ using ShapedType::Trait<TensorType>::clone;
+
+ /// Return a clone of this type with the given new shape and element type.
+ /// The returned type is ranked, even if this type is unranked.
+ RankedTensorType clone(ArrayRef<int64_t> shape, Type elementType) const;
+
+ /// Return a clone of this type with the given new shape. The returned type
+ /// is ranked, even if this type is unranked.
+ RankedTensorType clone(ArrayRef<int64_t> shape) const;
+
/// Return true if the specified element type is ok in a tensor.
static bool isValidElementType(Type type);
@@ -131,6 +144,17 @@ class BaseMemRefType : public Type, public ShapedType::Trait<BaseMemRefType> {
BaseMemRefType cloneWith(std::optional<ArrayRef<int64_t>> shape,
Type elementType) const;
+ // Make sure that base class overloads are visible.
+ using ShapedType::Trait<BaseMemRefType>::clone;
+
+ /// Return a clone of this type with the given new shape and element type.
+ /// The returned type is ranked, even if this type is unranked.
+ MemRefType clone(ArrayRef<int64_t> shape, Type elementType) const;
+
+ /// Return a clone of this type with the given new shape. The returned type
+ /// is ranked, even if this type is unranked.
+ MemRefType clone(ArrayRef<int64_t> shape) const;
+
/// Return true if the specified element type is ok in a memref.
static bool isValidElementType(Type type);
diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index 218c240743ae6..58a0156d54a1f 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -629,7 +629,7 @@ def Builtin_MemRef : Builtin_Type<"MemRef", [
"unsigned":$memorySpaceInd)>
];
let extraClassDeclaration = [{
- using ShapedType::Trait<MemRefType>::clone;
+ using BaseMemRefType::clone;
using ShapedType::Trait<MemRefType>::getElementTypeBitWidth;
using ShapedType::Trait<MemRefType>::getRank;
using ShapedType::Trait<MemRefType>::getNumElements;
@@ -794,7 +794,7 @@ def Builtin_RankedTensor : Builtin_Type<"RankedTensor", [
}]>
];
let extraClassDeclaration = [{
- using ShapedType::Trait<RankedTensorType>::clone;
+ using TensorType::clone;
using ShapedType::Trait<RankedTensorType>::getElementTypeBitWidth;
using ShapedType::Trait<RankedTensorType>::getRank;
using ShapedType::Trait<RankedTensorType>::getNumElements;
@@ -807,6 +807,12 @@ def Builtin_RankedTensor : Builtin_Type<"RankedTensor", [
/// This is a builder type that keeps local references to arguments.
/// Arguments that are passed into the builder must outlive the builder.
class Builder;
+
+ /// Return a clone of this type with the given new element type and the same
+ /// shape as this type.
+ RankedTensorType clone(::mlir::Type elementType) {
+ return ::llvm::cast<RankedTensorType>(cloneWith(getShape(), elementType));
+ }
}];
let skipDefaultBuilders = 1;
let genVerifyDecl = 1;
@@ -931,7 +937,7 @@ def Builtin_UnrankedMemRef : Builtin_Type<"UnrankedMemRef", [
}]>
];
let extraClassDeclaration = [{
- using ShapedType::Trait<UnrankedMemRefType>::clone;
+ using BaseMemRefType::clone;
using ShapedType::Trait<UnrankedMemRefType>::getElementTypeBitWidth;
using ShapedType::Trait<UnrankedMemRefType>::getRank;
using ShapedType::Trait<UnrankedMemRefType>::getNumElements;
@@ -946,6 +952,12 @@ def Builtin_UnrankedMemRef : Builtin_Type<"UnrankedMemRef", [
/// [deprecated] Returns the memory space in old raw integer representation.
/// New `Attribute getMemorySpace()` method should be used instead.
unsigned getMemorySpaceAsInt() const;
+
+ /// Return a clone of this type with the given new element type and the same
+ /// shape as this type.
+ MemRefType clone(::mlir::Type elementType) {
+ return ::llvm::cast<MemRefType>(cloneWith(getShape(), elementType));
+ }
}];
let skipDefaultBuilders = 1;
let genVerifyDecl = 1;
@@ -984,7 +996,7 @@ def Builtin_UnrankedTensor : Builtin_Type<"UnrankedTensor", [
}]>
];
let extraClassDeclaration = [{
- using ShapedType::Trait<UnrankedTensorType>::clone;
+ using TensorType::clone;
using ShapedType::Trait<UnrankedTensorType>::getElementTypeBitWidth;
using ShapedType::Trait<UnrankedTensorType>::getRank;
using ShapedType::Trait<UnrankedTensorType>::getNumElements;
diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp
index b46ea8a2e6e10..c816e4a6dbcf3 100644
--- a/mlir/lib/IR/BuiltinTypes.cpp
+++ b/mlir/lib/IR/BuiltinTypes.cpp
@@ -291,6 +291,15 @@ TensorType TensorType::cloneWith(std::optional<ArrayRef<int64_t>> shape,
rankedTy.getEncoding());
}
+RankedTensorType TensorType::clone(::llvm::ArrayRef<int64_t> shape,
+ Type elementType) const {
+ return ::llvm::cast<RankedTensorType>(cloneWith(shape, elementType));
+}
+
+RankedTensorType TensorType::clone(::llvm::ArrayRef<int64_t> shape) const {
+ return ::llvm::cast<RankedTensorType>(cloneWith(shape, getElementType()));
+}
+
// Check if "elementType" can be an element type of a tensor.
static LogicalResult
checkTensorElementType(function_ref<InFlightDiagnostic()> emitError,
@@ -370,6 +379,15 @@ BaseMemRefType BaseMemRefType::cloneWith(std::optional<ArrayRef<int64_t>> shape,
return builder;
}
+MemRefType BaseMemRefType::clone(::llvm::ArrayRef<int64_t> shape,
+ Type elementType) const {
+ return ::llvm::cast<MemRefType>(cloneWith(shape, elementType));
+}
+
+MemRefType BaseMemRefType::clone(::llvm::ArrayRef<int64_t> shape) const {
+ return ::llvm::cast<MemRefType>(cloneWith(shape, getElementType()));
+}
+
Attribute BaseMemRefType::getMemorySpace() const {
if (auto rankedMemRefTy = dyn_cast<MemRefType>())
return rankedMemRefTy.getMemorySpace();
More information about the Mlir-commits
mailing list