[Mlir-commits] [mlir] eb7f355 - [mlir][NFC] Minor cleanups around ShapedType

Matthias Springer llvmlistbot at llvm.org
Tue Apr 18 19:37:43 PDT 2023


Author: Matthias Springer
Date: 2023-04-19T11:30:45+09:00
New Revision: eb7f355725d1cd6d875446581def4d742f179838

URL: https://github.com/llvm/llvm-project/commit/eb7f355725d1cd6d875446581def4d742f179838
DIFF: https://github.com/llvm/llvm-project/commit/eb7f355725d1cd6d875446581def4d742f179838.diff

LOG: [mlir][NFC] Minor cleanups around ShapedType

* Remove unnecessary casts.
* Use concrete shaped types (e.g., `MemRefType`, `RankedTensorType`) instead of `ShapedType` when possible.
* Minor documentation cleanups.

Differential Revision: https://reviews.llvm.org/D148488

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
    mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
    mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
    mlir/include/mlir/Interfaces/ViewLikeInterface.td
    mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp
    mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
    mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
    mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
    mlir/lib/Dialect/Tensor/IR/TensorOps.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index 6a7f542887fd8..265e400045d8b 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -1390,18 +1390,16 @@ def MemRef_ReinterpretCastOp
 
   let extraClassDeclaration = extraBaseClassDeclaration # [{
     // The result of the op is always a ranked memref.
-    MemRefType getType() { return getResult().getType().cast<MemRefType>(); }
+    MemRefType getType() { return getResult().getType(); }
     Value getViewSource() { return getSource(); }
 
-    /// Return the rank of the source ShapedType.
-    unsigned getResultRank() {
-      return getResult().getType().cast<ShapedType>().getRank();
-    }
+    /// Return the rank of the result type.
+    unsigned getResultRank() { return getType().getRank(); }
 
     /// Return the expected rank of each of the`static_offsets`, `static_sizes`
     /// and `static_strides` attributes.
     std::array<unsigned, 3> getArrayAttrMaxRanks() {
-      unsigned resultRank = getResult().getType().cast<ShapedType>().getRank();
+      unsigned resultRank = getType().getRank();
       return {1, resultRank, resultRank};
     }
 
