[Mlir-commits] [mlir] a08b750 - [mlir][tensor] InsertSliceOp verification.

Nicolas Vasilache llvmlistbot at llvm.org
Tue Nov 30 12:38:34 PST 2021


Author: Nicolas Vasilache
Date: 2021-11-30T20:37:06Z
New Revision: a08b750ce9df2bf1cf9270d83c50de68eeb8b6f5

URL: https://github.com/llvm/llvm-project/commit/a08b750ce9df2bf1cf9270d83c50de68eeb8b6f5
DIFF: https://github.com/llvm/llvm-project/commit/a08b750ce9df2bf1cf9270d83c50de68eeb8b6f5.diff

LOG: [mlir][tensor] InsertSliceOp verification.

This revision reintroduces tensor.insert_slice verification which seems
to have vanished over time: a verifier was initially introduced in cf9503c1b752062d9abfb2c7922a50574d9c5de4
but for some reason the invalid.mlir was not properly updated; as time passed the verifier was not called anymore and later the code was deleted.

As a consequence, a non-negligible portion of tests has run astray using invalid
tensor.insert_slice semantics and needed to be fixed.

Also, extract isRankReducedType from TensorOps for better reuse
Originally, this facility was used by both tensor and memref forms but
it got copied around as dialects were split.

Differential Revision: https://reviews.llvm.org/D114715

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
    mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
    mlir/include/mlir/IR/BuiltinTypes.h
    mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
    mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
    mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
    mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
    mlir/lib/Dialect/Utils/StaticValueUtils.cpp
    mlir/lib/IR/BuiltinTypes.cpp
    mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
    mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir
    mlir/test/Dialect/Linalg/roundtrip.mlir
    mlir/test/Dialect/MemRef/invalid.mlir
    mlir/test/Dialect/SCF/for-loop-canonicalization.mlir
    mlir/test/Dialect/Tensor/canonicalize.mlir
    mlir/test/Dialect/Tensor/invalid.mlir
    mlir/test/Dialect/Tensor/ops.mlir
    mlir/test/IR/core-ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index 2151606f4bfb3..5442fed96dd1a 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -184,11 +184,14 @@ def Tensor_ExtractSliceOp : BaseOpWithOffsetSizesAndStrides<
     ShapedType::kDynamicStrideOrOffset encodes that the corresponding entry has
     a dynamic value.
 
-    After buffer-allocation, the "extract_slice" op is expected to lower into a
-    "subview" op.
+    After buffer allocation, the "extract_slice" op is expected to lower into a
+    memref.subview op.
 
     An extract_slice operation may additionally reduce the rank of the resulting
     tensor by removing dimensions that are statically known to be of size 1.
+    This rank-reduction behavior is not required by the op semantics: this
+    flexibility allows to progressively drop unit dimensions while lowering
+    between 
diff erent flavors of ops on that operate on tensors.
 
     Example:
 
@@ -196,8 +199,8 @@ def Tensor_ExtractSliceOp : BaseOpWithOffsetSizesAndStrides<
     // Rank-reducing extract_slice.
     %1 = tensor.extract_slice %0[0, 0, 0][1, 16, 4][1, 1, 1] :
       tensor<8x16x4xf32> to tensor<16x4xf32>
