[Mlir-commits] [mlir] eb7f355 - [mlir][NFC] Minor cleanups around ShapedType
Matthias Springer
llvmlistbot at llvm.org
Tue Apr 18 19:37:43 PDT 2023
Author: Matthias Springer
Date: 2023-04-19T11:30:45+09:00
New Revision: eb7f355725d1cd6d875446581def4d742f179838
URL: https://github.com/llvm/llvm-project/commit/eb7f355725d1cd6d875446581def4d742f179838
DIFF: https://github.com/llvm/llvm-project/commit/eb7f355725d1cd6d875446581def4d742f179838.diff
LOG: [mlir][NFC] Minor cleanups around ShapedType
* Remove unnecessary casts.
* Use concrete shaped types (e.g., `MemRefType`, `RankedTensorType`) instead of `ShapedType` when possible.
* Minor documentation cleanups.
Differential Revision: https://reviews.llvm.org/D148488
Added:
Modified:
mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
mlir/include/mlir/Interfaces/ViewLikeInterface.td
mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp
mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index 6a7f542887fd8..265e400045d8b 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -1390,18 +1390,16 @@ def MemRef_ReinterpretCastOp
let extraClassDeclaration = extraBaseClassDeclaration # [{
// The result of the op is always a ranked memref.
- MemRefType getType() { return getResult().getType().cast<MemRefType>(); }
+ MemRefType getType() { return getResult().getType(); }
Value getViewSource() { return getSource(); }
- /// Return the rank of the source ShapedType.
- unsigned getResultRank() {
- return getResult().getType().cast<ShapedType>().getRank();
- }
+ /// Return the rank of the result type.
+ unsigned getResultRank() { return getType().getRank(); }
/// Return the expected rank of each of the`static_offsets`, `static_sizes`
/// and `static_strides` attributes.
std::array<unsigned, 3> getArrayAttrMaxRanks() {
- unsigned resultRank = getResult().getType().cast<ShapedType>().getRank();
+ unsigned resultRank = getType().getRank();
return {1, resultRank, resultRank};
}
@@ -1830,8 +1828,7 @@ def SubViewOp : MemRef_OpWithOffsetSizesAndStrides<"subview", [
The representation based on offsets, sizes and strides support a
partially-static specification via attributes specified through the
`static_offsets`, `static_sizes` and `static_strides` arguments. A special
- sentinel value ShapedType::kDynamic and
- ShapedType::kDynamic encodes that the corresponding entry has
+ sentinel value ShapedType::kDynamic encodes that the corresponding entry has
a dynamic value.
A subview operation may additionally reduce the rank of the resulting view
@@ -2122,7 +2119,6 @@ def MemRef_TransposeOp : MemRef_Op<"transpose", [
let extraClassDeclaration = [{
static StringRef getPermutationAttrStrName() { return "permutation"; }
- ShapedType getShapedType() { return getIn().getType().cast<ShapedType>(); }
}];
let hasCustomAssemblyFormat = 1;
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index 8d0028d6d5343..0c589ce920fe0 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -226,7 +226,7 @@ def Tensor_ExtractOp : Tensor_Op<"extract", [
Pure,
TypesMatchWith<"result type matches element type of tensor",
"tensor", "result",
- "$_self.cast<ShapedType>().getElementType()">]> {
+ "$_self.cast<TensorType>().getElementType()">]> {
let summary = "element extraction operation";
let description = [{
The `tensor.extract` op reads a ranked tensor and returns one element as
@@ -281,8 +281,7 @@ def Tensor_ExtractSliceOp : Tensor_OpWithOffsetSizesAndStrides<"extract_slice",
The representation based on offsets, sizes and strides support a
partially-static specification via attributes specified through the
`static_offsets`, `static_sizes` and `static_strides` arguments. A special
- sentinel value ShapedType::kDynamic and
- ShapedType::kDynamic encodes that the corresponding entry has
+ sentinel value ShapedType::kDynamic encodes that the corresponding entry has
a dynamic value.
After buffer allocation, the "extract_slice" op is expected to lower into a
@@ -389,12 +388,12 @@ def Tensor_ExtractSliceOp : Tensor_OpWithOffsetSizesAndStrides<"extract_slice",
/// rank-reduced, from the source type and the static representation of
/// offsets, sizes and strides. Special sentinels encode the dynamic case.
static RankedTensorType inferResultType(
- ShapedType sourceShapedTensorType,
+ RankedTensorType sourceTensorType,
ArrayRef<int64_t> staticOffsets,
ArrayRef<int64_t> staticSizes,
ArrayRef<int64_t> staticStrides);
static RankedTensorType inferResultType(
- ShapedType sourceShapedTensorType,
+ RankedTensorType sourceTensorType,
ArrayRef<OpFoldResult> staticOffsets,
ArrayRef<OpFoldResult> staticSizes,
ArrayRef<OpFoldResult> staticStrides);
@@ -459,8 +458,8 @@ def Tensor_FromElementsOp : Tensor_Op<"from_elements", [
Pure,
TypesMatchWith<"operand types match result element type",
"result", "elements", "SmallVector<Type, 2>("
- "$_self.cast<ShapedType>().getNumElements(), "
- "$_self.cast<ShapedType>().getElementType())">
+ "$_self.cast<RankedTensorType>().getNumElements(), "
+ "$_self.cast<RankedTensorType>().getElementType())">
]> {
let summary = "tensor from elements operation.";
let description = [{
@@ -695,7 +694,7 @@ def Tensor_InsertOp : Tensor_Op<"insert", [
"$_self">,
TypesMatchWith<"scalar type matches element type of dest",
"dest", "scalar",
- "$_self.cast<ShapedType>().getElementType()">]> {
+ "$_self.cast<TensorType>().getElementType()">]> {
let summary = "element insertion operation";
let description = [{
The `tensor.insert` op inserts a scalar into a ranked tensor `dest` as
@@ -770,8 +769,7 @@ def Tensor_InsertSliceOp : Tensor_OpWithOffsetSizesAndStrides<"insert_slice", [
The representation based on offsets, sizes and strides support a
partially-static specification via attributes specified through the
`static_offsets`, `static_sizes` and `static_strides` arguments. A special
- sentinel value ShapedType::kDynamic and
- ShapedType::kDynamic encodes that the corresponding entry has
+ sentinel value ShapedType::kDynamic encodes that the corresponding entry has
a dynamic value.
After buffer allocation, the "insert_slice" op is expected to lower into a
@@ -1381,8 +1379,7 @@ def Tensor_ParallelInsertSliceOp : Tensor_Op<"parallel_insert_slice", [
The representation based on offsets, sizes and strides support a
partially-static specification via attributes specified through the
`static_offsets`, `static_sizes` and `static_strides` arguments. A special
- sentinel value ShapedType::kDynamic and
- ShapedType::kDynamic encodes that the corresponding entry has
+ sentinel value ShapedType::kDynamic encodes that the corresponding entry has
a dynamic value.
After buffer allocation, the "parallel_insert_slice" op is expected to lower
@@ -1790,10 +1787,10 @@ def Tensor_PackOp : Tensor_RelayoutOp<"pack", [
ArrayRef<OpFoldResult> innerTileDims, ArrayRef<int64_t> innerDimsPos,
ArrayRef<int64_t> outerDimsPerm = {});
- // Method to get the `ShapedType` of the result based on the inner tiles,
- // position of the inner tiles (innerDimsPos) and interchange vector of
- // outer loops (outerDimsPerm).
- static ShapedType inferPackedType(ShapedType sourceType,
+ // Method to get the `RankedTensorType` of the result based on the inner
+ // tiles, position of the inner tiles (innerDimsPos) and interchange vector
+ // of outer loops (outerDimsPerm).
+ static RankedTensorType inferPackedType(RankedTensorType sourceType,
ArrayRef<int64_t> innerTileSizes, ArrayRef<int64_t> innerDimsPos,
ArrayRef<int64_t> outerDimsPerm = {});
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 2a95ff243fbf9..f4508b9b3ddfd 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -529,7 +529,7 @@ def Vector_ExtractElementOp :
Vector_Op<"extractelement", [Pure,
TypesMatchWith<"result type matches element type of vector operand",
"vector", "result",
- "$_self.cast<ShapedType>().getElementType()">]>,
+ "$_self.cast<VectorType>().getElementType()">]>,
Arguments<(ins AnyVectorOfAnyRank:$vector,
Optional<AnySignlessIntegerOrIndex>:$position)>,
Results<(outs AnyType:$result)> {
@@ -644,7 +644,7 @@ def Vector_InsertElementOp :
Vector_Op<"insertelement", [Pure,
TypesMatchWith<"source operand type matches element type of result",
"result", "source",
- "$_self.cast<ShapedType>().getElementType()">,
+ "$_self.cast<VectorType>().getElementType()">,
AllTypesMatch<["dest", "result"]>]>,
Arguments<(ins AnyType:$source, AnyVectorOfAnyRank:$dest,
Optional<AnySignlessIntegerOrIndex>:$position)>,
@@ -1884,23 +1884,15 @@ def Vector_GatherOp :
: memref<16x16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
```
}];
+
let extraClassDeclaration = [{
- ShapedType getBaseType() {
- return getBase().getType().cast<ShapedType>();
- }
- VectorType getIndexVectorType() {
- return getIndexVec().getType().cast<VectorType>();
- }
- VectorType getMaskVectorType() {
- return getMask().getType().cast<VectorType>();
- }
- VectorType getPassThruVectorType() {
- return getPassThru().getType().cast<VectorType>();
- }
- VectorType getVectorType() {
- return getResult().getType().cast<VectorType>();
- }
+ ShapedType getBaseType() { return getBase().getType(); }
+ VectorType getIndexVectorType() { return getIndexVec().getType(); }
+ VectorType getMaskVectorType() { return getMask().getType(); }
+ VectorType getPassThruVectorType() { return getPassThru().getType(); }
+ VectorType getVectorType() { return getResult().getType(); }
}];
+
let assemblyFormat =
"$base `[` $indices `]` `[` $index_vec `]` `,` "
"$mask `,` $pass_thru attr-dict `:` type($base) `,` "
@@ -1960,20 +1952,14 @@ def Vector_ScatterOp :
: memref<16x16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>
```
}];
+
let extraClassDeclaration = [{
- MemRefType getMemRefType() {
- return getBase().getType().cast<MemRefType>();
- }
- VectorType getIndexVectorType() {
- return getIndexVec().getType().cast<VectorType>();
- }
- VectorType getMaskVectorType() {
- return getMask().getType().cast<VectorType>();
- }
- VectorType getVectorType() {
- return getValueToStore().getType().cast<VectorType>();
- }
+ MemRefType getMemRefType() { return getBase().getType(); }
+ VectorType getIndexVectorType() { return getIndexVec().getType(); }
+ VectorType getMaskVectorType() { return getMask().getType(); }
+ VectorType getVectorType() { return getValueToStore().getType(); }
}];
+
let assemblyFormat =
"$base `[` $indices `]` `[` $index_vec `]` `,` "
"$mask `,` $valueToStore attr-dict `:` type($base) `,` "
diff --git a/mlir/include/mlir/Interfaces/ViewLikeInterface.td b/mlir/include/mlir/Interfaces/ViewLikeInterface.td
index b5870af8c7936..4cddcc1764690 100644
--- a/mlir/include/mlir/Interfaces/ViewLikeInterface.td
+++ b/mlir/include/mlir/Interfaces/ViewLikeInterface.td
@@ -50,9 +50,8 @@ def OffsetSizeAndStrideOpInterface : OpInterface<"OffsetSizeAndStrideOpInterface
`getArrayAttrMaxRanks()`[0] (resp. [1], [2]).
3. if an entry of `static_offsets` (resp. `static_sizes`,
`static_strides`) is equal to a special sentinel value, namely
- `ShapedType::kDynamic` (resp. `ShapedType::kDynamic`,
- `ShapedType::kDynamic`), then the corresponding entry is
- a dynamic offset (resp. size, stride).
+ `ShapedType::kDynamic`, then the corresponding entry is a dynamic
+ offset (resp. size, stride).
4. a variadic `offset` (resp. `sizes`, `strides`) operand must be present
for each dynamic offset (resp. size, stride).
5. `offsets`, `sizes` and `strides` operands are specified in this order
diff --git a/mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp b/mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp
index e29c5d2565770..37aa6cf53c206 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp
@@ -128,7 +128,7 @@ bool AllocationOpLLVMLowering::isMemRefSizeMultipleOf(
const DataLayout *defaultLayout) const {
uint64_t sizeDivisor = getMemRefEltSizeInBytes(type, op, defaultLayout);
for (unsigned i = 0, e = type.getRank(); i < e; i++) {
- if (ShapedType::isDynamic(type.getDimSize(i)))
+ if (type.isDynamicDim(i))
continue;
sizeDivisor = sizeDivisor * type.getDimSize(i);
}
diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
index 2af5a2522566d..07d406bbb1c1a 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
@@ -154,10 +154,10 @@ struct ReallocOpLoweringBase : public AllocationOpLLVMLowering {
auto computeNumElements =
[&](MemRefType type, function_ref<Value()> getDynamicSize) -> Value {
// Compute number of elements.
- int64_t size = type.getShape()[0];
- Value numElements = ((size == ShapedType::kDynamic)
- ? getDynamicSize()
- : createIndexConstant(rewriter, loc, size));
+ Value numElements =
+ type.isDynamicDim(0)
+ ? getDynamicSize()
+ : createIndexConstant(rewriter, loc, type.getDimSize(0));
Type indexType = getIndexType();
if (numElements.getType() != indexType)
numElements = typeConverter->materializeTargetConversion(
@@ -987,7 +987,7 @@ struct MemRefCopyOpLowering : public ConvertOpToLLVMPattern<memref::CopyOp> {
auto targetType = op.getTarget().getType().cast<BaseMemRefType>();
// First make sure we have an unranked memref descriptor representation.
- auto makeUnranked = [&, this](Value ranked, BaseMemRefType type) {
+ auto makeUnranked = [&, this](Value ranked, MemRefType type) {
auto rank = rewriter.create<LLVM::ConstantOp>(loc, getIndexType(),
type.getRank());
auto *typeConverter = getTypeConverter();
@@ -1011,12 +1011,14 @@ struct MemRefCopyOpLowering : public ConvertOpToLLVMPattern<memref::CopyOp> {
auto stackSaveOp =
rewriter.create<LLVM::StackSaveOp>(loc, getVoidPtrType());
- Value unrankedSource = srcType.hasRank()
- ? makeUnranked(adaptor.getSource(), srcType)
- : adaptor.getSource();
- Value unrankedTarget = targetType.hasRank()
- ? makeUnranked(adaptor.getTarget(), targetType)
- : adaptor.getTarget();
+ auto srcMemRefType = srcType.dyn_cast<MemRefType>();
+ Value unrankedSource =
+ srcMemRefType ? makeUnranked(adaptor.getSource(), srcMemRefType)
+ : adaptor.getSource();
+ auto targetMemRefType = targetType.dyn_cast<MemRefType>();
+ Value unrankedTarget =
+ targetMemRefType ? makeUnranked(adaptor.getTarget(), targetMemRefType)
+ : adaptor.getTarget();
// Now promote the unranked descriptors to the stack.
auto one = rewriter.create<LLVM::ConstantOp>(loc, getIndexType(),
@@ -1390,11 +1392,11 @@ struct MemRefReshapeOpLowering
}
Value dimSize;
- int64_t size = targetMemRefType.getDimSize(i);
// If the size of this dimension is dynamic, then load it at runtime
// from the shape operand.
- if (!ShapedType::isDynamic(size)) {
- dimSize = createIndexConstant(rewriter, loc, size);
+ if (!targetMemRefType.isDynamicDim(i)) {
+ dimSize = createIndexConstant(rewriter, loc,
+ targetMemRefType.getDimSize(i));
} else {
Value shapeOp = reshapeOp.getShape();
Value index = createIndexConstant(rewriter, loc, i);
@@ -1589,7 +1591,8 @@ class TransposeOpLowering : public ConvertOpToLLVMPattern<memref::TransposeOp> {
return rewriter.replaceOp(transposeOp, {viewMemRef}), success();
auto targetMemRef = MemRefDescriptor::undef(
- rewriter, loc, typeConverter->convertType(transposeOp.getShapedType()));
+ rewriter, loc,
+ typeConverter->convertType(transposeOp.getIn().getType()));
// Copy the base and aligned pointers from the old descriptor to the new
// one.
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index da582a50da307..53b96803290f7 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -1178,7 +1178,7 @@ LogicalResult MapOp::verify() {
}
// The shape of each input must match the shape of the output.
- auto outputShape = getInit().getType().cast<ShapedType>().getShape();
+ auto outputShape = getInit().getType().getShape();
for (Type inputArgType : TypeRange{getInputs()}) {
auto inputElemShape = inputArgType.cast<ShapedType>().getShape();
if (inputElemShape != outputShape) {
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 5960b9f0c5c69..ee47547a1775b 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -3236,7 +3236,7 @@ ParseResult TransposeOp::parse(OpAsmParser &parser, OperationState &result) {
LogicalResult TransposeOp::verify() {
if (!getPermutation().isPermutation())
return emitOpError("expected a permutation map");
- if (getPermutation().getNumDims() != getShapedType().getRank())
+ if (getPermutation().getNumDims() != getIn().getType().getRank())
return emitOpError("expected a permutation map of same rank as the input");
auto srcType = getIn().getType().cast<MemRefType>();
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index e70a2794701f8..f15aea61b82fc 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -1617,27 +1617,26 @@ void ExtractSliceOp::getAsmResultNames(
/// rank-reduced, from the source type and the static representation of
/// offsets, sizes and strides. Special sentinels encode the dynamic case.
RankedTensorType ExtractSliceOp::inferResultType(
- ShapedType sourceShapedTensorType, ArrayRef<int64_t> staticOffsets,
+ RankedTensorType sourceTensorType, ArrayRef<int64_t> staticOffsets,
ArrayRef<int64_t> staticSizes, ArrayRef<int64_t> staticStrides) {
// An extract_slice op may specify only a leading subset of offset/sizes/
// strides in which case we complete with offset=0, sizes from memref type
// and strides=1.
assert(static_cast<int64_t>(staticSizes.size()) ==
- sourceShapedTensorType.getRank() &&
+ sourceTensorType.getRank() &&
"unexpected staticSizes not equal to rank of source");
- return RankedTensorType::get(staticSizes,
- sourceShapedTensorType.getElementType());
+ return RankedTensorType::get(staticSizes, sourceTensorType.getElementType());
}
RankedTensorType ExtractSliceOp::inferResultType(
- ShapedType sourceShapedTensorType, ArrayRef<OpFoldResult> offsets,
+ RankedTensorType sourceTensorType, ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes, ArrayRef<OpFoldResult> strides) {
SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
- return ExtractSliceOp::inferResultType(sourceShapedTensorType, staticOffsets,
+ return ExtractSliceOp::inferResultType(sourceTensorType, staticOffsets,
staticSizes, staticStrides);
}
@@ -1756,22 +1755,21 @@ void ExtractSliceOp::build(OpBuilder &b, OperationState &result, Value source,
build(b, result, RankedTensorType(), source, offsets, sizes, strides, attrs);
}
-template <typename OpTy>
static LogicalResult produceSliceErrorMsg(SliceVerificationResult result,
- OpTy op, Type expectedType) {
- auto memrefType = expectedType.cast<ShapedType>();
+ Operation *op,
+ RankedTensorType expectedType) {
switch (result) {
case SliceVerificationResult::Success:
return success();
case SliceVerificationResult::RankTooLarge:
- return op.emitError("expected rank to be smaller or equal to ")
+ return op->emitError("expected rank to be smaller or equal to ")
<< "the other rank. ";
case SliceVerificationResult::SizeMismatch:
- return op.emitError("expected type to be ")
+ return op->emitError("expected type to be ")
<< expectedType << " or a rank-reduced version. (size mismatch) ";
case SliceVerificationResult::ElemTypeMismatch:
- return op.emitError("expected element type to be ")
- << memrefType.getElementType();
+ return op->emitError("expected element type to be ")
+ << expectedType.getElementType();
default:
llvm_unreachable("unexpected extract_slice op verification result");
}
@@ -2147,9 +2145,9 @@ void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source,
/// Rank-reducing type verification for both InsertSliceOp and
/// ParallelInsertSliceOp.
static SliceVerificationResult verifyInsertSliceOp(
- ShapedType srcType, ShapedType dstType, ArrayRef<int64_t> staticOffsets,
- ArrayRef<int64_t> staticSizes, ArrayRef<int64_t> staticStrides,
- ShapedType *expectedType = nullptr) {
+ RankedTensorType srcType, RankedTensorType dstType,
+ ArrayRef<int64_t> staticOffsets, ArrayRef<int64_t> staticSizes,
+ ArrayRef<int64_t> staticStrides, RankedTensorType *expectedType = nullptr) {
// insert_slice is the inverse of extract_slice, use the same type
// inference.
RankedTensorType expected = ExtractSliceOp::inferResultType(
@@ -2161,7 +2159,7 @@ static SliceVerificationResult verifyInsertSliceOp(
/// Verifier for InsertSliceOp.
LogicalResult InsertSliceOp::verify() {
- ShapedType expectedType;
+ RankedTensorType expectedType;
SliceVerificationResult result =
verifyInsertSliceOp(getSourceType(), getType(), getStaticOffsets(),
getStaticSizes(), getStaticStrides(), &expectedType);
@@ -2334,8 +2332,10 @@ struct InsertSliceOpCastFolder final : public OpRewritePattern<InsertOpTy> {
auto src =
(sourceCastSource ? *sourceCastSource : insertSliceOp.getSource());
auto dst = (destCastSource ? *destCastSource : insertSliceOp.getDest());
- auto srcType = src.getType().template cast<ShapedType>();
- auto dstType = dst.getType().template cast<ShapedType>();
+ auto srcType = src.getType().template dyn_cast<RankedTensorType>();
+ auto dstType = dst.getType().template dyn_cast<RankedTensorType>();
+ if (!srcType || !dstType)
+ return failure();
if (verifyInsertSliceOp(srcType, dstType, insertSliceOp.getStaticOffsets(),
insertSliceOp.getStaticSizes(),
insertSliceOp.getStaticStrides()) !=
@@ -3072,7 +3072,7 @@ LogicalResult ParallelInsertSliceOp::verify() {
return this->emitError("expected ParallelCombiningOpInterface parent, got:")
<< *(getOperation()->getParentOp());
- ShapedType expectedType;
+ RankedTensorType expectedType;
SliceVerificationResult result =
verifyInsertSliceOp(getSourceType(), getDestType(), getStaticOffsets(),
getStaticSizes(), getStaticStrides(), &expectedType);
@@ -3307,9 +3307,9 @@ static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack) {
return op->emitError("invalid zero tile factor");
// Verify inner_dims_pos and outer_dims_perm.
- ShapedType unpackedType = (std::is_same<OpTy, PackOp>::value)
- ? packOrUnPack.getSourceType()
- : packOrUnPack.getDestType();
+ RankedTensorType unpackedType = (std::is_same<OpTy, PackOp>::value)
+ ? packOrUnPack.getSourceType()
+ : packOrUnPack.getDestType();
size_t unpackedRank = unpackedType.getRank();
ArrayRef<int64_t> innerDimsPos = packOrUnPack.getInnerDimsPos();
ArrayRef<int64_t> outerDimPerm = packOrUnPack.getOuterDimsPerm();
@@ -3344,7 +3344,7 @@ static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack) {
// Verify result shape is greater than the minimum expected
// by the pack operation, and that the output shape
// represents full tiles.
- ShapedType expectedPackedType = PackOp::inferPackedType(
+ RankedTensorType expectedPackedType = PackOp::inferPackedType(
unpackedType, packOrUnPack.getStaticTiles(), innerDimsPos, outerDimPerm);
if (!areAllInBound(expectedPackedType.getShape(), packedType.getShape())) {
return op->emitError("the shape of output is not large enough to hold the "
@@ -3594,10 +3594,10 @@ SmallVector<OpFoldResult> PackOp::getResultShape(
/// Get the expected packed type based on source type, tile factors, position of
/// the inner tiles and permutation of the outer tiled loop.
-ShapedType PackOp::inferPackedType(ShapedType sourceType,
- ArrayRef<int64_t> innerTileSizes,
- ArrayRef<int64_t> innerDimsPos,
- ArrayRef<int64_t> outerDimsPerm) {
+RankedTensorType PackOp::inferPackedType(RankedTensorType sourceType,
+ ArrayRef<int64_t> innerTileSizes,
+ ArrayRef<int64_t> innerDimsPos,
+ ArrayRef<int64_t> outerDimsPerm) {
SmallVector<int64_t> resultShape = getPackOpResultTypeShape(
sourceType.getShape(), innerTileSizes, innerDimsPos, outerDimsPerm);
return RankedTensorType::get(resultShape, sourceType.getElementType());
More information about the Mlir-commits
mailing list