@@ -1830,8 +1828,7 @@ def SubViewOp : MemRef_OpWithOffsetSizesAndStrides<"subview", [
     The representation based on offsets, sizes and strides support a
     partially-static specification via attributes specified through the
     `static_offsets`, `static_sizes` and `static_strides` arguments. A special
-    sentinel value ShapedType::kDynamic and
-    ShapedType::kDynamic encodes that the corresponding entry has
+    sentinel value ShapedType::kDynamic encodes that the corresponding entry has
     a dynamic value.
 
     A subview operation may additionally reduce the rank of the resulting view
@@ -2122,7 +2119,6 @@ def MemRef_TransposeOp : MemRef_Op<"transpose", [
 
   let extraClassDeclaration = [{
     static StringRef getPermutationAttrStrName() { return "permutation"; }
-    ShapedType getShapedType() { return getIn().getType().cast<ShapedType>(); }
   }];
 
   let hasCustomAssemblyFormat = 1;

diff  --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index 8d0028d6d5343..0c589ce920fe0 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -226,7 +226,7 @@ def Tensor_ExtractOp : Tensor_Op<"extract", [
     Pure,
     TypesMatchWith<"result type matches element type of tensor",
                    "tensor", "result",
-                   "$_self.cast<ShapedType>().getElementType()">]> {
+                   "$_self.cast<TensorType>().getElementType()">]> {
   let summary = "element extraction operation";
   let description = [{
     The `tensor.extract` op reads a ranked tensor and returns one element as
@@ -281,8 +281,7 @@ def Tensor_ExtractSliceOp : Tensor_OpWithOffsetSizesAndStrides<"extract_slice",
     The representation based on offsets, sizes and strides support a
     partially-static specification via attributes specified through the
     `static_offsets`, `static_sizes` and `static_strides` arguments. A special
-    sentinel value ShapedType::kDynamic and
-    ShapedType::kDynamic encodes that the corresponding entry has
+    sentinel value ShapedType::kDynamic encodes that the corresponding entry has
     a dynamic value.
 
     After buffer allocation, the "extract_slice" op is expected to lower into a
@@ -389,12 +388,12 @@ def Tensor_ExtractSliceOp : Tensor_OpWithOffsetSizesAndStrides<"extract_slice",
     /// rank-reduced, from the source type and the static representation of
     /// offsets, sizes and strides. Special sentinels encode the dynamic case.
     static RankedTensorType inferResultType(
-      ShapedType sourceShapedTensorType,
+      RankedTensorType sourceTensorType,
       ArrayRef<int64_t> staticOffsets,
       ArrayRef<int64_t> staticSizes,
       ArrayRef<int64_t> staticStrides);
     static RankedTensorType inferResultType(
-      ShapedType sourceShapedTensorType,
+      RankedTensorType sourceTensorType,
       ArrayRef<OpFoldResult> staticOffsets,
       ArrayRef<OpFoldResult> staticSizes,
       ArrayRef<OpFoldResult> staticStrides);
@@ -459,8 +458,8 @@ def Tensor_FromElementsOp : Tensor_Op<"from_elements", [
     Pure,
     TypesMatchWith<"operand types match result element type",
                    "result", "elements", "SmallVector<Type, 2>("
-                   "$_self.cast<ShapedType>().getNumElements(), "
-                   "$_self.cast<ShapedType>().getElementType())">
+                   "$_self.cast<RankedTensorType>().getNumElements(), "
+                   "$_self.cast<RankedTensorType>().getElementType())">
   ]> {
   let summary = "tensor from elements operation.";
   let description = [{
@@ -695,7 +694,7 @@ def Tensor_InsertOp : Tensor_Op<"insert", [
                    "$_self">,
     TypesMatchWith<"scalar type matches element type of dest",
                    "dest", "scalar",
-                   "$_self.cast<ShapedType>().getElementType()">]> {
+                   "$_self.cast<TensorType>().getElementType()">]> {
   let summary = "element insertion operation";
   let description = [{
     The `tensor.insert` op inserts a scalar into a ranked tensor `dest` as
@@ -770,8 +769,7 @@ def Tensor_InsertSliceOp : Tensor_OpWithOffsetSizesAndStrides<"insert_slice", [
     The representation based on offsets, sizes and strides support a
     partially-static specification via attributes specified through the
     `static_offsets`, `static_sizes` and `static_strides` arguments. A special
-    sentinel value ShapedType::kDynamic and
-    ShapedType::kDynamic encodes that the corresponding entry has
+    sentinel value ShapedType::kDynamic encodes that the corresponding entry has
     a dynamic value.
 
     After buffer allocation, the "insert_slice" op is expected to lower into a
@@ -1381,8 +1379,7 @@ def Tensor_ParallelInsertSliceOp : Tensor_Op<"parallel_insert_slice", [
     The representation based on offsets, sizes and strides support a
     partially-static specification via attributes specified through the
     `static_offsets`, `static_sizes` and `static_strides` arguments. A special
-    sentinel value ShapedType::kDynamic and
-    ShapedType::kDynamic encodes that the corresponding entry has
+    sentinel value ShapedType::kDynamic encodes that the corresponding entry has
     a dynamic value.
 
     After buffer allocation, the "parallel_insert_slice" op is expected to lower
@@ -1790,10 +1787,10 @@ def Tensor_PackOp : Tensor_RelayoutOp<"pack", [
         ArrayRef<OpFoldResult> innerTileDims, ArrayRef<int64_t> innerDimsPos,
         ArrayRef<int64_t> outerDimsPerm = {});
 
-    // Method to get the `ShapedType` of the result based on the inner tiles,
-    // position of the inner tiles (innerDimsPos)  and interchange vector of
-    // outer loops (outerDimsPerm).
-    static ShapedType inferPackedType(ShapedType sourceType,
+    // Method to get the `RankedTensorType` of the result based on the inner
+    // tiles, position of the inner tiles (innerDimsPos)  and interchange vector
+    // of outer loops (outerDimsPerm).
+    static RankedTensorType inferPackedType(RankedTensorType sourceType,
         ArrayRef<int64_t> innerTileSizes, ArrayRef<int64_t> innerDimsPos,
         ArrayRef<int64_t> outerDimsPerm = {});
 

diff  --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 2a95ff243fbf9..f4508b9b3ddfd 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -529,7 +529,7 @@ def Vector_ExtractElementOp :
   Vector_Op<"extractelement", [Pure,
      TypesMatchWith<"result type matches element type of vector operand",
                     "vector", "result",
-                    "$_self.cast<ShapedType>().getElementType()">]>,
+                    "$_self.cast<VectorType>().getElementType()">]>,
     Arguments<(ins AnyVectorOfAnyRank:$vector,
                    Optional<AnySignlessIntegerOrIndex>:$position)>,
     Results<(outs AnyType:$result)> {
@@ -644,7 +644,7 @@ def Vector_InsertElementOp :
   Vector_Op<"insertelement", [Pure,
      TypesMatchWith<"source operand type matches element type of result",
                     "result", "source",
-                    "$_self.cast<ShapedType>().getElementType()">,
+                    "$_self.cast<VectorType>().getElementType()">,
      AllTypesMatch<["dest", "result"]>]>,
      Arguments<(ins AnyType:$source, AnyVectorOfAnyRank:$dest,
                     Optional<AnySignlessIntegerOrIndex>:$position)>,
@@ -1884,23 +1884,15 @@ def Vector_GatherOp :
        : memref<16x16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
     ```
   }];
+
   let extraClassDeclaration = [{
-    ShapedType getBaseType() {
-      return getBase().getType().cast<ShapedType>();
-    }
-    VectorType getIndexVectorType() {
-      return getIndexVec().getType().cast<VectorType>();
-    }
-    VectorType getMaskVectorType() {
-      return getMask().getType().cast<VectorType>();
-    }
-    VectorType getPassThruVectorType() {
-      return getPassThru().getType().cast<VectorType>();
-    }
-    VectorType getVectorType() {
-      return getResult().getType().cast<VectorType>();
-    }
+    ShapedType getBaseType() { return getBase().getType(); }
+    VectorType getIndexVectorType() { return getIndexVec().getType(); }
+    VectorType getMaskVectorType() { return getMask().getType(); }
+    VectorType getPassThruVectorType() { return getPassThru().getType(); }
+    VectorType getVectorType() { return getResult().getType(); }
   }];
+
   let assemblyFormat =
     "$base `[` $indices `]` `[` $index_vec `]` `,` "
     "$mask `,` $pass_thru attr-dict `:` type($base) `,` "
@@ -1960,20 +1952,14 @@ def Vector_ScatterOp :
         : memref<16x16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>
     ```
   }];
+
   let extraClassDeclaration = [{
-    MemRefType getMemRefType() {
-      return getBase().getType().cast<MemRefType>();
-    }
-    VectorType getIndexVectorType() {
-      return getIndexVec().getType().cast<VectorType>();
-    }
-    VectorType getMaskVectorType() {
-      return getMask().getType().cast<VectorType>();
-    }
-    VectorType getVectorType() {
-      return getValueToStore().getType().cast<VectorType>();
-    }
+    MemRefType getMemRefType() { return getBase().getType(); }
+    VectorType getIndexVectorType() { return getIndexVec().getType(); }
+    VectorType getMaskVectorType() { return getMask().getType(); }
+    VectorType getVectorType() { return getValueToStore().getType(); }
   }];
+
   let assemblyFormat =
       "$base `[` $indices `]` `[` $index_vec `]` `,` "
       "$mask `,` $valueToStore attr-dict `:` type($base) `,` "

diff  --git a/mlir/include/mlir/Interfaces/ViewLikeInterface.td b/mlir/include/mlir/Interfaces/ViewLikeInterface.td
index b5870af8c7936..4cddcc1764690 100644
--- a/mlir/include/mlir/Interfaces/ViewLikeInterface.td
+++ b/mlir/include/mlir/Interfaces/ViewLikeInterface.td
@@ -50,9 +50,8 @@ def OffsetSizeAndStrideOpInterface : OpInterface<"OffsetSizeAndStrideOpInterface
          `getArrayAttrMaxRanks()`[0] (resp. [1], [2]).
       3. if an entry of `static_offsets` (resp. `static_sizes`,
          `static_strides`) is equal to a special sentinel value, namely
-         `ShapedType::kDynamic` (resp. `ShapedType::kDynamic`,
-         `ShapedType::kDynamic`), then the corresponding entry is
-         a dynamic offset (resp. size, stride).
+         `ShapedType::kDynamic`, then the corresponding entry is a dynamic
+         offset (resp. size, stride).
       4. a variadic `offset` (resp. `sizes`, `strides`) operand  must be present
          for each dynamic offset (resp. size, stride).
       5. `offsets`, `sizes` and `strides` operands are specified in this order

diff  --git a/mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp b/mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp
index e29c5d2565770..37aa6cf53c206 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp
@@ -128,7 +128,7 @@ bool AllocationOpLLVMLowering::isMemRefSizeMultipleOf(
     const DataLayout *defaultLayout) const {
   uint64_t sizeDivisor = getMemRefEltSizeInBytes(type, op, defaultLayout);
   for (unsigned i = 0, e = type.getRank(); i < e; i++) {
-    if (ShapedType::isDynamic(type.getDimSize(i)))
+    if (type.isDynamicDim(i))
       continue;
     sizeDivisor = sizeDivisor * type.getDimSize(i);
   }

diff  --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
index 2af5a2522566d..07d406bbb1c1a 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
@@ -154,10 +154,10 @@ struct ReallocOpLoweringBase : public AllocationOpLLVMLowering {
     auto computeNumElements =
         [&](MemRefType type, function_ref<Value()> getDynamicSize) -> Value {
       // Compute number of elements.
-      int64_t size = type.getShape()[0];
-      Value numElements = ((size == ShapedType::kDynamic)
-                               ? getDynamicSize()
-                               : createIndexConstant(rewriter, loc, size));
+      Value numElements =
+          type.isDynamicDim(0)
+              ? getDynamicSize()
+              : createIndexConstant(rewriter, loc, type.getDimSize(0));
       Type indexType = getIndexType();
       if (numElements.getType() != indexType)
         numElements = typeConverter->materializeTargetConversion(
@@ -987,7 +987,7 @@ struct MemRefCopyOpLowering : public ConvertOpToLLVMPattern<memref::CopyOp> {
     auto targetType = op.getTarget().getType().cast<BaseMemRefType>();
 
     // First make sure we have an unranked memref descriptor representation.
-    auto makeUnranked = [&, this](Value ranked, BaseMemRefType type) {
+    auto makeUnranked = [&, this](Value ranked, MemRefType type) {
       auto rank = rewriter.create<LLVM::ConstantOp>(loc, getIndexType(),
                                                     type.getRank());
       auto *typeConverter = getTypeConverter();
@@ -1011,12 +1011,14 @@ struct MemRefCopyOpLowering : public ConvertOpToLLVMPattern<memref::CopyOp> {
     auto stackSaveOp =
         rewriter.create<LLVM::StackSaveOp>(loc, getVoidPtrType());
 
-    Value unrankedSource = srcType.hasRank()
-                               ? makeUnranked(adaptor.getSource(), srcType)
-                               : adaptor.getSource();
-    Value unrankedTarget = targetType.hasRank()
-                               ? makeUnranked(adaptor.getTarget(), targetType)
-                               : adaptor.getTarget();
+    auto srcMemRefType = srcType.dyn_cast<MemRefType>();
+    Value unrankedSource =
+        srcMemRefType ? makeUnranked(adaptor.getSource(), srcMemRefType)
+                      : adaptor.getSource();
+    auto targetMemRefType = targetType.dyn_cast<MemRefType>();
+    Value unrankedTarget =
+        targetMemRefType ? makeUnranked(adaptor.getTarget(), targetMemRefType)
+                         : adaptor.getTarget();
 
     // Now promote the unranked descriptors to the stack.
     auto one = rewriter.create<LLVM::ConstantOp>(loc, getIndexType(),
@@ -1390,11 +1392,11 @@ struct MemRefReshapeOpLowering
         }
 
         Value dimSize;
-        int64_t size = targetMemRefType.getDimSize(i);
         // If the size of this dimension is dynamic, then load it at runtime
         // from the shape operand.
-        if (!ShapedType::isDynamic(size)) {
-          dimSize = createIndexConstant(rewriter, loc, size);
+        if (!targetMemRefType.isDynamicDim(i)) {
+          dimSize = createIndexConstant(rewriter, loc,
+                                        targetMemRefType.getDimSize(i));
         } else {
           Value shapeOp = reshapeOp.getShape();
           Value index = createIndexConstant(rewriter, loc, i);
@@ -1589,7 +1591,8 @@ class TransposeOpLowering : public ConvertOpToLLVMPattern<memref::TransposeOp> {
       return rewriter.replaceOp(transposeOp, {viewMemRef}), success();
 
     auto targetMemRef = MemRefDescriptor::undef(
-        rewriter, loc, typeConverter->convertType(transposeOp.getShapedType()));
+        rewriter, loc,
+        typeConverter->convertType(transposeOp.getIn().getType()));
 
     // Copy the base and aligned pointers from the old descriptor to the new
     // one.

diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index da582a50da307..53b96803290f7 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -1178,7 +1178,7 @@ LogicalResult MapOp::verify() {
   }
 
   // The shape of each input must match the shape of the output.
-  auto outputShape = getInit().getType().cast<ShapedType>().getShape();
+  auto outputShape = getInit().getType().getShape();
   for (Type inputArgType : TypeRange{getInputs()}) {
     auto inputElemShape = inputArgType.cast<ShapedType>().getShape();
     if (inputElemShape != outputShape) {

diff  --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 5960b9f0c5c69..ee47547a1775b 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -3236,7 +3236,7 @@ ParseResult TransposeOp::parse(OpAsmParser &parser, OperationState &result) {
 LogicalResult TransposeOp::verify() {
   if (!getPermutation().isPermutation())
     return emitOpError("expected a permutation map");
-  if (getPermutation().getNumDims() != getShapedType().getRank())
+  if (getPermutation().getNumDims() != getIn().getType().getRank())
     return emitOpError("expected a permutation map of same rank as the input");
 
   auto srcType = getIn().getType().cast<MemRefType>();

diff  --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index e70a2794701f8..f15aea61b82fc 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -1617,27 +1617,26 @@ void ExtractSliceOp::getAsmResultNames(
 /// rank-reduced, from the source type and the static representation of
 /// offsets, sizes and strides. Special sentinels encode the dynamic case.
 RankedTensorType ExtractSliceOp::inferResultType(
-    ShapedType sourceShapedTensorType, ArrayRef<int64_t> staticOffsets,
+    RankedTensorType sourceTensorType, ArrayRef<int64_t> staticOffsets,
     ArrayRef<int64_t> staticSizes, ArrayRef<int64_t> staticStrides) {
   // An extract_slice op may specify only a leading subset of offset/sizes/
   // strides in which case we complete with offset=0, sizes from memref type
   // and strides=1.
   assert(static_cast<int64_t>(staticSizes.size()) ==
-             sourceShapedTensorType.getRank() &&
+             sourceTensorType.getRank() &&
          "unexpected staticSizes not equal to rank of source");
-  return RankedTensorType::get(staticSizes,
-                               sourceShapedTensorType.getElementType());
+  return RankedTensorType::get(staticSizes, sourceTensorType.getElementType());
 }
 
 RankedTensorType ExtractSliceOp::inferResultType(
-    ShapedType sourceShapedTensorType, ArrayRef<OpFoldResult> offsets,
+    RankedTensorType sourceTensorType, ArrayRef<OpFoldResult> offsets,
     ArrayRef<OpFoldResult> sizes, ArrayRef<OpFoldResult> strides) {
   SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
   SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
   dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
   dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
   dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
-  return ExtractSliceOp::inferResultType(sourceShapedTensorType, staticOffsets,
+  return ExtractSliceOp::inferResultType(sourceTensorType, staticOffsets,
                                          staticSizes, staticStrides);
 }
 
@@ -1756,22 +1755,21 @@ void ExtractSliceOp::build(OpBuilder &b, OperationState &result, Value source,
   build(b, result, RankedTensorType(), source, offsets, sizes, strides, attrs);
 }
 
-template <typename OpTy>
 static LogicalResult produceSliceErrorMsg(SliceVerificationResult result,
-                                          OpTy op, Type expectedType) {
-  auto memrefType = expectedType.cast<ShapedType>();
+                                          Operation *op,
+                                          RankedTensorType expectedType) {
   switch (result) {
   case SliceVerificationResult::Success:
     return success();
   case SliceVerificationResult::RankTooLarge:
-    return op.emitError("expected rank to be smaller or equal to ")
+    return op->emitError("expected rank to be smaller or equal to ")
            << "the other rank. ";
   case SliceVerificationResult::SizeMismatch:
-    return op.emitError("expected type to be ")
+    return op->emitError("expected type to be ")
            << expectedType << " or a rank-reduced version. (size mismatch) ";
   case SliceVerificationResult::ElemTypeMismatch:
-    return op.emitError("expected element type to be ")
-           << memrefType.getElementType();
+    return op->emitError("expected element type to be ")
+           << expectedType.getElementType();
   default:
     llvm_unreachable("unexpected extract_slice op verification result");
   }
@@ -2147,9 +2145,9 @@ void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source,
 /// Rank-reducing type verification for both InsertSliceOp and
 /// ParallelInsertSliceOp.
 static SliceVerificationResult verifyInsertSliceOp(
-    ShapedType srcType, ShapedType dstType, ArrayRef<int64_t> staticOffsets,
-    ArrayRef<int64_t> staticSizes, ArrayRef<int64_t> staticStrides,
-    ShapedType *expectedType = nullptr) {
+    RankedTensorType srcType, RankedTensorType dstType,
+    ArrayRef<int64_t> staticOffsets, ArrayRef<int64_t> staticSizes,
+    ArrayRef<int64_t> staticStrides, RankedTensorType *expectedType = nullptr) {
   // insert_slice is the inverse of extract_slice, use the same type
   // inference.
   RankedTensorType expected = ExtractSliceOp::inferResultType(
@@ -2161,7 +2159,7 @@ static SliceVerificationResult verifyInsertSliceOp(
 
 /// Verifier for InsertSliceOp.
 LogicalResult InsertSliceOp::verify() {
-  ShapedType expectedType;
+  RankedTensorType expectedType;
   SliceVerificationResult result =
       verifyInsertSliceOp(getSourceType(), getType(), getStaticOffsets(),
                           getStaticSizes(), getStaticStrides(), &expectedType);
@@ -2334,8 +2332,10 @@ struct InsertSliceOpCastFolder final : public OpRewritePattern<InsertOpTy> {
     auto src =
         (sourceCastSource ? *sourceCastSource : insertSliceOp.getSource());
     auto dst = (destCastSource ? *destCastSource : insertSliceOp.getDest());
-    auto srcType = src.getType().template cast<ShapedType>();
-    auto dstType = dst.getType().template cast<ShapedType>();
+    auto srcType = src.getType().template dyn_cast<RankedTensorType>();
+    auto dstType = dst.getType().template dyn_cast<RankedTensorType>();
+    if (!srcType || !dstType)
+      return failure();
     if (verifyInsertSliceOp(srcType, dstType, insertSliceOp.getStaticOffsets(),
                             insertSliceOp.getStaticSizes(),
                             insertSliceOp.getStaticStrides()) !=
@@ -3072,7 +3072,7 @@ LogicalResult ParallelInsertSliceOp::verify() {
     return this->emitError("expected ParallelCombiningOpInterface parent, got:")
            << *(getOperation()->getParentOp());
 
-  ShapedType expectedType;
+  RankedTensorType expectedType;
   SliceVerificationResult result =
       verifyInsertSliceOp(getSourceType(), getDestType(), getStaticOffsets(),
                           getStaticSizes(), getStaticStrides(), &expectedType);
@@ -3307,9 +3307,9 @@ static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack) {
     return op->emitError("invalid zero tile factor");
 
   // Verify inner_dims_pos and outer_dims_perm.
-  ShapedType unpackedType = (std::is_same<OpTy, PackOp>::value)
-                                ? packOrUnPack.getSourceType()
-                                : packOrUnPack.getDestType();
+  RankedTensorType unpackedType = (std::is_same<OpTy, PackOp>::value)
+                                      ? packOrUnPack.getSourceType()
+                                      : packOrUnPack.getDestType();
   size_t unpackedRank = unpackedType.getRank();
   ArrayRef<int64_t> innerDimsPos = packOrUnPack.getInnerDimsPos();
   ArrayRef<int64_t> outerDimPerm = packOrUnPack.getOuterDimsPerm();
@@ -3344,7 +3344,7 @@ static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack) {
   // Verify result shape is greater than the minimum expected
   // by the pack operation, and that the output shape
   // represents full tiles.
-  ShapedType expectedPackedType = PackOp::inferPackedType(
+  RankedTensorType expectedPackedType = PackOp::inferPackedType(
       unpackedType, packOrUnPack.getStaticTiles(), innerDimsPos, outerDimPerm);
   if (!areAllInBound(expectedPackedType.getShape(), packedType.getShape())) {
     return op->emitError("the shape of output is not large enough to hold the "
@@ -3594,10 +3594,10 @@ SmallVector<OpFoldResult> PackOp::getResultShape(
 
 /// Get the expected packed type based on source type, tile factors, position of
 /// the inner tiles and permutation of the outer tiled loop.
-ShapedType PackOp::inferPackedType(ShapedType sourceType,
-                                   ArrayRef<int64_t> innerTileSizes,
-                                   ArrayRef<int64_t> innerDimsPos,
-                                   ArrayRef<int64_t> outerDimsPerm) {
+RankedTensorType PackOp::inferPackedType(RankedTensorType sourceType,
+                                         ArrayRef<int64_t> innerTileSizes,
+                                         ArrayRef<int64_t> innerDimsPos,
+                                         ArrayRef<int64_t> outerDimsPerm) {
   SmallVector<int64_t> resultShape = getPackOpResultTypeShape(
       sourceType.getShape(), innerTileSizes, innerDimsPos, outerDimsPerm);
   return RankedTensorType::get(resultShape, sourceType.getElementType());


        


More information about the Mlir-commits mailing list