-    %3 = tensor.extract_slice %2[3, 4, 2][1, 6, 3][1, 1, 1] :
-      tensor<8x16x4xf32> to tensor<6x3xf32>
+    %3 = tensor.extract_slice %2[%o0, 4, %o2][1, %sz1, 1][1, %st1, 1] :
+      tensor<8x16x4xf32> to tensor<1x?xf32>
     ```
   }];
 
@@ -257,24 +260,28 @@ def Tensor_ExtractSliceOp : BaseOpWithOffsetSizesAndStrides<
     /// An extract_slice result type can be fully inferred from the source type
     /// and the static representation of offsets, sizes and strides. Special
     /// sentinels encode the dynamic case.
-    static Type inferResultType(RankedTensorType sourceRankedTensorType,
-                                ArrayRef<int64_t> staticOffsets,
-                                ArrayRef<int64_t> staticSizes,
-                                ArrayRef<int64_t> staticStrides);
-    static Type inferResultType(RankedTensorType sourceRankedTensorType,
-                                ArrayRef<OpFoldResult> staticOffsets,
-                                ArrayRef<OpFoldResult> staticSizes,
-                                ArrayRef<OpFoldResult> staticStrides);
-    static Type inferRankReducedResultType(unsigned resultRank,
-                                           RankedTensorType sourceRankedTensorType,
-                                           ArrayRef<int64_t> staticOffsets,
-                                           ArrayRef<int64_t> staticSizes,
-                                           ArrayRef<int64_t> staticStrides);
-    static Type inferRankReducedResultType(unsigned resultRank,
-                                           RankedTensorType sourceRankedTensorType,
-                                           ArrayRef<OpFoldResult> staticOffsets,
-                                           ArrayRef<OpFoldResult> staticSizes,
-                                           ArrayRef<OpFoldResult> staticStrides);
+    static RankedTensorType inferResultType(
+      RankedTensorType sourceRankedTensorType,
+      ArrayRef<int64_t> staticOffsets,
+      ArrayRef<int64_t> staticSizes,
+      ArrayRef<int64_t> staticStrides);
+    static RankedTensorType inferResultType(
+      RankedTensorType sourceRankedTensorType,
+      ArrayRef<OpFoldResult> staticOffsets,
+      ArrayRef<OpFoldResult> staticSizes,
+      ArrayRef<OpFoldResult> staticStrides);
+    static RankedTensorType inferRankReducedResultType(
+      unsigned resultRank,
+      RankedTensorType sourceRankedTensorType,
+      ArrayRef<int64_t> staticOffsets,
+      ArrayRef<int64_t> staticSizes,
+      ArrayRef<int64_t> staticStrides);
+    static RankedTensorType inferRankReducedResultType(
+      unsigned resultRank,
+      RankedTensorType sourceRankedTensorType,
+      ArrayRef<OpFoldResult> staticOffsets,
+      ArrayRef<OpFoldResult> staticSizes,
+      ArrayRef<OpFoldResult> staticStrides);
 
     /// Return the expected rank of each of the`static_offsets`, `static_sizes`
     /// and `static_strides` attributes.
@@ -469,8 +476,27 @@ def Tensor_InsertSliceOp : BaseOpWithOffsetSizesAndStrides<
     ShapedType::kDynamicStrideOrOffset encodes that the corresponding entry has
     a dynamic value.
 
-    After buffer-allocation, the "insert_slice" op is expected to become an
-    in-place buffer update.
+    After buffer allocation, the "insert_slice" op is expected to lower into a
+    memref.subview op.
+
+    An insert_slice operation may additionally specify insertion into a tensor
+    of higher rank than the source tensor, along dimensions that are statically
+    known to be of size 1.
+    This rank-altering behavior is not required by the op semantics: this
+    flexibility allows to progressively drop unit dimensions while lowering
+    between 
diff erent flavors of ops on that operate on tensors.
+    The rank-altering behavior of tensor.insert_slice matches the rank-reducing
+    behavior of tensor.extract_slice.
+
+    Example:
+
+    ```
+    // Rank-reducing extract_slice.
+    %1 = tensor.insert_slice %t into %0[0, 0, 0][1, 16, 4][1, 1, 1] :
+      tensor<16x4xf32> into tensor<8x16x4xf32>
+    %3 = tensor.insert_slice %tt into %2[%o0, 4, %o2][1, %sz1, 1][1, %st1, 1] :
+      tensor<1x?xf32> into tensor<8x16x4xf32>
+    ```
   }];
 
   let arguments = (ins
@@ -493,8 +519,6 @@ def Tensor_InsertSliceOp : BaseOpWithOffsetSizesAndStrides<
     attr-dict `:` type($source) `into` type($dest)
   }];
 
-  let verifier = ?;
-
   let builders = [
     // Build a InsertSliceOp with mixed static and dynamic entries.
     OpBuilder<(ins "Value":$source, "Value":$dest,

diff  --git a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
index 5838f1d1fb241..bf5d0a8bd1bf8 100644
--- a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
@@ -21,9 +21,10 @@
 
 namespace mlir {
 
-/// Helper function to dispatch an OpFoldResult into either the `dynamicVec` if
-/// it is a Value or into `staticVec` if it is an IntegerAttr.
-/// In the case of a Value, a copy of the `sentinel` value is also pushed to
+/// Helper function to dispatch an OpFoldResult into `staticVec` if:
+///   a) it is an IntegerAttr
+/// In other cases, the OpFoldResult is dispached to the `dynamicVec`.
+/// In such dynamic cases, a copy of the `sentinel` value is also pushed to
 /// `staticVec`. This is useful to extract mixed static and dynamic entries that
 /// come from an AttrSizedOperandSegments trait.
 void dispatchIndexOpFoldResult(OpFoldResult ofr,
@@ -31,11 +32,8 @@ void dispatchIndexOpFoldResult(OpFoldResult ofr,
                                SmallVectorImpl<int64_t> &staticVec,
                                int64_t sentinel);
 
-/// Helper function to dispatch multiple OpFoldResults into either the
-/// `dynamicVec` (for Values) or into `staticVec` (for IntegerAttrs).
-/// In the case of a Value, a copy of the `sentinel` value is also pushed to
-/// `staticVec`. This is useful to extract mixed static and dynamic entries that
-/// come from an AttrSizedOperandSegments trait.
+/// Helper function to dispatch multiple OpFoldResults according to the behavior
+/// of `dispatchIndexOpFoldResult(OpFoldResult ofr` for a single OpFoldResult.
 void dispatchIndexOpFoldResults(ArrayRef<OpFoldResult> ofrs,
                                 SmallVectorImpl<Value> &dynamicVec,
                                 SmallVectorImpl<int64_t> &staticVec,

diff  --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h
index f3d2c24073dc6..10d8a5847ebd2 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.h
+++ b/mlir/include/mlir/IR/BuiltinTypes.h
@@ -369,6 +369,25 @@ llvm::Optional<llvm::SmallDenseSet<unsigned>>
 computeRankReductionMask(ArrayRef<int64_t> originalShape,
                          ArrayRef<int64_t> reducedShape);
 
+/// Enum that captures information related to verifier error conditions on
+/// slice insert/extract type of ops.
+enum class SliceVerificationResult {
+  Success,
+  RankTooLarge,
+  SizeMismatch,
+  ElemTypeMismatch,
+  // Error codes to ops with a memory space and a layout annotation.
+  MemSpaceMismatch,
+  LayoutMismatch
+};
+
+/// Check if `originalType` can be rank reduced to `candidateReducedType` type
+/// by dropping some dimensions with static size `1`.
+/// Return `SliceVerificationResult::Success` on success or an appropriate error
+/// code.
+SliceVerificationResult isRankReducedType(ShapedType originalType,
+                                          ShapedType candidateReducedType);
+
 //===----------------------------------------------------------------------===//
 // Deferred Method Definitions
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 77cf563abe1a1..7961638a4661b 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -2248,8 +2248,8 @@ struct ConcatConverter : public OpConversionPattern<tosa::ConcatOp> {
 
     Location loc = op.getLoc();
     int axis = op.axis();
-    Value axisValue =
-        rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(axis));
+    Value axisValue = rewriter.createOrFold<arith::ConstantOp>(
+        loc, rewriter.getIndexAttr(axis));
     int rank = resultType.getRank();
     SmallVector<Value, 3> offsets, sizes, strides;
     sizes.reserve(rank);
@@ -2257,31 +2257,41 @@ struct ConcatConverter : public OpConversionPattern<tosa::ConcatOp> {
     offsets.resize(rank, rewriter.create<arith::ConstantIndexOp>(loc, 0));
 
     for (int i = 0; i < rank; ++i) {
-      sizes.push_back(
-          rewriter.create<tensor::DimOp>(loc, adaptor.getOperands()[0], i));
+      sizes.push_back(rewriter.createOrFold<tensor::DimOp>(
+          loc, adaptor.getOperands()[0], i));
     }
 
     Value resultDimSize = sizes[axis];
     for (auto arg : adaptor.getOperands().drop_front()) {
-      auto size = rewriter.create<tensor::DimOp>(loc, arg, axisValue);
-      resultDimSize = rewriter.create<arith::AddIOp>(loc, resultDimSize, size);
+      auto size = rewriter.createOrFold<tensor::DimOp>(loc, arg, axisValue);
+      resultDimSize =
+          rewriter.createOrFold<arith::AddIOp>(loc, resultDimSize, size);
     }
     sizes[axis] = resultDimSize;
 
     Value init = rewriter.create<linalg::InitTensorOp>(
         loc, resultType.getShape(), resultType.getElementType());
 
-    Value zeroVal = rewriter.create<arith::ConstantOp>(
+    Value zeroVal = rewriter.createOrFold<arith::ConstantOp>(
         loc, rewriter.getZeroAttr(resultType.getElementType()));
     Value result =
         rewriter.create<linalg::FillOp>(loc, zeroVal, init).getResult(0);
 
+    auto toOpFoldResult = [](Value v) -> OpFoldResult {
+      auto op = v.getDefiningOp<arith::ConstantIndexOp>();
+      if (!op)
+        return v;
+      return op.getValue();
+    };
     for (auto arg : adaptor.getOperands()) {
-      sizes[axis] = rewriter.create<tensor::DimOp>(loc, arg, axisValue);
-      result = rewriter.create<tensor::InsertSliceOp>(loc, arg, result, offsets,
-                                                      sizes, strides);
+      sizes[axis] = rewriter.createOrFold<tensor::DimOp>(loc, arg, axisValue);
+      result = rewriter.createOrFold<tensor::InsertSliceOp>(
+          loc, arg, result,
+          llvm::to_vector(llvm::map_range(offsets, toOpFoldResult)),
+          llvm::to_vector(llvm::map_range(sizes, toOpFoldResult)),
+          llvm::to_vector(llvm::map_range(strides, toOpFoldResult)));
       offsets[axis] =
-          rewriter.create<arith::AddIOp>(loc, offsets[axis], sizes[axis]);
+          rewriter.createOrFold<arith::AddIOp>(loc, offsets[axis], sizes[axis]);
     }
     rewriter.replaceOp(op, result);
     return success();

diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 9ef930cb204c1..36828eabd59f7 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -835,16 +835,14 @@ void GenericOp::getCanonicalizationPatterns(RewritePatternSet &results,
 //===----------------------------------------------------------------------===//
 // InitTensorOp
 //===----------------------------------------------------------------------===//
+
 void InitTensorOp::build(OpBuilder &b, OperationState &result,
                          ArrayRef<OpFoldResult> sizes, Type elementType,
                          ArrayRef<NamedAttribute> attrs) {
-  unsigned rank = sizes.size();
   SmallVector<Value, 4> dynamicSizes;
   SmallVector<int64_t, 4> staticSizes;
-  for (unsigned i = 0; i < rank; ++i) {
-    dispatchIndexOpFoldResult(sizes[i], dynamicSizes, staticSizes,
-                              ShapedType::kDynamicSize);
-  }
+  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes,
+                             ShapedType::kDynamicSize);
   auto resultType = RankedTensorType ::get(staticSizes, elementType);
   build(b, result, resultType, dynamicSizes, b.getI64ArrayAttr(staticSizes));
   result.addAttributes(attrs);
@@ -1127,19 +1125,16 @@ void PadTensorOp::build(OpBuilder &b, OperationState &result, Type resultType,
                         ArrayRef<NamedAttribute> attrs) {
   assert(resultType.isa<RankedTensorType>());
   auto sourceType = source.getType().cast<RankedTensorType>();
-  unsigned rank = sourceType.getRank();
   SmallVector<Value, 4> dynamicLow, dynamicHigh;
   SmallVector<int64_t, 4> staticLow, staticHigh;
-  for (unsigned i = 0; i < rank; ++i) {
-    // staticLow and staticHigh have full information of the padding config.
-    // This will grow staticLow and staticHigh with 1 value. If the config is
-    // dynamic (ie not a constant), dynamicLow and dynamicHigh will grow with 1
-    // value as well.
-    dispatchIndexOpFoldResult(low[i], dynamicLow, staticLow,
-                              ShapedType::kDynamicSize);
-    dispatchIndexOpFoldResult(high[i], dynamicHigh, staticHigh,
-                              ShapedType::kDynamicSize);
-  }
+  // staticLow and staticHigh have full information of the padding config.
+  // This will grow staticLow and staticHigh with 1 value. If the config is
+  // dynamic (ie not a constant), dynamicLow and dynamicHigh will grow with 1
+  // value as well.
+  dispatchIndexOpFoldResults(low, dynamicLow, staticLow,
+                             ShapedType::kDynamicSize);
+  dispatchIndexOpFoldResults(high, dynamicHigh, staticHigh,
+                             ShapedType::kDynamicSize);
   if (!resultType) {
     resultType =
         PadTensorOp::inferResultType(sourceType, staticLow, staticHigh);

diff  --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index b9bd01b439a9d..938197df59c22 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -504,11 +504,13 @@ static std::map<int64_t, unsigned> getNumOccurences(ArrayRef<int64_t> vals) {
   return numOccurences;
 }
 
-/// Given the type of the un-rank reduced subview result type and the
-/// rank-reduced result type, computes the dropped dimensions. This accounts for
-/// cases where there are multiple unit-dims, but only a subset of those are
-/// dropped. For MemRefTypes these can be disambiguated using the strides. If a
-/// dimension is dropped the stride must be dropped too.
+/// Given the `originalType` and a `candidateReducedType` whose shape is assumed
+/// to be a subset of `originalType` with some `1` entries erased, return the
+/// set of indices that specifies which of the entries of `originalShape` are
+/// dropped to obtain `reducedShape`.
+/// This accounts for cases where there are multiple unit-dims, but only a
+/// subset of those are dropped. For MemRefTypes these can be disambiguated
+/// using the strides. If a dimension is dropped the stride must be dropped too.
 static llvm::Optional<llvm::SmallDenseSet<unsigned>>
 computeMemRefRankReductionMask(MemRefType originalType, MemRefType reducedType,
                                ArrayRef<OpFoldResult> sizes) {
@@ -1548,8 +1550,7 @@ Type SubViewOp::inferResultType(MemRefType sourceMemRefType,
   dispatchIndexOpFoldResults(leadingStaticStrides, dynamicStrides,
                              staticStrides, ShapedType::kDynamicStrideOrOffset);
   return SubViewOp::inferResultType(sourceMemRefType, staticOffsets,
-                                    staticSizes, staticStrides)
-      .cast<MemRefType>();
+                                    staticSizes, staticStrides);
 }
 
 Type SubViewOp::inferRankReducedResultType(
@@ -1706,88 +1707,58 @@ void SubViewOp::build(OpBuilder &b, OperationState &result, Value source,
 /// For ViewLikeOpInterface.
 Value SubViewOp::getViewSource() { return source(); }
 
-enum SubViewVerificationResult {
-  Success,
-  RankTooLarge,
-  SizeMismatch,
-  ElemTypeMismatch,
-  MemSpaceMismatch,
-  AffineMapMismatch
-};
-
 /// Checks if `original` Type type can be rank reduced to `reduced` type.
 /// This function is slight variant of `is subsequence` algorithm where
 /// not matching dimension must be 1.
-static SubViewVerificationResult
-isRankReducedType(Type originalType, Type candidateReducedType,
-                  ArrayRef<OpFoldResult> sizes, std::string *errMsg = nullptr) {
-  if (originalType == candidateReducedType)
-    return SubViewVerificationResult::Success;
-  if (!originalType.isa<MemRefType>())
-    return SubViewVerificationResult::Success;
-  if (originalType.isa<MemRefType>() && !candidateReducedType.isa<MemRefType>())
-    return SubViewVerificationResult::Success;
-
-  ShapedType originalShapedType = originalType.cast<ShapedType>();
-  ShapedType candidateReducedShapedType =
-      candidateReducedType.cast<ShapedType>();
-
-  // Rank and size logic is valid for all ShapedTypes.
-  ArrayRef<int64_t> originalShape = originalShapedType.getShape();
-  ArrayRef<int64_t> candidateReducedShape =
-      candidateReducedShapedType.getShape();
-  unsigned originalRank = originalShape.size(),
-           candidateReducedRank = candidateReducedShape.size();
-  if (candidateReducedRank > originalRank)
-    return SubViewVerificationResult::RankTooLarge;
+static SliceVerificationResult
+isRankReducedMemRefType(MemRefType originalType,
+                        MemRefType candidatecandidateReducedType,
+                        ArrayRef<OpFoldResult> sizes) {
+  auto partialRes =
+      isRankReducedType(originalType, candidatecandidateReducedType);
+  if (partialRes != SliceVerificationResult::Success)
+    return partialRes;
 
   MemRefType original = originalType.cast<MemRefType>();
-  MemRefType candidateReduced = candidateReducedType.cast<MemRefType>();
+  MemRefType candidateReduced =
+      candidatecandidateReducedType.cast<MemRefType>();
 
   auto optionalUnusedDimsMask =
       computeMemRefRankReductionMask(original, candidateReduced, sizes);
 
   // Sizes cannot be matched in case empty vector is returned.
   if (!optionalUnusedDimsMask.hasValue())
-    return SubViewVerificationResult::SizeMismatch;
+    return SliceVerificationResult::LayoutMismatch;
 
-  if (originalShapedType.getElementType() !=
-      candidateReducedShapedType.getElementType())
-    return SubViewVerificationResult::ElemTypeMismatch;
-
-  // Strided layout logic is relevant for MemRefType only.
   if (original.getMemorySpace() != candidateReduced.getMemorySpace())
-    return SubViewVerificationResult::MemSpaceMismatch;
-  return SubViewVerificationResult::Success;
+    return SliceVerificationResult::MemSpaceMismatch;
+
+  return SliceVerificationResult::Success;
 }
 
 template <typename OpTy>
-static LogicalResult produceSubViewErrorMsg(SubViewVerificationResult result,
-                                            OpTy op, Type expectedType,
-                                            StringRef errMsg = "") {
+static LogicalResult produceSubViewErrorMsg(SliceVerificationResult result,
+                                            OpTy op, Type expectedType) {
   auto memrefType = expectedType.cast<ShapedType>();
   switch (result) {
-  case SubViewVerificationResult::Success:
+  case SliceVerificationResult::Success:
     return success();
-  case SubViewVerificationResult::RankTooLarge:
+  case SliceVerificationResult::RankTooLarge:
     return op.emitError("expected result rank to be smaller or equal to ")
-           << "the source rank. " << errMsg;
-  case SubViewVerificationResult::SizeMismatch:
+           << "the source rank. ";
+  case SliceVerificationResult::SizeMismatch:
     return op.emitError("expected result type to be ")
            << expectedType
-           << " or a rank-reduced version. (mismatch of result sizes) "
-           << errMsg;
-  case SubViewVerificationResult::ElemTypeMismatch:
+           << " or a rank-reduced version. (mismatch of result sizes) ";
+  case SliceVerificationResult::ElemTypeMismatch:
     return op.emitError("expected result element type to be ")
-           << memrefType.getElementType() << errMsg;
-  case SubViewVerificationResult::MemSpaceMismatch:
-    return op.emitError("expected result and source memory spaces to match.")
-           << errMsg;
-  case SubViewVerificationResult::AffineMapMismatch:
+           << memrefType.getElementType();
+  case SliceVerificationResult::MemSpaceMismatch:
+    return op.emitError("expected result and source memory spaces to match.");
+  case SliceVerificationResult::LayoutMismatch:
     return op.emitError("expected result type to be ")
            << expectedType
-           << " or a rank-reduced version. (mismatch of result affine map) "
-           << errMsg;
+           << " or a rank-reduced version. (mismatch of result layout) ";
   }
   llvm_unreachable("unexpected subview verification result");
 }
@@ -1813,10 +1784,9 @@ static LogicalResult verify(SubViewOp op) {
       extractFromI64ArrayAttr(op.static_sizes()),
       extractFromI64ArrayAttr(op.static_strides()));
 
-  std::string errMsg;
-  auto result =
-      isRankReducedType(expectedType, subViewType, op.getMixedSizes(), &errMsg);
-  return produceSubViewErrorMsg(result, op, expectedType, errMsg);
+  auto result = isRankReducedMemRefType(expectedType.cast<MemRefType>(),
+                                        subViewType, op.getMixedSizes());
+  return produceSubViewErrorMsg(result, op, expectedType);
 }
 
 raw_ostream &mlir::operator<<(raw_ostream &os, const Range &range) {

diff  --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 0eff26bc86de4..7f1bd74cd37aa 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -12,6 +12,7 @@
 #include "mlir/Dialect/Utils/StaticValueUtils.h"
 #include "mlir/IR/BlockAndValueMapping.h"
 #include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinAttributeInterfaces.h"
 #include "mlir/IR/Matchers.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/TypeUtilities.h"
@@ -655,10 +656,11 @@ static LogicalResult verify(ReshapeOp op) {
 /// An extract_slice op result type can be fully inferred from the source type
 /// and the static representation of offsets, sizes and strides. Special
 /// sentinels encode the dynamic case.
-Type ExtractSliceOp::inferResultType(RankedTensorType sourceRankedTensorType,
-                                     ArrayRef<int64_t> leadingStaticOffsets,
-                                     ArrayRef<int64_t> leadingStaticSizes,
-                                     ArrayRef<int64_t> leadingStaticStrides) {
+RankedTensorType
+ExtractSliceOp::inferResultType(RankedTensorType sourceRankedTensorType,
+                                ArrayRef<int64_t> leadingStaticOffsets,
+                                ArrayRef<int64_t> leadingStaticSizes,
+                                ArrayRef<int64_t> leadingStaticStrides) {
   // An extract_slice op may specify only a leading subset of offset/sizes/
   // strides in which case we complete with offset=0, sizes from memref type and
   // strides=1.
@@ -673,11 +675,11 @@ Type ExtractSliceOp::inferResultType(RankedTensorType sourceRankedTensorType,
                                sourceRankedTensorType.getElementType());
 }
 
-Type ExtractSliceOp::inferResultType(
-    RankedTensorType sourceRankedTensorType,
-    ArrayRef<OpFoldResult> leadingStaticOffsets,
-    ArrayRef<OpFoldResult> leadingStaticSizes,
-    ArrayRef<OpFoldResult> leadingStaticStrides) {
+RankedTensorType
+ExtractSliceOp::inferResultType(RankedTensorType sourceRankedTensorType,
+                                ArrayRef<OpFoldResult> leadingStaticOffsets,
+                                ArrayRef<OpFoldResult> leadingStaticSizes,
+                                ArrayRef<OpFoldResult> leadingStaticStrides) {
   SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
   SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
   dispatchIndexOpFoldResults(leadingStaticOffsets, dynamicOffsets,
@@ -693,7 +695,7 @@ Type ExtractSliceOp::inferResultType(
 /// An extract_slice op result type can be fully inferred from the source type
 /// and the static representation of offsets, sizes and strides. Special
 /// sentinels encode the dynamic case.
-Type ExtractSliceOp::inferRankReducedResultType(
+RankedTensorType ExtractSliceOp::inferRankReducedResultType(
     unsigned resultRank, RankedTensorType sourceRankedTensorType,
     ArrayRef<int64_t> leadingStaticOffsets,
     ArrayRef<int64_t> leadingStaticSizes,
@@ -717,7 +719,7 @@ Type ExtractSliceOp::inferRankReducedResultType(
   return inferredType;
 }
 
-Type ExtractSliceOp::inferRankReducedResultType(
+RankedTensorType ExtractSliceOp::inferRankReducedResultType(
     unsigned resultRank, RankedTensorType sourceRankedTensorType,
     ArrayRef<OpFoldResult> leadingStaticOffsets,
     ArrayRef<OpFoldResult> leadingStaticSizes,
@@ -746,10 +748,12 @@ void ExtractSliceOp::build(OpBuilder &b, OperationState &result,
   SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
   SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
   dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets,
+
                              ShapedType::kDynamicStrideOrOffset);
   dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes,
                              ShapedType::kDynamicSize);
   dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides,
+
                              ShapedType::kDynamicStrideOrOffset);
   auto sourceRankedTensorType = source.getType().cast<RankedTensorType>();
   // Structuring implementation this way avoids duplication between builders.
@@ -797,89 +801,35 @@ void ExtractSliceOp::build(OpBuilder &b, OperationState &result, Value source,
   build(b, result, RankedTensorType(), source, offsets, sizes, strides, attrs);
 }
 
-enum SliceVerificationResult {
-  Success,
-  RankTooLarge,
-  SizeMismatch,
-  ElemTypeMismatch,
-};
-
-/// Checks if `original` Type type can be rank reduced to `reduced` type.
-/// This function is slight variant of `is subsequence` algorithm where
-/// not matching dimension must be 1.
-static SliceVerificationResult
-isRankReducedType(Type originalType, Type candidateReducedType,
-                  std::string *errMsg = nullptr) {
-  if (originalType == candidateReducedType)
-    return SliceVerificationResult::Success;
-  if (!originalType.isa<RankedTensorType>())
-    return SliceVerificationResult::Success;
-  if (originalType.isa<RankedTensorType>() &&
-      !candidateReducedType.isa<RankedTensorType>())
-    return SliceVerificationResult::Success;
-
-  ShapedType originalShapedType = originalType.cast<ShapedType>();
-  ShapedType candidateReducedShapedType =
-      candidateReducedType.cast<ShapedType>();
-
-  // Rank and size logic is valid for all ShapedTypes.
-  ArrayRef<int64_t> originalShape = originalShapedType.getShape();
-  ArrayRef<int64_t> candidateReducedShape =
-      candidateReducedShapedType.getShape();
-  unsigned originalRank = originalShape.size(),
-           candidateReducedRank = candidateReducedShape.size();
-  if (candidateReducedRank > originalRank)
-    return SliceVerificationResult::RankTooLarge;
-
-  auto optionalUnusedDimsMask =
-      computeRankReductionMask(originalShape, candidateReducedShape);
-
-  // Sizes cannot be matched in case empty vector is returned.
-  if (!optionalUnusedDimsMask.hasValue())
-    return SliceVerificationResult::SizeMismatch;
-
-  if (originalShapedType.getElementType() !=
-      candidateReducedShapedType.getElementType())
-    return SliceVerificationResult::ElemTypeMismatch;
-
-  // We are done for the tensor case.
-  if (originalType.isa<RankedTensorType>())
-    return SliceVerificationResult::Success;
-
-  return SliceVerificationResult::Success;
-}
-
 template <typename OpTy>
 static LogicalResult produceSliceErrorMsg(SliceVerificationResult result,
-                                          OpTy op, Type expectedType,
-                                          StringRef errMsg = "") {
+                                          OpTy op, Type expectedType) {
   auto memrefType = expectedType.cast<ShapedType>();
   switch (result) {
   case SliceVerificationResult::Success:
     return success();
   case SliceVerificationResult::RankTooLarge:
-    return op.emitError("expected result rank to be smaller or equal to ")
-           << "the source rank. " << errMsg;
+    return op.emitError("expected rank to be smaller or equal to ")
+           << "the other rank. ";
   case SliceVerificationResult::SizeMismatch:
-    return op.emitError("expected result type to be ")
-           << expectedType
-           << " or a rank-reduced version. (mismatch of result sizes) "
-           << errMsg;
+    return op.emitError("expected type to be ")
+           << expectedType << " or a rank-reduced version. (size mismatch) ";
   case SliceVerificationResult::ElemTypeMismatch:
-    return op.emitError("expected result element type to be ")
-           << memrefType.getElementType() << errMsg;
+    return op.emitError("expected element type to be ")
+           << memrefType.getElementType();
+  default:
+    llvm_unreachable("unexpected extract_slice op verification result");
   }
-  llvm_unreachable("unexpected extract_slice op verification result");
 }
 
 /// Verifier for ExtractSliceOp.
 static LogicalResult verify(ExtractSliceOp op) {
   // Verify result type against inferred type.
-  auto expectedType = ExtractSliceOp::inferResultType(
-      op.getSourceType(), extractFromI64ArrayAttr(op.static_offsets()),
-      extractFromI64ArrayAttr(op.static_sizes()),
-      extractFromI64ArrayAttr(op.static_strides()));
-  auto result = isRankReducedType(expectedType, op.getType());
+  auto expectedType =
+      ExtractSliceOp::inferResultType(op.getSourceType(), op.getMixedOffsets(),
+                                      op.getMixedSizes(), op.getMixedStrides());
+  auto result =
+      isRankReducedType(expectedType.cast<ShapedType>(), op.getType());
   return produceSliceErrorMsg(result, op, expectedType);
 }
 
@@ -1104,10 +1054,12 @@ void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source,
   SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
   SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
   dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets,
+
                              ShapedType::kDynamicStrideOrOffset);
   dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes,
                              ShapedType::kDynamicSize);
   dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides,
+
                              ShapedType::kDynamicStrideOrOffset);
   build(b, result, dest.getType(), source, dest, dynamicOffsets, dynamicSizes,
         dynamicStrides, b.getI64ArrayAttr(staticOffsets),
@@ -1128,6 +1080,19 @@ void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source,
   build(b, result, source, dest, offsetValues, sizeValues, strideValues);
 }
 
+/// Verifier for InsertSliceOp.
+static LogicalResult verify(InsertSliceOp op) {
+  // insert_slice is the inverse of extract_slice, use the same type inference.
+  auto expectedType = ExtractSliceOp::inferRankReducedResultType(
+      op.getSourceType().getRank(), op.getType(),
+      extractFromI64ArrayAttr(op.static_offsets()),
+      extractFromI64ArrayAttr(op.static_sizes()),
+      extractFromI64ArrayAttr(op.static_strides()));
+  auto result =
+      isRankReducedType(expectedType.cast<ShapedType>(), op.getSourceType());
+  return produceSliceErrorMsg(result, op, expectedType);
+}
+
 /// If we have two consecutive InsertSliceOp writing to the same slice, we
 /// can mutate the second InsertSliceOp's destination to the first one's.
 ///
@@ -1202,9 +1167,16 @@ class InsertSliceOpConstantArgumentFolder final
     canonicalizeSubViewPart(mixedStrides, ShapedType::isDynamicStrideOrOffset);
 
     // Create the new op in canonical form.
-    rewriter.replaceOpWithNewOp<InsertSliceOp>(
-        insertSliceOp, insertSliceOp.source(), insertSliceOp.dest(),
+    auto sourceType = ExtractSliceOp::inferRankReducedResultType(
+        insertSliceOp.getSourceType().getRank(), insertSliceOp.getType(),
         mixedOffsets, mixedSizes, mixedStrides);
+    Value toInsert = insertSliceOp.source();
+    if (sourceType != insertSliceOp.getSourceType())
+      toInsert = rewriter.create<tensor::CastOp>(insertSliceOp.getLoc(),
+                                                 sourceType, toInsert);
+    rewriter.replaceOpWithNewOp<InsertSliceOp>(
+        insertSliceOp, toInsert, insertSliceOp.dest(), mixedOffsets, mixedSizes,
+        mixedStrides);
     return success();
   }
 };

diff  --git a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
index 24d8ac09deff9..3e50fac6fd3a8 100644
--- a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
+++ b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
@@ -13,22 +13,24 @@
 
 namespace mlir {
 
-/// Helper function to dispatch an OpFoldResult into either the `dynamicVec` if
-/// it is a Value or into `staticVec` if it is an IntegerAttr.
-/// In the case of a Value, a copy of the `sentinel` value is also pushed to
+/// Helper function to dispatch an OpFoldResult into `staticVec` if:
+///   a) it is an IntegerAttr
+/// In other cases, the OpFoldResult is dispached to the `dynamicVec`.
+/// In such dynamic cases, a copy of the `sentinel` value is also pushed to
 /// `staticVec`. This is useful to extract mixed static and dynamic entries that
 /// come from an AttrSizedOperandSegments trait.
 void dispatchIndexOpFoldResult(OpFoldResult ofr,
                                SmallVectorImpl<Value> &dynamicVec,
                                SmallVectorImpl<int64_t> &staticVec,
                                int64_t sentinel) {
-  if (auto v = ofr.dyn_cast<Value>()) {
-    dynamicVec.push_back(v);
-    staticVec.push_back(sentinel);
+  auto v = ofr.dyn_cast<Value>();
+  if (!v) {
+    APInt apInt = ofr.get<Attribute>().cast<IntegerAttr>().getValue();
+    staticVec.push_back(apInt.getSExtValue());
     return;
   }
-  APInt apInt = ofr.dyn_cast<Attribute>().cast<IntegerAttr>().getValue();
-  staticVec.push_back(apInt.getSExtValue());
+  dynamicVec.push_back(v);
+  staticVec.push_back(sentinel);
 }
 
 void dispatchIndexOpFoldResults(ArrayRef<OpFoldResult> ofrs,

diff  --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp
index 64dceaaa4480d..33ed6b60932d4 100644
--- a/mlir/lib/IR/BuiltinTypes.cpp
+++ b/mlir/lib/IR/BuiltinTypes.cpp
@@ -571,7 +571,7 @@ mlir::computeRankReductionMask(ArrayRef<int64_t> originalShape,
   llvm::SmallDenseSet<unsigned> unusedDims;
   unsigned reducedIdx = 0;
   for (unsigned originalIdx = 0; originalIdx < originalRank; ++originalIdx) {
-    // Greedily insert `originalIdx` if no match.
+    // Greedily insert `originalIdx` if match.
     if (reducedIdx < reducedRank &&
         originalShape[originalIdx] == reducedShape[reducedIdx]) {
       reducedIdx++;
@@ -590,6 +590,39 @@ mlir::computeRankReductionMask(ArrayRef<int64_t> originalShape,
   return unusedDims;
 }
 
+SliceVerificationResult
+mlir::isRankReducedType(ShapedType originalType,
+                        ShapedType candidateReducedType) {
+  if (originalType == candidateReducedType)
+    return SliceVerificationResult::Success;
+
+  ShapedType originalShapedType = originalType.cast<ShapedType>();
+  ShapedType candidateReducedShapedType =
+      candidateReducedType.cast<ShapedType>();
+
+  // Rank and size logic is valid for all ShapedTypes.
+  ArrayRef<int64_t> originalShape = originalShapedType.getShape();
+  ArrayRef<int64_t> candidateReducedShape =
+      candidateReducedShapedType.getShape();
+  unsigned originalRank = originalShape.size(),
+           candidateReducedRank = candidateReducedShape.size();
+  if (candidateReducedRank > originalRank)
+    return SliceVerificationResult::RankTooLarge;
+
+  auto optionalUnusedDimsMask =
+      computeRankReductionMask(originalShape, candidateReducedShape);
+
+  // Sizes cannot be matched in case empty vector is returned.
+  if (!optionalUnusedDimsMask.hasValue())
+    return SliceVerificationResult::SizeMismatch;
+
+  if (originalShapedType.getElementType() !=
+      candidateReducedShapedType.getElementType())
+    return SliceVerificationResult::ElemTypeMismatch;
+
+  return SliceVerificationResult::Success;
+}
+
 bool mlir::detail::isSupportedMemorySpace(Attribute memorySpace) {
   // Empty attribute is allowed as default memory space.
   if (!memorySpace)

diff  --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index 1cf88f9bc9709..15409702ec197 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -820,38 +820,24 @@ func @concat(%arg0: tensor<5x1xf32>, %arg1: tensor<6x1xf32>) -> () {
   // CHECK: [[STRIDE:%.+]]   = arith.constant 1
   // CHECK: [[OFFSET:%.+]] = arith.constant 0 : index
   // CHECK: [[IDX0:%.+]] = arith.constant 0 : index
-  // CHECK: [[ARG0_DIM0:%.+]] = tensor.dim %arg0, [[IDX0]]
   // CHECK: [[IDX1:%.+]] = arith.constant 1 : index
-  // CHECK: [[ARG0_DIM1:%.+]] = tensor.dim %arg0, [[IDX1]]
-  // CHECK: [[ARG1_AXIS:%.+]] = tensor.dim %arg1, [[AXIS]]
-  // CHECK: [[RESULT_AXIS:%.+]] = arith.addi [[ARG0_DIM0]], [[ARG1_AXIS]]
   // CHECK: [[INIT:%.+]] = linalg.init_tensor [11, 1]
   // CHECK: [[CST:%.+]] = arith.constant 0.0
   // CHECK: [[FILL:%.+]] = linalg.fill([[CST]], [[INIT]])
-  // CHECK: [[ARG0_DIM0:%.+]] = tensor.dim %arg0, [[AXIS]]
-  // CHECK: [[INSERT0:%.+]] = tensor.insert_slice %arg0 into [[FILL]]{{\[}}[[OFFSET]], [[OFFSET]]] {{\[}}[[ARG0_DIM0]], [[ARG0_DIM1]]] {{\[}}[[STRIDE]], [[STRIDE]]]
-  // CHECK: [[NEW_OFFSET:%.+]] = arith.addi [[OFFSET]], [[ARG0_DIM0]]
-  // CHECK: [[ARG1_DIM0:%.+]] = tensor.dim %arg1, [[AXIS]]
-  // CHECK: [[INSERT1:%.+]] = tensor.insert_slice %arg1 into [[INSERT0]]{{\[}}[[NEW_OFFSET]], [[OFFSET]]] {{\[}}[[ARG1_DIM0]], [[ARG0_DIM1]]] {{\[}}[[STRIDE]], [[STRIDE]]]
+  // CHECK: [[INSERT0:%.+]] = tensor.insert_slice %arg0 into [[FILL]][0, 0] [5, 1] [1, 1]
+  // CHECK: [[INSERT1:%.+]] = tensor.insert_slice %arg1 into [[INSERT0]][5, 0] [6, 1] [1, 1]
   %0 = "tosa.concat"(%arg0, %arg1) { axis = 0 : i64} : (tensor<5x1xf32>, tensor<6x1xf32>)  -> (tensor<11x1xf32>)
 
   // CHECK: [[AXIS:%.+]] = arith.constant 1
   // CHECK: [[STRIDE:%.+]]   = arith.constant 1
   // CHECK: [[OFFSET:%.+]] = arith.constant 0 : index
   // CHECK: [[IDX0:%.+]] = arith.constant 0 : index
-  // CHECK: [[ARG0_DIM0:%.+]] = tensor.dim %arg0, [[IDX0]]
   // CHECK: [[IDX1:%.+]] = arith.constant 1 : index
-  // CHECK: [[ARG0_DIM1:%.+]] = tensor.dim %arg0, [[IDX1]]
-  // CHECK: [[ARG1_AXIS:%.+]] = tensor.dim %arg0, [[AXIS]]
-  // CHECK: [[RESULT_AXIS:%.+]] = arith.addi [[ARG0_DIM1]], [[ARG1_AXIS]]
   // CHECK: [[INIT:%.+]] = linalg.init_tensor [5, 2]
   // CHECK: [[CST:%.+]] = arith.constant 0.0
   // CHECK: [[FILL:%.+]] = linalg.fill([[CST]], [[INIT]])
-  // CHECK: [[ARG0_DIM1:%.+]] = tensor.dim %arg0, [[AXIS]]
-  // CHECK: [[INSERT0:%.+]] = tensor.insert_slice %arg0 into [[FILL]]{{\[}}[[OFFSET]], [[OFFSET]]] {{\[}}[[ARG0_DIM0]], [[ARG0_DIM1]]] {{\[}}[[STRIDE]], [[STRIDE]]]
-  // CHECK: [[NEW_OFFSET:%.+]] = arith.addi [[OFFSET]], [[ARG0_DIM1]]
-  // CHECK: [[ARG1_DIM1:%.+]] = tensor.dim %arg0, [[AXIS]]
-  // CHECK: [[INSERT1:%.+]] = tensor.insert_slice %arg0 into [[INSERT0]]{{\[}}[[OFFSET]], [[NEW_OFFSET]]] {{\[}}[[ARG0_DIM0]], [[ARG1_DIM1]]] {{\[}}[[STRIDE]], [[STRIDE]]]
+  // CHECK: [[INSERT0:%.+]] = tensor.insert_slice %arg0 into [[FILL]][0, 0] [5, 1] [1, 1]
+  // CHECK: [[INSERT1:%.+]] = tensor.insert_slice %arg0 into [[INSERT0]][0, 1] [5, 1] [1, 1]
   %1 = "tosa.concat"(%arg0, %arg0) { axis = 1 : i64} : (tensor<5x1xf32>, tensor<5x1xf32>)  -> (tensor<5x2xf32>)
   return
 }

diff  --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir
index 68fe343e412e2..4a8d2e48162d0 100644
--- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir
+++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir
@@ -428,7 +428,9 @@ func @nested_extract_slice_and_insert(
     %A : tensor<?x?xf32>,
     %B : tensor<?x?xf32> {linalg.inplaceable = true},
     %C : tensor<?x?xf32> {linalg.inplaceable = true},
-    %idx : index)
+    %idx : index,
+    %sz1 : index,
+    %sz2 : index)
   ->  (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>)
 {
   %f0 = arith.constant 0.0 : f32
@@ -497,9 +499,9 @@ func @nested_extract_slice_and_insert(
   // CHECK-NEXT: tensor.insert_slice
   // CHECK-SAME: {__inplace_results_attr__ = ["true"]}
   %sC = tensor.extract_slice %C[0, 0][%idx, %idx][1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
-  %ssC = tensor.extract_slice %sC[0, 0][4, 4][1, 1] : tensor<?x?xf32> to tensor<4x4xf32>
-  %FC = linalg.fill(%f0, %ssC) : f32, tensor<4x4xf32> -> tensor<4x4xf32>
-  %rsC = tensor.insert_slice %FC into %sC[0, 0][12345, 67890][1, 1] : tensor<4x4xf32> into tensor<?x?xf32>
+  %ssC = tensor.extract_slice %sC[0, 0][%sz1, 4][1, 1] : tensor<?x?xf32> to tensor<?x4xf32>
+  %FC = linalg.fill(%f0, %ssC) : f32, tensor<?x4xf32> -> tensor<?x4xf32>
+  %rsC = tensor.insert_slice %FC into %sC[0, 0][%sz2, 4][1, 1] : tensor<?x4xf32> into tensor<?x?xf32>
   %rC = tensor.insert_slice %rsC into %C[0, 0][%idx, %idx][1, 1] : tensor<?x?xf32> into tensor<?x?xf32>
 
   return %rA, %rB, %rC: tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>

diff  --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir
index b0cddb26c6772..20cd8606c9391 100644
--- a/mlir/test/Dialect/Linalg/roundtrip.mlir
+++ b/mlir/test/Dialect/Linalg/roundtrip.mlir
@@ -592,7 +592,7 @@ func @tiled_loop_reduction(%input_3d: tensor<16x24x32xf32>,
       linalg.yield %1 : f32
     } -> tensor<4xf32>
 
-    %sum_sub = tensor.insert_slice %acc into %o_[%j][%c4][1]
+    %sum_sub = tensor.insert_slice %acc into %o_[%j][4][1]
       : tensor<4xf32> into tensor<24xf32>
     linalg.yield %sum_sub : tensor<24xf32>
   }

diff  --git a/mlir/test/Dialect/MemRef/invalid.mlir b/mlir/test/Dialect/MemRef/invalid.mlir
index 861478d7cc171..73006149a37b5 100644
--- a/mlir/test/Dialect/MemRef/invalid.mlir
+++ b/mlir/test/Dialect/MemRef/invalid.mlir
@@ -644,7 +644,7 @@ func @invalid_rank_reducing_subview(%arg0 : index, %arg1 : index, %arg2 : index)
 // -----
 
 func @invalid_rank_reducing_subview(%arg0 : memref<?x?xf32>, %arg1 : index, %arg2 : index) {
-  // expected-error at +1 {{expected result type to be 'memref<?x1xf32, affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>>' or a rank-reduced version. (mismatch of result sizes)}}
+  // expected-error at +1 {{expected result type to be 'memref<?x1xf32, affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>>' or a rank-reduced version. (mismatch of result layout)}}
   %0 = memref.subview %arg0[0, %arg1][%arg2, 1][1, 1] : memref<?x?xf32> to memref<?xf32>
   return
 }
@@ -653,7 +653,7 @@ func @invalid_rank_reducing_subview(%arg0 : memref<?x?xf32>, %arg1 : index, %arg
 
 func @static_stride_to_dynamic_stride(%arg0 : memref<?x?x?xf32>, %arg1 : index,
     %arg2 : index) -> memref<?x?xf32, offset:?, strides: [?, ?]> {
-  // expected-error @+1 {{expected result type to be 'memref<1x?x?xf32, affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2 + d2)>>' or a rank-reduced version. (mismatch of result sizes)}}
+  // expected-error @+1 {{expected result type to be 'memref<1x?x?xf32, affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2 + d2)>>' or a rank-reduced version. (mismatch of result layout)}}
   %0 = memref.subview %arg0[0, 0, 0] [1, %arg1, %arg2] [1, 1, 1] : memref<?x?x?xf32> to memref<?x?xf32, offset: ?, strides: [?, ?]>
   return %0 : memref<?x?xf32, offset: ?, strides: [?, ?]>
 }

diff  --git a/mlir/test/Dialect/SCF/for-loop-canonicalization.mlir b/mlir/test/Dialect/SCF/for-loop-canonicalization.mlir
index a07e5688ead36..79a692f9d59b1 100644
--- a/mlir/test/Dialect/SCF/for-loop-canonicalization.mlir
+++ b/mlir/test/Dialect/SCF/for-loop-canonicalization.mlir
@@ -250,7 +250,7 @@ func @tensor_dim_of_iter_arg(%t : tensor<?x?xf32>) -> index {
 //       CHECK:   scf.for
 //       CHECK:     tensor.dim %[[t]]
 func @tensor_dim_of_iter_arg_insertslice(%t : tensor<?x?xf32>,
-                                         %t2 : tensor<?x?xf32>) -> index {
+                                         %t2 : tensor<10x10xf32>) -> index {
   %c0 = arith.constant 0 : index
   %c1 = arith.constant 1 : index
   %c10 = arith.constant 10 : index
@@ -258,9 +258,9 @@ func @tensor_dim_of_iter_arg_insertslice(%t : tensor<?x?xf32>,
       -> (tensor<?x?xf32>, index) {
     %dim = tensor.dim %arg0, %c0 : tensor<?x?xf32>
     %2 = tensor.insert_slice %t2 into %arg0[0, 0] [10, 10] [1, 1]
-        : tensor<?x?xf32> into tensor<?x?xf32>
+        : tensor<10x10xf32> into tensor<?x?xf32>
     %3 = tensor.insert_slice %t2 into %2[1, 1] [10, 10] [1, 1]
-        : tensor<?x?xf32> into tensor<?x?xf32>
+        : tensor<10x10xf32> into tensor<?x?xf32>
     scf.yield %3, %dim : tensor<?x?xf32>, index
   }
   return %1 : index
@@ -274,7 +274,7 @@ func @tensor_dim_of_iter_arg_insertslice(%t : tensor<?x?xf32>,
 //       CHECK:     scf.for
 //       CHECK:       tensor.dim %[[t]]
 func @tensor_dim_of_iter_arg_nested_for(%t : tensor<?x?xf32>,
-                                        %t2 : tensor<?x?xf32>) -> index {
+                                        %t2 : tensor<10x10xf32>) -> index {
   %c0 = arith.constant 0 : index
   %c1 = arith.constant 1 : index
   %c10 = arith.constant 10 : index
@@ -284,7 +284,7 @@ func @tensor_dim_of_iter_arg_nested_for(%t : tensor<?x?xf32>,
         -> (tensor<?x?xf32>, index) {
       %dim = tensor.dim %arg2, %c0 : tensor<?x?xf32>
       %4 = tensor.insert_slice %t2 into %arg2[0, 0] [10, 10] [1, 1]
-          : tensor<?x?xf32> into tensor<?x?xf32>
+          : tensor<10x10xf32> into tensor<?x?xf32>
       scf.yield %4, %dim : tensor<?x?xf32>, index
     }
     scf.yield %2, %3 : tensor<?x?xf32>, index
@@ -292,6 +292,7 @@ func @tensor_dim_of_iter_arg_nested_for(%t : tensor<?x?xf32>,
   return %1 : index
 }
 
+
 // -----
 
 // A test case that should not canonicalize because the loop is not shape

diff  --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 9d9da02c0220f..1aa4008cf90ec 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -348,8 +348,10 @@ func @rank_reducing_tensor_of_cast(%arg : tensor<4x6x16x32xi8>) -> tensor<16x32x
 //   CHECK-NOT:   tensor.cast
 //       CHECK:   return %[[S]] : tensor<4x6x16x32xi8>
 func @rank_reducing_insert_slice_of_cast(%a : tensor<16x32xi8>, %b : tensor<4x6x16x32xi8>) -> tensor<4x6x16x32xi8> {
+  %c0 = arith.constant 0: index
   %cast = tensor.cast %a : tensor<16x32xi8> to tensor<?x32xi8>
-  %res = tensor.insert_slice %cast into %b[0, 1, 0] [1, 1, 16] [1, 1, 1] : tensor<?x32xi8> into tensor<4x6x16x32xi8>
+  %sz = tensor.dim %cast, %c0: tensor<?x32xi8>
+  %res = tensor.insert_slice %cast into %b[0, 1, 0] [1, 1, %sz] [1, 1, 1] : tensor<?x32xi8> into tensor<4x6x16x32xi8>
   return %res : tensor<4x6x16x32xi8>
 }
 
@@ -408,9 +410,10 @@ func @rank_reducing_insert_slice_canonicalize(%arg0 : tensor<?x?xf32>, %arg1 : i
 }
 // CHECK-LABEL: func @rank_reducing_insert_slice_canonicalize
 //  CHECK-SAME:   %[[ARG0:.+]]: tensor<?x?xf32>
-//       CHECK:   %[[RESULT:.+]] = tensor.insert_slice %[[ARG0]]
+//       CHECK:   %[[CAST:.*]] = tensor.cast %[[ARG0]] : tensor<?x?xf32> to tensor<4x?xf32>
+//       CHECK:   %[[RESULT:.+]] = tensor.insert_slice %[[CAST]]
 //  CHECK-SAME:      [0, %{{.+}}, 1] [4, 1, %{{.+}}] [1, 1, 1]
-//  CHECK-SAME:      : tensor<?x?xf32> into tensor<?x?x?xf32>
+//  CHECK-SAME:      : tensor<4x?xf32> into tensor<?x?x?xf32>
 //       CHEKC:   return %[[RESULT]]
 
 // -----
@@ -450,7 +453,7 @@ func @insert_slice_propagate_dest_cast(%arg0 : tensor<2x?xi32>, %arg1 : tensor<i
   ^bb0(%arg4: index, %arg5: index):
     tensor.yield %1 : i32
   } : tensor<?x?xi32>
-  %3 = tensor.insert_slice %arg0 into %2[%c0, %arg3] [%c2, %0] [%c1, %c1] : tensor<2x?xi32> into tensor<?x?xi32>
+  %3 = tensor.insert_slice %arg0 into %2[0, %arg3] [2, %0] [1, 1] : tensor<2x?xi32> into tensor<?x?xi32>
   return %3 : tensor<?x?xi32>
 }
 // CHECK-LABEL: func @insert_slice_propagate_dest_cast
@@ -462,9 +465,6 @@ func @insert_slice_propagate_dest_cast(%arg0 : tensor<2x?xi32>, %arg1 : tensor<i
 // -----
 
 func @insert_slice_output_dest_canonicalize(%arg0 : tensor<2x3xi32>, %arg1 : tensor<i32>) -> tensor<3x9xi32> {
-  %c0 = arith.constant 0 : index
-  %c1 = arith.constant 1 : index
-  %c2 = arith.constant 2 : index
   %c9 = arith.constant 9 : index
   %c3 = arith.constant 3 : index
   %2 = tensor.extract %arg1[] : tensor<i32>
@@ -472,7 +472,7 @@ func @insert_slice_output_dest_canonicalize(%arg0 : tensor<2x3xi32>, %arg1 : ten
   ^bb0(%arg2: index, %arg3: index):
     tensor.yield %2 : i32
   } : tensor<?x?xi32>
-  %5 = tensor.insert_slice %arg0 into %4[%c0, %c1] [%c2, %c3] [1, 1] : tensor<2x3xi32> into tensor<?x?xi32>
+  %5 = tensor.insert_slice %arg0 into %4[0, 1] [2, 3] [1, 1] : tensor<2x3xi32> into tensor<?x?xi32>
   %6 = tensor.cast %5 : tensor<?x?xi32> to tensor<3x9xi32>
   return %6 : tensor<3x9xi32>
 }
@@ -527,8 +527,9 @@ func @fold_dim_of_tensor.cast(%arg0 : tensor<4x?xf32>) -> (index, index) {
 //      CHECK:    %[[r:.*]] =  tensor.insert_slice %[[cast]] into %[[arg1]][0, 1, 2] [64, 5, 64] [1, 1, 1] : tensor<64x5x64xf32> into tensor<?x?x?xf32>
 //      CHECK:    return %[[r]]
 func @insert_tensor_cast_on_insert_slice_src(
-  %arg0 : tensor<?x5x?xf32>,  %arg1 : tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
-  %r = tensor.insert_slice %arg0 into %arg1[0, 1, 2] [64, 5, 64] [1, 1, 1]
+    %arg0 : tensor<?x5x?xf32>,  %arg1 : tensor<?x?x?xf32>, %sz0: index, %sz2: index) -> tensor<?x?x?xf32> {
+  %c64 = arith.constant 64: index
+  %r = tensor.insert_slice %arg0 into %arg1[0, 1, 2] [%c64, 5, %c64] [1, 1, 1]
     : tensor<?x5x?xf32> into tensor<?x?x?xf32>
   return %r : tensor<?x?x?xf32>
 }
@@ -559,13 +560,3 @@ func @fold_overlapping_insert(%input : tensor<?x?x?xf32>, %slice1: tensor<4x?x8x
   // CHECK: return %[[INSERT]]
   return %1 : tensor<?x?x?xf32>
 }
-
-// -----
-
-// CHECK-LABEL: func @folding_incorrect_ir_triggers_infinite_loop
-func @folding_incorrect_ir_triggers_infinite_loop(
-  %A : tensor<4x4xf32>, %C : tensor<?x?xf32>) -> tensor<?x?xf32> {
-  %rC = tensor.insert_slice %A into %C[0, 0][12345, 67890][1, 1] :
-    tensor<4x4xf32> into tensor<?x?xf32>
-  return %rC: tensor<?x?xf32>
-}

diff  --git a/mlir/test/Dialect/Tensor/invalid.mlir b/mlir/test/Dialect/Tensor/invalid.mlir
index 51b67a76d14c3..f3c8ba28eb51e 100644
--- a/mlir/test/Dialect/Tensor/invalid.mlir
+++ b/mlir/test/Dialect/Tensor/invalid.mlir
@@ -149,8 +149,36 @@ func @tensor.reshape_num_elements_mismatch(
 
 // -----
 
-func @slice_wrong_dynamic_type(%t: tensor<8x16x4xf32>, %idx : index) {
-      // expected-error @+1 {{expected result type to be 'tensor<4x4x4xf32>' or a rank-reduced version. (mismatch of result sizes)}}
+func @extract_slice_wrong_result_rank(%t: tensor<?xf32>, %idx : index) {
+  // expected-error @+1 {{expected rank to be smaller or equal to the other rank.}}
+  %0 = tensor.extract_slice %t[0][4][1] : tensor<?xf32> to tensor<?x?xf32>
+
+  return
+}
+
+// -----
+
+func @extract_slice_wrong_result_rank(%t: tensor<?xf32>, %idx : index) {
+  // expected-error @+1 {{expected element type to be 'f32'}}
+  %0 = tensor.extract_slice %t[0][4][1] : tensor<?xf32> to tensor<4xi8>
+
+  return
+}
+
+// -----
+
+func @extract_slice_wrong_static_type(%t: tensor<8x16x4xf32>, %idx : index) {
+  // expected-error @+1 {{expected type to be 'tensor<?x4x4xf32>' or a rank-reduced version. (size mismatch)}}
+  %0 = tensor.extract_slice %t[0, 0, 0][%idx, 4, 4][1, 1, 1]
+    : tensor<8x16x4xf32> to tensor<4x4x4xf32>
+
+  return
+}
+
+// -----
+
+func @extract_slice_wrong_dynamic_type(%t: tensor<8x16x4xf32>, %idx : index) {
+  // expected-error @+1 {{expected type to be 'tensor<4x4x4xf32>' or a rank-reduced version. (size mismatch)}}
   %0 = tensor.extract_slice %t[0, 2, 0][4, 4, 4][1, 1, 1]
     : tensor<8x16x4xf32> to tensor<?x4x4xf32>
 
@@ -159,10 +187,38 @@ func @slice_wrong_dynamic_type(%t: tensor<8x16x4xf32>, %idx : index) {
 
 // -----
 
-func @slice_wrong_static_type(%t: tensor<8x16x4xf32>, %idx : index) {
-      // expected-error @+1 {{expected result type to be 'tensor<?x3x?xf32>' or a rank-reduced version. (mismatch of result sizes)}}
-  %0 = tensor.extract_slice %t[0, 0, 0][%idx, 3, %idx][1, 1, 1]
-    : tensor<8x16x4xf32> to tensor<4x4x4xf32>
+func @insert_slice_wrong_result_rank(%t1: tensor<?xf32>, %t2: tensor<?x?xf32>, %idx : index) {
+  // expected-error @+1 {{expected rank to be smaller or equal to the other rank.}}
+  %0 = tensor.insert_slice %t2 into %t1[0][4][1] : tensor<?x?xf32> into tensor<?xf32>
+
+  return
+}
+
+// -----
+
+func @insert_slice_wrong_result_rank(%t1: tensor<4xi8>, %t2: tensor<?xf32>, %idx : index) {
+  // expected-error @+1 {{expected element type to be 'f32'}}
+  %0 = tensor.insert_slice %t1 into %t2[0][4][1] : tensor<4xi8> into tensor<?xf32>
+
+  return
+}
+
+// -----
+
+func @insert_slice_wrong_static_type(%t1: tensor<4x4x4xf32>, %t2: tensor<8x16x4xf32>, %idx : index) {
+  // expected-error @+1 {{expected type to be 'tensor<?x4x4xf32>' or a rank-reduced version. (size mismatch)}}
+  %0 = tensor.insert_slice %t1 into %t2[0, 0, 0][%idx, 4, 4][1, 1, 1]
+    : tensor<4x4x4xf32> into tensor<8x16x4xf32>
+
+  return
+}
+
+// -----
+
+func @insert_slice_wrong_dynamic_type(%t1: tensor<?x4x4xf32>, %t2: tensor<8x16x4xf32>, %idx : index) {
+  // expected-error @+1 {{expected type to be 'tensor<4x4x4xf32>' or a rank-reduced version. (size mismatch)}}
+  %0 = tensor.insert_slice %t1 into %t2[0, 2, 0][4, 4, 4][1, 1, 1]
+    : tensor<?x4x4xf32> into tensor<8x16x4xf32>
 
   return
 }

diff  --git a/mlir/test/Dialect/Tensor/ops.mlir b/mlir/test/Dialect/Tensor/ops.mlir
index 2353d0320b4e0..d8c5a415fcb8b 100644
--- a/mlir/test/Dialect/Tensor/ops.mlir
+++ b/mlir/test/Dialect/Tensor/ops.mlir
@@ -78,3 +78,60 @@ func @tensor_reshape(%unranked: tensor<*xf32>, %shape1: tensor<1xi32>,
                : (tensor<?x?xf32>, tensor<?xi32>) -> tensor<*xf32>
   return %new_unranked : tensor<*xf32>
 }
+
+// CHECK-LABEL: func @slice({{.*}}) {
+func @slice(%t: tensor<8x16x4xf32>, %idx : index) {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+
+  // CHECK: tensor.extract_slice
+  // CHECK-SAME: tensor<8x16x4xf32> to tensor<?x?x?xf32>
+  %1 = tensor.extract_slice %t[%c0, %c0, %c0][%idx, %idx, %idx][%c1, %c1, %c1]
+    : tensor<8x16x4xf32> to tensor<?x?x?xf32>
+
+  // CHECK: tensor.extract_slice
+  // CHECK-SAME: tensor<8x16x4xf32> to tensor<4x4x4xf32>
+  %2 = tensor.extract_slice %t[0, 2, 0][4, 4, 4][1, 1, 1]
+    : tensor<8x16x4xf32> to tensor<4x4x4xf32>
+
+  // CHECK: tensor.extract_slice
+  // CHECK-SAME: tensor<8x16x4xf32> to tensor<4x4xf32>
+  %3 = tensor.extract_slice %t[0, 2, 0][4, 1, 4][1, 1, 1]
+    : tensor<8x16x4xf32> to tensor<4x4xf32>
+
+  return
+}
+
+// CHECK-LABEL: func @insert_slice({{.*}}) {
+func @insert_slice(
+    %t: tensor<8x16x4xf32>,
+    %td: tensor<8x?x4xf32>,
+    %t2: tensor<16x32x8xf32>,
+    %t3: tensor<4x4xf32>,
+    %idx : index,
+    %sz : index) {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+
+  // CHECK: tensor.insert_slice
+  // CHECK-SAME: tensor<8x16x4xf32> into tensor<16x32x8xf32>
+  %1 = tensor.insert_slice %t into %t2[%c0, %c0, %c0][8, 16, 4][%c1, %c1, %c1]
+    : tensor<8x16x4xf32> into tensor<16x32x8xf32>
+
+  // CHECK: tensor.insert_slice
+  // CHECK-SAME: tensor<8x16x4xf32> into tensor<16x32x8xf32>
+  %2 = tensor.insert_slice %t into %t2[%c0, %idx, %c0][8, 16, 4][%c1, 1, %c1]
+    : tensor<8x16x4xf32> into tensor<16x32x8xf32>
+
+  // CHECK: tensor.insert_slice
+  // CHECK-SAME: tensor<4x4xf32> into tensor<8x16x4xf32>
+  %3 = tensor.insert_slice %t3 into %t[0, 2, 0][4, 1, 4][1, 1, 1]
+    : tensor<4x4xf32> into tensor<8x16x4xf32>
+
+  // CHECK: tensor.insert_slice
+  // CHECK-SAME: tensor<8x?x4xf32> into tensor<8x16x4xf32>
+  %4 = tensor.insert_slice %td into %t[0, %idx, 0][8, %sz, 4][1, 1, 1]
+    : tensor<8x?x4xf32> into tensor<8x16x4xf32>
+
+  return
+}

diff  --git a/mlir/test/IR/core-ops.mlir b/mlir/test/IR/core-ops.mlir
index 029d26c0c71e6..300c2542d8078 100644
--- a/mlir/test/IR/core-ops.mlir
+++ b/mlir/test/IR/core-ops.mlir
@@ -486,53 +486,3 @@ func @assume_alignment(%0: memref<4x4xf16>) {
   memref.assume_alignment %0, 16 : memref<4x4xf16>
   return
 }
-
-// CHECK-LABEL: func @slice({{.*}}) {
-func @slice(%t: tensor<8x16x4xf32>, %idx : index) {
-  %c0 = arith.constant 0 : index
-  %c1 = arith.constant 1 : index
-
-  // CHECK: tensor.extract_slice
-  // CHECK-SAME: tensor<8x16x4xf32> to tensor<?x?x?xf32>
-  %1 = tensor.extract_slice %t[%c0, %c0, %c0][%idx, %idx, %idx][%c1, %c1, %c1]
-    : tensor<8x16x4xf32> to tensor<?x?x?xf32>
-
-  // CHECK: tensor.extract_slice
-  // CHECK-SAME: tensor<8x16x4xf32> to tensor<4x4x4xf32>
-  %2 = tensor.extract_slice %t[0, 2, 0][4, 4, 4][1, 1, 1]
-    : tensor<8x16x4xf32> to tensor<4x4x4xf32>
-
-  // CHECK: tensor.extract_slice
-  // CHECK-SAME: tensor<8x16x4xf32> to tensor<4x4xf32>
-  %3 = tensor.extract_slice %t[0, 2, 0][4, 1, 4][1, 1, 1]
-    : tensor<8x16x4xf32> to tensor<4x4xf32>
-
-  return
-}
-
-// CHECK-LABEL: func @insert_slice({{.*}}) {
-func @insert_slice(
-    %t: tensor<8x16x4xf32>,
-    %t2: tensor<16x32x8xf32>,
-    %t3: tensor<4x4xf32>,
-    %idx : index) {
-  %c0 = arith.constant 0 : index
-  %c1 = arith.constant 1 : index
-
-  // CHECK: tensor.insert_slice
-  // CHECK-SAME: tensor<8x16x4xf32> into tensor<16x32x8xf32>
-  %1 = tensor.insert_slice %t into %t2[%c0, %c0, %c0][%idx, %idx, %idx][%c1, %c1, %c1]
-    : tensor<8x16x4xf32> into tensor<16x32x8xf32>
-
-  // CHECK: tensor.insert_slice
-  // CHECK-SAME: tensor<8x16x4xf32> into tensor<16x32x8xf32>
-  %2 = tensor.insert_slice %t into %t2[%c0, %idx, %c0][%idx, 4, %idx][%c1, 1, %c1]
-    : tensor<8x16x4xf32> into tensor<16x32x8xf32>
-
-  // CHECK: tensor.insert_slice
-  // CHECK-SAME: tensor<4x4xf32> into tensor<8x16x4xf32>
-  %3 = tensor.insert_slice %t3 into %t[0, 2, 0][4, 1, 4][1, 1, 1]
-    : tensor<4x4xf32> into tensor<8x16x4xf32>
-
-  return
-}


        


More information about the Mlir-commits mailing list