[Mlir-commits] [mlir] 96e9b6c - Revert "[mlir] Rewrite canonicalization of collapse(expand) and expand(collapse)."

Hanhan Wang llvmlistbot at llvm.org
Tue Apr 5 15:06:05 PDT 2022


Author: Hanhan Wang
Date: 2022-04-05T15:05:41-07:00
New Revision: 96e9b6c9dc60946f08399def879a19395bc98107

URL: https://github.com/llvm/llvm-project/commit/96e9b6c9dc60946f08399def879a19395bc98107
DIFF: https://github.com/llvm/llvm-project/commit/96e9b6c9dc60946f08399def879a19395bc98107.diff

LOG: Revert "[mlir] Rewrite canonicalization of collapse(expand) and expand(collapse)."

This reverts commit 64f659bee67b5a024defeb3cd2ecf65e1ad8c0a7.

An invalid tensor.expand_shape op is generated with the commit. To repro:

$ mlir-opt -canonicalize a.mlir

```
func @foo(%0: tensor<1x1xf32>, %1: tensor<1x1xf32>, %2: tensor<1x1xf32>) -> tensor<1x1xf32> {
  %cst = arith.constant 0.000000e+00 : f32
  %3 = linalg.init_tensor [8, 1] : tensor<8x1xf32>
  %4 = linalg.fill ins(%cst : f32) outs(%3 : tensor<8x1xf32>) -> tensor<8x1xf32>
  %5 = tensor.collapse_shape %0 [] : tensor<1x1xf32> into tensor<f32>
  %6 = tensor.insert_slice %5 into %4[0, 0] [1, 1] [1, 1] : tensor<f32> into tensor<8x1xf32>
  %7 = linalg.init_tensor [8, 1] : tensor<8x1xf32>
  %8 = linalg.fill ins(%cst : f32) outs(%7 : tensor<8x1xf32>) -> tensor<8x1xf32>
  %9 = tensor.collapse_shape %2 [] : tensor<1x1xf32> into tensor<f32>
  %10 = tensor.insert_slice %9 into %8[0, 0] [1, 1] [1, 1] : tensor<f32> into tensor<8x1xf32>
  %11 = tensor.collapse_shape %6 [[0, 1]] : tensor<8x1xf32> into tensor<8xf32>
  %12 = linalg.init_tensor [8] : tensor<8xf32>
  %13 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%11 : tensor<8xf32>) outs(%12 : tensor<8xf32>) {
  ^bb0(%arg3: f32, %arg4: f32):
    linalg.yield %arg3 : f32
  } -> tensor<8xf32>
  %14 = tensor.expand_shape %13 [[0, 1, 2, 3]] : tensor<8xf32> into tensor<1x1x8x1xf32>
  %15 = tensor.collapse_shape %1 [] : tensor<1x1xf32> into tensor<f32>
  %16 = linalg.init_tensor [] : tensor<f32>
  %17 = linalg.generic {indexing_maps = [affine_map<() -> ()>, affine_map<() -> ()>], iterator_types = []} ins(%15 : tensor<f32>) outs(%16 : tensor<f32>) {
  ^bb0(%arg3: f32, %arg4: f32):
    linalg.yield %arg3 : f32
  } -> tensor<f32>
  %18 = tensor.expand_shape %17 [] : tensor<f32> into tensor<1x1x1x1xf32>
  %19 = tensor.collapse_shape %10 [[0, 1]] : tensor<8x1xf32> into tensor<8xf32>
  %20 = linalg.init_tensor [8] : tensor<8xf32>
  %21 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%19 : tensor<8xf32>) outs(%20 : tensor<8xf32>) {
  ^bb0(%arg3: f32, %arg4: f32):
    linalg.yield %arg3 : f32
  } -> tensor<8xf32>
  %22 = tensor.expand_shape %21 [[0, 1, 2, 3]] : tensor<8xf32> into tensor<1x1x8x1xf32>
  %23 = linalg.mmt4d {comment = "f32*f32->f32, aarch64, matrix*vector"} ins(%14, %18 : tensor<1x1x8x1xf32>, tensor<1x1x1x1xf32>) outs(%22 : tensor<1x1x8x1xf32>) -> tensor<1x1x8x1xf32>
  %24 = tensor.collapse_shape %23 [[0, 1, 2, 3]] : tensor<1x1x8x1xf32> into tensor<8xf32>
  %25 = linalg.init_tensor [8] : tensor<8xf32>
  %26 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%24 : tensor<8xf32>) outs(%25 : tensor<8xf32>) {
  ^bb0(%arg3: f32, %arg4: f32):
    linalg.yield %arg3 : f32
  } -> tensor<8xf32>
  %27 = tensor.expand_shape %26 [[0, 1]] : tensor<8xf32> into tensor<8x1xf32>
  %28 = tensor.extract_slice %27[0, 0] [1, 1] [1, 1] : tensor<8x1xf32> to tensor<f32>
  %29 = tensor.expand_shape %28 [] : tensor<f32> into tensor<1x1xf32>
  return %29 : tensor<1x1xf32>
}
```

Differential Revision: https://reviews.llvm.org/D123161

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
    mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
    mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
    mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
    mlir/test/Dialect/MemRef/canonicalize.mlir
    mlir/test/Dialect/Tensor/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
index e2b4c0742ffdf..dfeac25fd6c99 100644
--- a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
@@ -68,12 +68,6 @@ SmallVector<ReassociationIndices, 2> convertReassociationMapsToIndices(
 Optional<SmallVector<ReassociationIndices>>
 getReassociationIndicesForReshape(ShapedType sourceType, ShapedType targetType);
 
-/// Returns the reassociation maps to collapse `sourceShape` to `targetShape` if
-/// possible.
-Optional<SmallVector<ReassociationIndices>>
-getReassociationIndicesForCollapse(ArrayRef<int64_t> sourceShape,
-                                   ArrayRef<int64_t> targetShape);
-
 /// Return true if the reassociation specification is valid, false otherwise.
 /// When false, the `invalidIndex` integer pointer is optionally filled with the
 /// index of the offending reassociation map.
@@ -162,13 +156,10 @@ static LogicalResult verifyReshapeLikeShapes(OpTy op, ShapedType collapsedType,
       op.getReassociationIndices(), isExpandingReshape);
 }
 
-/// Returns true iff the type is a MemRefType and has a non-identity layout.
-bool hasNonIdentityLayout(Type type);
-
 /// Pattern to collapse producer/consumer reshape ops that are both collapsing
 /// dimensions or are both expanding dimensions.
 template <typename ReshapeOpTy>
-struct ComposeReassociativeReshapeOps : public OpRewritePattern<ReshapeOpTy> {
+struct CollapseReshapeOps : public OpRewritePattern<ReshapeOpTy> {
   using OpRewritePattern<ReshapeOpTy>::OpRewritePattern;
   LogicalResult matchAndRewrite(ReshapeOpTy reshapeOp,
                                 PatternRewriter &rewriter) const override {
@@ -177,12 +168,6 @@ struct ComposeReassociativeReshapeOps : public OpRewritePattern<ReshapeOpTy> {
       return failure();
 
     ShapedType resultType = reshapeOp.getResultType();
-
-    if (hasNonIdentityLayout(srcReshapeOp.src().getType()) ||
-        hasNonIdentityLayout(reshapeOp.src().getType()) ||
-        hasNonIdentityLayout(reshapeOp.result().getType()))
-      return failure();
-
     Optional<SmallVector<ReassociationIndices>> reassociationIndices =
         composeReassociationIndices(srcReshapeOp.getReassociationIndices(),
                                     reshapeOp.getReassociationIndices(),
@@ -195,180 +180,46 @@ struct ComposeReassociativeReshapeOps : public OpRewritePattern<ReshapeOpTy> {
   }
 };
 
-/// Pattern to compose
-/// `collapse_shape(expand_shape(%src, reassociation_1), reassociation_2)`.
-/// In that case both `srcType` and `resultType` can be expressed as a function
-/// of `intermediateType`.
-/// In order to demonstrate the approach, let's assume that `rank(srcType) >
-/// `rank(resultType)`, i.e. the resulting operation should be `collapse_shape`.
-/// In that case, we can iterate over every set of indices in `reassociation_2`
-/// and try to find ids of sets of indices in `reassociation_1` that cover it
-/// completely.
-///
-/// Example:
-///
-///   %0 = tensor.expand_shape %arg [[0], [1], [2, 3]]
-///     : tensor<?x?x?xi64> into tensor<?x?x?x1xi64>
-///   %1 = tensor.collapse_shape %0 [[0, 1], [2, 3]]
-///     : tensor<?x?x?x1xi64> into tensor<?x?xi64>
-///
-/// can be canonicalized into
-///
-///   %0 = tensor.collapse_shape %arg [[0, 1], [2]]
-///     : tensor<?x?x?xi64> into tensor<?x?xi64>
-///
-/// because [0] and [1] from `expand_shape` reassociation cover completely
-/// `[0, 1]` from `collapse_shape`. If it is impossible to find such union of
-/// indices, then we fail.
-//
-/// When `rank(srcType) < rank(resultType)`, then we just swap `reassociation_1`
-/// `reassociation_2` and produce `expand_shape`.
-template <typename CollapseOpTy, typename ExpandOpTy>
-struct ComposeCollapseOfExpandOp : public OpRewritePattern<CollapseOpTy> {
-  using OpRewritePattern<CollapseOpTy>::OpRewritePattern;
-  LogicalResult matchAndRewrite(CollapseOpTy collapseOp,
+/// Pattern to collapse producer/consumer reshape ops that are both collapsing
+/// dimensions or are both expanding dimensions.
+template <typename ReshapeOpTy, typename InverseReshapeOpTy>
+struct CollapseMixedReshapeOps : public OpRewritePattern<ReshapeOpTy> {
+  using OpRewritePattern<ReshapeOpTy>::OpRewritePattern;
+  LogicalResult matchAndRewrite(ReshapeOpTy reshapeOp,
                                 PatternRewriter &rewriter) const override {
-    auto expandOp = collapseOp.src().template getDefiningOp<ExpandOpTy>();
-    if (!expandOp)
+    auto srcReshapeOp =
+        reshapeOp.src().template getDefiningOp<InverseReshapeOpTy>();
+    if (!srcReshapeOp)
       return failure();
 
-    ShapedType srcType = expandOp.getSrcType();
-    ShapedType resultType = collapseOp.getResultType();
-
-    if (hasNonIdentityLayout(collapseOp.src().getType()) ||
-        hasNonIdentityLayout(expandOp.src().getType()) ||
-        hasNonIdentityLayout(expandOp.result().getType()))
-      return failure();
+    ShapedType srcReshapeSrcType = srcReshapeOp.getSrcType();
+    ShapedType intermediateType = reshapeOp.getSrcType();
+    ShapedType resultType = reshapeOp.getResultType();
 
-    int64_t srcRank = srcType.getRank();
-    int64_t resultRank = resultType.getRank();
-    if (srcType == resultType)
+    // If the source reshape can be collapsed/expanded into the target reshape
+    // they can still be folded. This can only be reasoned about statically
+    // for cases where
+    // - either all shapes are static, or
+    // - The number of dynamic dimensions matches in the source of source and
+    //   result with all other dimensions being 1.
+    Optional<SmallVector<ReassociationIndices>> reassociationIndices =
+        getReassociationIndicesForReshape(srcReshapeSrcType, resultType);
+    if (!reassociationIndices)
       return failure();
-
-    SmallVector<ReassociationIndices, 4> higherRankReassociation,
-        lowerRankReassociation;
-
-    bool isResultCollapsed = srcRank > resultRank;
-    if (isResultCollapsed) {
-      higherRankReassociation = expandOp.getReassociationIndices();
-      lowerRankReassociation = collapseOp.getReassociationIndices();
-    } else {
-      higherRankReassociation = collapseOp.getReassociationIndices();
-      lowerRankReassociation = expandOp.getReassociationIndices();
-    }
-
-    size_t higherRankIndicesID = 0;
-    SmallVector<ReassociationIndices, 4> composedReassociation;
-    for (const auto &lowerRankIndices : lowerRankReassociation) {
-      ReassociationIndices composedIndices;
-      while (higherRankIndicesID < higherRankReassociation.size()) {
-        auto rightmostIndex =
-            higherRankReassociation[higherRankIndicesID].back();
-        if (rightmostIndex > lowerRankIndices.back())
-          return failure();
-        composedIndices.push_back(higherRankIndicesID++);
-        if (rightmostIndex == lowerRankIndices.back())
-          break;
-      }
-      composedReassociation.push_back(composedIndices);
-    }
-    if (isResultCollapsed)
-      rewriter.replaceOpWithNewOp<CollapseOpTy>(
-          collapseOp, resultType, expandOp.src(), composedReassociation);
+    bool originalOpExpands =
+        intermediateType.getRank() > srcReshapeSrcType.getRank();
+    bool resultingOpExpands =
+        resultType.getRank() > srcReshapeSrcType.getRank();
+    if (!(resultingOpExpands ^ originalOpExpands))
+      rewriter.replaceOpWithNewOp<InverseReshapeOpTy>(
+          reshapeOp, resultType, srcReshapeOp.src(), *reassociationIndices);
     else
-      rewriter.replaceOpWithNewOp<ExpandOpTy>(
-          collapseOp, resultType, expandOp.src(), composedReassociation);
+      rewriter.replaceOpWithNewOp<ReshapeOpTy>(
+          reshapeOp, resultType, srcReshapeOp.src(), *reassociationIndices);
     return success();
   }
 };
 
-template <typename ExpandOpTy, typename CollapseOpTy>
-struct ComposeExpandOfCollapseOp : public OpRewritePattern<ExpandOpTy> {
-  using OpRewritePattern<ExpandOpTy>::OpRewritePattern;
-  LogicalResult matchAndRewrite(ExpandOpTy expandOp,
-                                PatternRewriter &rewriter) const override {
-    auto collapseOp = expandOp.src().template getDefiningOp<CollapseOpTy>();
-    if (!collapseOp)
-      return failure();
-
-    ShapedType srcType = collapseOp.getSrcType();
-    ShapedType resultType = expandOp.getResultType();
-
-    if (hasNonIdentityLayout(expandOp.src().getType()) ||
-        hasNonIdentityLayout(collapseOp.src().getType()) ||
-        hasNonIdentityLayout(collapseOp.result().getType()))
-      return failure();
-
-    int64_t srcRank = srcType.getRank();
-    int64_t resultRank = resultType.getRank();
-    if (srcType == resultType)
-      return failure();
-
-    auto srcReassociation = collapseOp.getReassociationIndices();
-    auto resultReassociation = expandOp.getReassociationIndices();
-    if (srcRank > resultRank) {
-      auto composedReassociation = findCollapsingReassociation(
-          srcReassociation, resultReassociation, srcType.getShape(),
-          resultType.getShape());
-      if (!composedReassociation.hasValue())
-        return failure();
-
-      rewriter.replaceOpWithNewOp<CollapseOpTy>(
-          expandOp, resultType, collapseOp.src(), *composedReassociation);
-      return success();
-    }
-    auto composedReassociation =
-        findCollapsingReassociation(resultReassociation, srcReassociation,
-                                    resultType.getShape(), srcType.getShape());
-    if (!composedReassociation.hasValue())
-      return failure();
-
-    rewriter.replaceOpWithNewOp<ExpandOpTy>(
-        expandOp, resultType, collapseOp.src(), *composedReassociation);
-    return success();
-  }
-
-private:
-  // Attempts to find a way to collapse `srcShape` to `resultShape` by
-  // collapsing subshapes defined by the reassociation indices.
-  Optional<SmallVector<ReassociationIndices>> findCollapsingReassociation(
-      ArrayRef<ReassociationIndices> srcReassociation,
-      ArrayRef<ReassociationIndices> resultReassociation,
-      ArrayRef<int64_t> srcShape, ArrayRef<int64_t> resultShape) const {
-    SmallVector<ReassociationIndices, 4> composedReassociation;
-
-    for (auto item : llvm::zip(srcReassociation, resultReassociation)) {
-      auto &srcIndices = std::get<0>(item);
-      auto &resultIndices = std::get<1>(item);
-      auto srcSubShape = srcShape.slice(srcIndices.front(), srcIndices.size());
-      auto resultSubShape =
-          resultShape.slice(resultIndices.front(), resultIndices.size());
-
-      if (srcSubShape.size() == resultSubShape.size()) {
-        if (srcSubShape == resultSubShape)
-          composedReassociation.push_back(srcIndices);
-        else
-          return llvm::None;
-      }
-
-      // Find reassociation to collapse `srcSubShape` into `resultSubShape`.
-      auto subShapeReassociation =
-          getReassociationIndicesForCollapse(srcSubShape, resultSubShape);
-      if (!subShapeReassociation.hasValue())
-        return llvm::None;
-
-      // Remap the subshape indices back to the original srcShape.
-      for (auto &subshape_indices : *subShapeReassociation) {
-        ReassociationIndices shape_indices;
-        for (int64_t index : subshape_indices)
-          shape_indices.push_back(srcIndices.front() + index);
-        composedReassociation.push_back(shape_indices);
-      }
-    }
-    return {std::move(composedReassociation)};
-  }
-};
-
 } // namespace mlir
 
 #endif // MLIR_DIALECT_UTILS_RESHAPEOPSUTILS_H

diff  --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index bb36f3d00d179..5a8bd2b8dd551 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -1793,9 +1793,8 @@ LogicalResult ExpandShapeOp::verify() {
 
 void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                                 MLIRContext *context) {
-  results.add<ComposeReassociativeReshapeOps<ExpandShapeOp>,
-              ComposeExpandOfCollapseOp<ExpandShapeOp, CollapseShapeOp>>(
-      context);
+  results.add<CollapseReshapeOps<ExpandShapeOp>,
+              CollapseMixedReshapeOps<ExpandShapeOp, CollapseShapeOp>>(context);
 }
 
 /// Compute the layout map after collapsing a given source MemRef type with the
@@ -2000,8 +1999,8 @@ struct CollapseShapeOpMemRefCastFolder
 
 void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                                   MLIRContext *context) {
-  results.add<ComposeReassociativeReshapeOps<CollapseShapeOp>,
-              ComposeCollapseOfExpandOp<CollapseShapeOp, ExpandShapeOp>,
+  results.add<CollapseReshapeOps<CollapseShapeOp>,
+              CollapseMixedReshapeOps<CollapseShapeOp, ExpandShapeOp>,
               CollapseShapeOpMemRefCastFolder>(context);
 }
 

diff  --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 5b52a3fdd24ce..1c8065ec88095 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -890,16 +890,16 @@ struct FoldReshapeWithFromElements : OpRewritePattern<TensorReshapeOp> {
 
 void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                                 MLIRContext *context) {
-  results.add<ComposeReassociativeReshapeOps<ExpandShapeOp>,
-              ComposeExpandOfCollapseOp<ExpandShapeOp, CollapseShapeOp>,
+  results.add<CollapseReshapeOps<ExpandShapeOp>,
+              CollapseMixedReshapeOps<ExpandShapeOp, CollapseShapeOp>,
               FoldReshapeWithConstant<ExpandShapeOp>,
               FoldReshapeWithFromElements<ExpandShapeOp>>(context);
 }
 
 void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                                   MLIRContext *context) {
-  results.add<ComposeReassociativeReshapeOps<CollapseShapeOp>,
-              ComposeCollapseOfExpandOp<CollapseShapeOp, ExpandShapeOp>,
+  results.add<CollapseReshapeOps<CollapseShapeOp>,
+              CollapseMixedReshapeOps<CollapseShapeOp, ExpandShapeOp>,
               FoldReshapeWithConstant<CollapseShapeOp>,
               FoldReshapeWithFromElements<CollapseShapeOp>>(context);
 }

diff  --git a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
index 64937be9fac05..03cd3af2e7bec 100644
--- a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
+++ b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
@@ -18,23 +18,18 @@ using namespace mlir;
 Optional<SmallVector<ReassociationIndices>>
 mlir::getReassociationIndicesForReshape(ShapedType sourceType,
                                         ShapedType targetType) {
-  if (sourceType.getRank() > targetType.getRank())
-    return getReassociationIndicesForCollapse(sourceType.getShape(),
-                                              targetType.getShape());
+  // Make the sourceType greater rank than the targetType. If they are same
+  // rank, then its an unsupported reshape op.
+  if (sourceType.getRank() == targetType.getRank())
+    return llvm::None;
   if (sourceType.getRank() < targetType.getRank())
-    return getReassociationIndicesForCollapse(targetType.getShape(),
-                                              sourceType.getShape());
-  return llvm::None;
-}
+    std::swap(sourceType, targetType);
 
-Optional<SmallVector<ReassociationIndices>>
-mlir::getReassociationIndicesForCollapse(ArrayRef<int64_t> sourceShape,
-                                         ArrayRef<int64_t> targetShape) {
-  if (sourceShape.size() <= targetShape.size())
-    return llvm::None;
+  ArrayRef<int64_t> sourceShape = sourceType.getShape();
+  ArrayRef<int64_t> targetShape = targetType.getShape();
   unsigned sourceDim = 0;
   SmallVector<ReassociationIndices> reassociationMap;
-  reassociationMap.reserve(targetShape.size());
+  reassociationMap.reserve(targetType.getRank());
 
   ReassociationIndices currIndices;
   int64_t prodOfCollapsedDims = 1;
@@ -42,7 +37,7 @@ mlir::getReassociationIndicesForCollapse(ArrayRef<int64_t> sourceShape,
     unsigned targetDim = reassociationMap.size();
     // If we have mapped all the target dimensions stop and handle the remaining
     // tail of size-1 dimensions explictly.
-    if (targetDim == targetShape.size())
+    if (targetDim == targetType.getRank())
       break;
 
     int64_t currTargetShape = targetShape[targetDim];
@@ -192,7 +187,6 @@ mlir::getSymbolLessAffineMaps(ArrayRef<ReassociationExprs> reassociation) {
   }
   return maps;
 }
-
 bool mlir::isReassociationValid(ArrayRef<AffineMap> reassociation,
                                 int *invalidIndex) {
   if (reassociation.empty())
@@ -264,9 +258,3 @@ LogicalResult mlir::reshapeLikeShapesAreCompatible(
   }
   return success();
 }
-
-bool mlir::hasNonIdentityLayout(Type type) {
-  if (auto memrefType = type.dyn_cast<MemRefType>())
-    return !memrefType.getLayout().isIdentity();
-  return false;
-}

diff  --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index 1a01460a24dc9..8a4f80e77b61f 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -302,20 +302,20 @@ func @allocator(%arg0 : memref<memref<?xi32>>, %arg1 : index)  {
 
 // -----
 
-func @compose_collapse_of_collapse_zero_dim(%arg0 : memref<1x1x1xf32>)
-    -> memref<f32> {
+func @collapsing_memref_reshapes_to_zero_dim(%arg0 : memref<1x1x1xf32>)
+                                             -> memref<f32> {
   %0 = memref.collapse_shape %arg0 [[0, 1, 2]]
       : memref<1x1x1xf32> into memref<1xf32>
   %1 = memref.collapse_shape %0 [] : memref<1xf32> into memref<f32>
   return %1 : memref<f32>
 }
-// CHECK-LABEL: func @compose_collapse_of_collapse_zero_dim
+// CHECK-LABEL: collapsing_memref_reshapes_to_zero
 //       CHECK:   memref.collapse_shape %{{.*}} []
 //  CHECK-SAME:     memref<1x1x1xf32> into memref<f32>
 
 // -----
 
-func @compose_collapse_of_collapse(%arg0 : memref<?x?x?x?x?xf32>)
+func @collapsing_memref_reshapes(%arg0 : memref<?x?x?x?x?xf32>)
     -> memref<?x?xf32> {
   %0 = memref.collapse_shape %arg0 [[0, 1], [2], [3, 4]]
       : memref<?x?x?x?x?xf32> into memref<?x?x?xf32>
@@ -323,30 +323,13 @@ func @compose_collapse_of_collapse(%arg0 : memref<?x?x?x?x?xf32>)
       : memref<?x?x?xf32> into memref<?x?xf32>
   return %1 : memref<?x?xf32>
 }
-// CHECK-LABEL: func @compose_collapse_of_collapse
+// CHECK-LABEL: collapsing_memref_reshapes
 //       CHECK:   memref.collapse_shape %{{.*}} {{\[}}[0, 1, 2], [3, 4]]
 //   CHECK-NOT:   memref.collapse_shape
 
 // -----
 
-func @do_not_compose_collapse_of_expand_non_identity_layout(
-    %arg0: memref<?x?xf32, offset : 0, strides : [?, 1]>)
-    -> memref<?xf32> {
-  %1 = memref.expand_shape %arg0 [[0, 1], [2]] :
-    memref<?x?xf32, offset : 0, strides : [?, 1]> into
-    memref<?x4x?xf32, offset : 0, strides : [?, ?, 1]>
-  %2 = memref.collapse_shape %1 [[0, 1, 2]] :
-    memref<?x4x?xf32, offset : 0, strides : [?, ?, 1]> into
-    memref<?xf32>
-  return %2 : memref<?xf32>
-}
-// CHECK-LABEL: func @do_not_compose_collapse_of_expand_non_identity_layout
-// CHECK: expand
-// CHECK: collapse
-
-// -----
-
-func @compose_expand_of_expand(%arg0 : memref<?x?xf32>)
+func @expanding_memref_reshapes(%arg0 : memref<?x?xf32>)
     -> memref<?x6x4x5x?xf32> {
   %0 = memref.expand_shape %arg0 [[0, 1], [2]]
       : memref<?x?xf32> into memref<?x4x?xf32>
@@ -354,46 +337,45 @@ func @compose_expand_of_expand(%arg0 : memref<?x?xf32>)
       : memref<?x4x?xf32> into memref<?x6x4x5x?xf32>
   return %1 : memref<?x6x4x5x?xf32>
 }
-// CHECK-LABEL: func @compose_expand_of_expand
+// CHECK-LABEL: expanding_memref_reshapes
 //       CHECK:   memref.expand_shape %{{.*}} {{\[}}[0, 1, 2], [3, 4]]
 //   CHECK-NOT:   memref.expand_shape
 
 // -----
 
-func @compose_expand_of_expand_of_zero_dim(%arg0 : memref<f32>)
-    -> memref<1x1x1xf32> {
+func @expanding_memref_reshapes_to_zero_dim(%arg0 : memref<f32>)
+                                             -> memref<1x1x1xf32> {
   %0 = memref.expand_shape %arg0 [] : memref<f32> into memref<1xf32>
   %1 = memref.expand_shape %0 [[0, 1, 2]]
       : memref<1xf32> into memref<1x1x1xf32>
   return %1 : memref<1x1x1xf32>
 }
-// CHECK-LABEL: func @compose_expand_of_expand_of_zero_dim
+// CHECK-LABEL: expanding_memref_reshapes_to_zero
 //       CHECK:   memref.expand_shape %{{.*}} []
 //  CHECK-SAME:     memref<f32> into memref<1x1x1xf32>
 
 // -----
 
-func @fold_collapse_of_expand(%arg0 : memref<12x4xf32>) -> memref<12x4xf32> {
+func @fold_memref_reshape(%arg0 : memref<12x4xf32>) -> memref<12x4xf32> {
   %0 = memref.expand_shape %arg0 [[0, 1], [2]]
       : memref<12x4xf32> into memref<3x4x4xf32>
   %1 = memref.collapse_shape %0 [[0, 1], [2]]
       : memref<3x4x4xf32> into memref<12x4xf32>
   return %1 : memref<12x4xf32>
 }
-// CHECK-LABEL: func @fold_collapse_of_expand
+// CHECK-LABEL: @fold_memref_reshape
 //   CHECK-NOT:   linalg.{{.*}}_shape
 
 // -----
 
-func @fold_collapse_collapse_of_expand(%arg0 : memref<?x?xf32>)
-    -> memref<?x?xf32> {
+func @fold_memref_reshape_dynamic(%arg0 : memref<?x?xf32>) -> memref<?x?xf32> {
   %0 = memref.expand_shape %arg0 [[0, 1], [2]]
       : memref<?x?xf32> into memref<?x4x?xf32>
   %1 = memref.collapse_shape %0 [[0, 1], [2]]
       : memref<?x4x?xf32> into memref<?x?xf32>
   return %1 : memref<?x?xf32>
 }
-// CHECK-LABEL: @fold_collapse_collapse_of_expand
+// CHECK-LABEL: @fold_memref_reshape_dynamic
 //   CHECK-NOT:   linalg.{{.*}}_shape
 
 // -----

diff  --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 9996b9776c4d5..22770c2e67342 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -646,7 +646,7 @@ func @fold_overlapping_insert(%input : tensor<?x?x?xf32>, %slice1: tensor<4x?x8x
 
 // -----
 
-func @compose_expand_of_expand(%arg0 : tensor<?x?xf32>)
+func @expanding_tensor_reshapes(%arg0 : tensor<?x?xf32>)
     -> tensor<?x6x4x?x5xf32> {
   %0 = tensor.expand_shape %arg0 [[0, 1], [2]]
       : tensor<?x?xf32> into tensor<?x4x?xf32>
@@ -654,51 +654,49 @@ func @compose_expand_of_expand(%arg0 : tensor<?x?xf32>)
       : tensor<?x4x?xf32> into tensor<?x6x4x?x5xf32>
   return %1 : tensor<?x6x4x?x5xf32>
 }
-// CHECK-LABEL: compose_expand_of_expand
+// CHECK-LABEL: expanding_tensor_reshapes
 //       CHECK:   tensor.expand_shape %{{.*}} {{\[}}[0, 1, 2], [3, 4]]
 //   CHECK-NOT:   tensor.expand_shape
 
 // -----
 
-func @compose_expand_of_expand_of_zero_dim(%arg0 : tensor<f32>)
+func @expanding_tensor_reshapes_to_zero_dim(%arg0 : tensor<f32>)
     -> tensor<1x1x1xf32> {
   %0 = tensor.expand_shape %arg0 [] : tensor<f32> into tensor<1xf32>
   %1 = tensor.expand_shape %0 [[0, 1, 2]]
       : tensor<1xf32> into tensor<1x1x1xf32>
   return %1 : tensor<1x1x1xf32>
 }
-// CHECK-LABEL: compose_expand_of_expand_of_zero_dim
+// CHECK-LABEL: expanding_tensor_reshapes_to_zero
 //       CHECK:   tensor.expand_shape %{{.*}} []
 //  CHECK-SAME:     tensor<f32> into tensor<1x1x1xf32>
 
 // -----
 
-func @fold_collapse_of_expand(%arg0 : tensor<12x4xf32>) -> tensor<12x4xf32> {
+func @fold_tensor_reshape(%arg0 : tensor<12x4xf32>) -> tensor<12x4xf32> {
   %0 = tensor.expand_shape %arg0 [[0, 1], [2]]
       : tensor<12x4xf32> into tensor<3x4x4xf32>
   %1 = tensor.collapse_shape %0 [[0, 1], [2]]
       : tensor<3x4x4xf32> into tensor<12x4xf32>
   return %1 : tensor<12x4xf32>
 }
-// CHECK-LABEL: @fold_collapse_of_expand
+// CHECK-LABEL: @fold_tensor_reshape
 //   CHECK-NOT:   linalg.{{.*}}shape
 
 // -----
 
-func @fold_collapse_of_expand_dynamic(%arg0 : tensor<?x?xf32>)
-    -> tensor<?x?xf32> {
+func @fold_tensor_reshape_dynamic(%arg0 : tensor<?x?xf32>) -> tensor<?x?xf32> {
   %0 = tensor.expand_shape %arg0 [[0, 1], [2]]
       : tensor<?x?xf32> into tensor<?x4x?xf32>
   %1 = tensor.collapse_shape %0 [[0, 1], [2]]
       : tensor<?x4x?xf32> into tensor<?x?xf32>
   return %1 : tensor<?x?xf32>
 }
-// CHECK-LABEL: @fold_collapse_of_expand_dynamic
+// CHECK-LABEL: @fold_tensor_reshape_dynamic
 //   CHECK-NOT:   linalg.{{.*}}_shape
 
 // -----
-
-func @compose_expand_of_collapse(%arg0 : tensor<2x3x4x5x6x7x8xf32>)
+func @reshape_collapse(%arg0 : tensor<2x3x4x5x6x7x8xf32>)
     -> tensor<24x5x42x8xf32> {
   %0 = tensor.collapse_shape %arg0 [[0, 1, 2, 3, 4, 5, 6]]
       : tensor<2x3x4x5x6x7x8xf32> into tensor<40320xf32>
@@ -706,7 +704,7 @@ func @compose_expand_of_collapse(%arg0 : tensor<2x3x4x5x6x7x8xf32>)
       : tensor<40320xf32> into tensor<24x5x42x8xf32>
   return %1 : tensor<24x5x42x8xf32>
 }
-//      CHECK: func @compose_expand_of_collapse
+//      CHECK: func @reshape_collapse
 // CHECK-SAME:   %[[ARG0:.+]]: tensor<2x3x4x5x6x7x8xf32>
 //      CHECK:   %[[RESULT:.+]] = tensor.collapse_shape %[[ARG0]]
 // CHECK-SAME:     [0, 1, 2], [3], [4, 5], [6]
@@ -714,7 +712,7 @@ func @compose_expand_of_collapse(%arg0 : tensor<2x3x4x5x6x7x8xf32>)
 
 // -----
 
-func @compose_expand_of_collapse_7D(%arg0 : tensor<24x5x42x8xf32>)
+func @reshape_expand(%arg0 : tensor<24x5x42x8xf32>)
     -> tensor<2x3x4x5x6x7x8xf32> {
   %0 = tensor.collapse_shape %arg0 [[0, 1, 2, 3]]
       : tensor<24x5x42x8xf32> into tensor<40320xf32>
@@ -722,7 +720,7 @@ func @compose_expand_of_collapse_7D(%arg0 : tensor<24x5x42x8xf32>)
       : tensor<40320xf32> into tensor<2x3x4x5x6x7x8xf32>
   return %1 : tensor<2x3x4x5x6x7x8xf32>
 }
-//      CHECK: func @compose_expand_of_collapse_7D
+//      CHECK: func @reshape_expand
 // CHECK-SAME:   %[[ARG0:.+]]: tensor<24x5x42x8xf32>
 //      CHECK:   %[[RESULT:.+]] = tensor.expand_shape %[[ARG0]]
 // CHECK-SAME:     [0, 1, 2], [3], [4, 5], [6]
@@ -730,37 +728,20 @@ func @compose_expand_of_collapse_7D(%arg0 : tensor<24x5x42x8xf32>)
 
 // -----
 
-func @compose_collapse_of_expand(%arg : tensor<?x?x?xi64>)
-    -> tensor<?x?xi64> {
-  %0 = tensor.expand_shape %arg [[0], [1], [2, 3]]
-    : tensor<?x?x?xi64> into tensor<?x?x?x1xi64>
-  %1 = tensor.collapse_shape %0 [[0, 1], [2, 3]]
-    : tensor<?x?x?x1xi64> into tensor<?x?xi64>
-  return %1 : tensor<?x?xi64>
-}
-// CHECK-LABEL: func @compose_collapse_of_expand
-//       CHECK:   (%[[ARG:.*]]: tensor<?x?x?xi64>)
-//  CHECK-NEXT: tensor.collapse_shape %[[ARG]]
-//  CHECK-SAME:   [0, 1], [2]
-//  CHECK-SAME:   : tensor<?x?x?xi64> into tensor<?x?xi64>
-
-// -----
-
-func @compose_collapse_of_expand_1D(%arg0 : tensor<2048xf32>)
-    -> tensor<4x512xf32> {
+func @expand_reshape_1D(%arg0 : tensor<2048xf32>) -> tensor<4x512xf32> {
   %0 = tensor.expand_shape %arg0 [[0, 1, 2, 3]]
     : tensor<2048xf32> into tensor<1x4x1x512xf32>
   %1 = tensor.collapse_shape %0 [[0, 1, 2], [3]]
     : tensor<1x4x1x512xf32> into tensor<4x512xf32>
   return %1 : tensor<4x512xf32>
 }
-//       CHECK: func @compose_collapse_of_expand_1D
+//       CHECK: func @expand_reshape_1D
 //       CHECK: tensor.expand_shape %{{.*}} {{\[}}[0, 1]]
 //  CHECK-SAME:   tensor<2048xf32> into tensor<4x512xf32>
 
 // -----
 
-// CHECK-LABEL: func @zero_rank_reshape_multi
+// CHECK-LABEL: zero_rank_reshape_multi
 func @zero_rank_reshape_multi(%arg0: tensor<f32>) -> tensor<f32> {
   // CHECK: return %arg0
   %0 = tensor.expand_shape %arg0 [] : tensor<f32> into tensor<1xf32>
@@ -771,7 +752,7 @@ func @zero_rank_reshape_multi(%arg0: tensor<f32>) -> tensor<f32> {
 
 // -----
 
-func @compose_collapse_of_collapse(%arg0 : tensor<?x?x?x?x?xf32>)
+func @collapsing_tensor_reshapes(%arg0 : tensor<?x?x?x?x?xf32>)
     -> tensor<?x?xf32> {
   %0 = tensor.collapse_shape %arg0 [[0, 1], [2], [3, 4]]
       : tensor<?x?x?x?x?xf32> into tensor<?x?x?xf32>
@@ -779,39 +760,39 @@ func @compose_collapse_of_collapse(%arg0 : tensor<?x?x?x?x?xf32>)
       : tensor<?x?x?xf32> into tensor<?x?xf32>
   return %1 : tensor<?x?xf32>
 }
-// CHECK-LABEL: func @compose_collapse_of_collapse
+// CHECK-LABEL: collapsing_tensor_reshapes
 //       CHECK:   tensor.collapse_shape %{{.*}} {{\[}}[0, 1, 2], [3, 4]]
 //   CHECK-NOT:   tensor.collapse_shape
 
 // -----
 
-func @compose_collapse_of_collapse_zero_dim(%arg0 : tensor<1x1x1xf32>)
+func @collapsing_tensor_reshapes_to_zero_dim(%arg0 : tensor<1x1x1xf32>)
     -> tensor<f32> {
   %0 = tensor.collapse_shape %arg0 [[0, 1, 2]]
       : tensor<1x1x1xf32> into tensor<1xf32>
   %1 = tensor.collapse_shape %0 [] : tensor<1xf32> into tensor<f32>
   return %1 : tensor<f32>
 }
-// CHECK-LABEL: func @compose_collapse_of_collapse_zero_dim
+// CHECK-LABEL: collapsing_tensor_reshapes_to_zero
 //       CHECK:   tensor.collapse_shape %{{.*}} []
 //  CHECK-SAME:     tensor<1x1x1xf32> into tensor<f32>
 
 // -----
 
-func @fold_collapse_of_expand_1D(%arg0 : tensor<4x512xf32>) -> tensor<2048xf32> {
+func @fold_reshape_1D(%arg0 : tensor<4x512xf32>) -> tensor<2048xf32> {
   %0 = tensor.expand_shape %arg0 [[0, 1, 2], [3]]
     : tensor<4x512xf32> into tensor<1x4x1x512xf32>
   %1 = tensor.collapse_shape %0 [[0, 1, 2, 3]]
     : tensor<1x4x1x512xf32> into tensor<2048xf32>
   return %1 : tensor<2048xf32>
 }
-//       CHECK: func @fold_collapse_of_expand_1D
+//       CHECK: func @fold_reshape_1D
 //       CHECK: tensor.collapse_shape %{{.*}} {{\[}}[0, 1]]
 //  CHECK-SAME:   tensor<4x512xf32> into tensor<2048xf32>
 
 // -----
 
-func @fold_collapse_of_expand_unit_dims(%arg0 : tensor<2048x1x1xf32>)
+func @fold_reshape_unit_dims(%arg0 : tensor<2048x1x1xf32>)
     -> tensor<4x512x1x1xf32> {
   %0 = tensor.expand_shape %arg0 [[0, 1, 2, 3], [4], [5]]
     : tensor<2048x1x1xf32> into tensor<1x4x1x512x1x1xf32>
@@ -819,13 +800,13 @@ func @fold_collapse_of_expand_unit_dims(%arg0 : tensor<2048x1x1xf32>)
     : tensor<1x4x1x512x1x1xf32> into tensor<4x512x1x1xf32>
   return %1 : tensor<4x512x1x1xf32>
 }
-//       CHECK: func @fold_collapse_of_expand_unit_dims
+//       CHECK: func @fold_reshape_unit_dims
 //       CHECK: tensor.expand_shape %{{.*}} {{\[}}[0, 1], [2], [3]]
 //  CHECK-SAME:   tensor<2048x1x1xf32> into tensor<4x512x1x1xf32>
 
 // -----
 
-func @compose_collapse_of_expand_unit_dims(%arg0 : tensor<2048x1x2048xf32>)
+func @expand_reshape_unit_dims(%arg0 : tensor<2048x1x2048xf32>)
     -> tensor<4x512x1x512x4xf32> {
   %0 = tensor.expand_shape %arg0 [[0, 1, 2, 3, 4], [5], [6, 7, 8]]
     : tensor<2048x1x2048xf32> into tensor<1x4x1x512x1x1x512x1x4xf32>
@@ -833,70 +814,69 @@ func @compose_collapse_of_expand_unit_dims(%arg0 : tensor<2048x1x2048xf32>)
     : tensor<1x4x1x512x1x1x512x1x4xf32> into tensor<4x512x1x512x4xf32>
   return %1 : tensor<4x512x1x512x4xf32>
 }
-//       CHECK: func @compose_collapse_of_expand_unit_dims
+//       CHECK: func @expand_reshape_unit_dims
 //       CHECK: tensor.expand_shape %{{.*}} {{\[}}[0, 1], [2], [3, 4]]
 //  CHECK-SAME:   tensor<2048x1x2048xf32> into tensor<4x512x1x512x4xf32>
 
 // -----
 
-func @compose_collapse_of_expand_trailing_unit_dims(%arg0: tensor<2xf32>)
-    -> tensor<2x1xf32> {
+func @fold_reshape_trailing_unit_dims(%arg0: tensor<2xf32>) -> tensor<2x1xf32> {
   %0 = tensor.expand_shape %arg0 [[0, 1, 2]]
       : tensor<2xf32> into tensor<2x1x1xf32>
   %1 = tensor.collapse_shape %0 [[0], [1, 2]]
       : tensor<2x1x1xf32> into tensor<2x1xf32>
   return %1 : tensor<2x1xf32>
 }
-//       CHECK: func @compose_collapse_of_expand_trailing_unit_dims
+//       CHECK: func @fold_reshape_trailing_unit_dims
 //       CHECK: tensor.expand_shape %{{.*}} {{\[}}[0, 1]]
 //  CHECK-SAME:   tensor<2xf32> into tensor<2x1xf32>
 
 // -----
 
-func @compose_collapse_of_collapse_unit_dims_dynamic(
-    %arg0 : tensor<?x1x?x1x1x?x?x1x1xf32>) -> tensor<?x?x?x?xf32> {
+func @collapse_reshape_unit_dims_dynamic(%arg0 : tensor<?x1x?x1x1x?x?x1x1xf32>)
+    -> tensor<?x?x?x?xf32> {
   %0 = tensor.collapse_shape %arg0 [[0], [1, 2], [3], [4], [5], [6, 7, 8]]
     : tensor<?x1x?x1x1x?x?x1x1xf32> into tensor<?x?x1x1x?x?xf32>
   %1 = tensor.collapse_shape %0 [[0], [1], [2, 3, 4], [5]]
     : tensor<?x?x1x1x?x?xf32> into tensor<?x?x?x?xf32>
   return %1 : tensor<?x?x?x?xf32>
 }
-//       CHECK: func @compose_collapse_of_collapse_unit_dims_dynamic
+//       CHECK: func @collapse_reshape_unit_dims_dynamic
 //       CHECK: tensor.collapse_shape
 //  CHECK-SAME:   [0], [1, 2], [3, 4, 5], [6, 7, 8]
 //  CHECK-SAME:   tensor<?x1x?x1x1x?x?x1x1xf32> into tensor<?x?x?x?xf32>
 
 // -----
 
-func @fold_collapse_of_expand_trailing_unit_dims(%arg0: tensor<2xf32>)
-    -> tensor<2x1xf32> {
+func @fold_reshape_trailing_unit_dims(%arg0: tensor<2xf32>) -> tensor<2x1xf32>
+{
   %0 = tensor.expand_shape %arg0 [[0, 1, 2]]
       : tensor<2xf32> into tensor<2x1x1xf32>
   %1 = tensor.collapse_shape %0 [[0], [1, 2]]
       : tensor<2x1x1xf32> into tensor<2x1xf32>
   return %1 : tensor<2x1xf32>
 }
-//       CHECK: func @fold_collapse_of_expand_trailing_unit_dims
+//       CHECK: func @fold_reshape_trailing_unit_dims
 //       CHECK: tensor.expand_shape %{{.*}} {{\[}}[0, 1]]
 //  CHECK-SAME:   tensor<2xf32> into tensor<2x1xf32>
 
 // -----
 
-func @fold_collapse_of_collapse_trailing_unit_dims_dynamic(
-    %arg0: tensor<1x1x?x1x1x1xf32>) -> tensor<?xf32> {
+func @fold_reshape_trailing_unit_dims_dynamic(%arg0: tensor<1x1x?x1x1x1xf32>)
+    -> tensor<?xf32> {
   %0 = tensor.collapse_shape %arg0 [[0, 1, 2], [3], [4], [5]]
       : tensor<1x1x?x1x1x1xf32> into tensor<?x1x1x1xf32>
   %1 = tensor.collapse_shape %0 [[0, 1, 2, 3]]
       : tensor<?x1x1x1xf32> into tensor<?xf32>
   return %1 : tensor<?xf32>
 }
-//       CHECK: func @fold_collapse_of_collapse_trailing_unit_dims_dynamic
+//       CHECK: func @fold_reshape_trailing_unit_dims_dynamic
 //       CHECK: tensor.collapse_shape %{{.*}} {{\[}}[0, 1, 2, 3, 4, 5]]
 //  CHECK-SAME:   tensor<1x1x?x1x1x1xf32> into tensor<?xf32>
 
 // -----
 
-func @fold_collapse_of_expand_trailing_unit_dims(%arg0: tensor<12x42x1x1xf32>)
+func @fold_reshape_trailing_unit_dims(%arg0: tensor<12x42x1x1xf32>)
     -> tensor<12x42xf32> {
   %0 = tensor.expand_shape %arg0 [[0], [1], [2], [3, 4]]
       : tensor<12x42x1x1xf32> into tensor<12x42x1x1x1xf32>
@@ -904,28 +884,27 @@ func @fold_collapse_of_expand_trailing_unit_dims(%arg0: tensor<12x42x1x1xf32>)
       : tensor<12x42x1x1x1xf32> into tensor<12x42xf32>
   return %1 : tensor<12x42xf32>
 }
-//       CHECK: func @fold_collapse_of_expand_trailing_unit_dims
+//       CHECK: func @fold_reshape_trailing_unit_dims
 //       CHECK: tensor.collapse_shape %{{.*}} {{\[}}[0], [1, 2, 3]]
 //  CHECK-SAME:   tensor<12x42x1x1xf32> into tensor<12x42xf32>
 
 // -----
 
-func @fold_collapse_of_expand_unit_dims_in_middle(%arg0 : tensor<?x?x?xf32>)
-    -> tensor<?x?xf32> {
+func @fold_reshapes_unit_dims_in_middle(%arg0 : tensor<?x?x?xf32>) -> tensor<?x?xf32> {
   %0 = tensor.expand_shape %arg0 [[0], [1], [2, 3]]
       : tensor<?x?x?xf32> into tensor<?x?x1x?xf32>
   %1 = tensor.collapse_shape %0 [[0], [1, 2, 3]]
       : tensor<?x?x1x?xf32> into tensor<?x?xf32>
   return %1 : tensor<?x?xf32>
 }
-// CHECK-LABEL: func @fold_collapse_of_expand_unit_dims_in_middle
+// CHECK-LABEL: func @fold_reshapes_unit_dims_in_middle
 //  CHECK-SAME: (%[[ARG:.*]]: tensor<?x?x?xf32>
 //       CHECK: tensor.collapse_shape %[[ARG]] {{\[}}[0], [1, 2]]
 //  CHECK-SAME:   tensor<?x?x?xf32> into tensor<?x?xf32>
 
 // -----
 
-func @no_fold_collapse_of_expand_incompatible(%arg0 : tensor<4x6x8xf32>)
+func @no_fold_reshape_incompatible(%arg0 : tensor<4x6x8xf32>)
     -> tensor<2x6x16xf32> {
   %0 = tensor.expand_shape %arg0 [[0, 1], [2, 3], [4]]
       : tensor<4x6x8xf32> into tensor<2x2x3x2x8xf32>
@@ -933,21 +912,20 @@ func @no_fold_collapse_of_expand_incompatible(%arg0 : tensor<4x6x8xf32>)
       : tensor<2x2x3x2x8xf32> into tensor<2x6x16xf32>
   return %1 : tensor<2x6x16xf32>
 }
-// CHECK-LABEL: func @no_fold_collapse_of_expand_incompatible
+// CHECK-LABEL: func @no_fold_reshape_incompatible
 //       CHECK:   tensor.expand_shape
 //       CHECK:   tensor.collapse_shape
 
 // -----
 
-func @no_fold_collapse_of_expand_empty_expr(%arg0: tensor<3x2x2xf32>)
-    -> tensor<12x1xf32> {
+func @no_fold_reshape_empty_expr(%arg0: tensor<3x2x2xf32>) -> tensor<12x1xf32> {
   %0 = tensor.expand_shape %arg0 [[0], [1], [2, 3]]
       : tensor<3x2x2xf32> into tensor<3x2x2x1xf32>
   %1 = tensor.collapse_shape %0 [[0, 1, 2], [3]]
       : tensor<3x2x2x1xf32> into tensor<12x1xf32>
   return %1 : tensor<12x1xf32>
 }
-//      CHECK: func @no_fold_collapse_of_expand_empty_expr
+//      CHECK: func @no_fold_reshape_empty_expr
 // CHECK-SAME:    %[[ARG0:.+]]: tensor<3x2x2xf32>
 //      CHECK:    %[[RARG0:.+]] = tensor.expand_shape %[[ARG0]]
 // CHECK-SAME:      [0], [1], [2, 3]
@@ -1024,11 +1002,11 @@ func @fold_rank() -> (index) {
 
 // -----
 
-// CHECK-LABEL: func @pad_same_static_shape(
+// CHECK-LABEL: func @pad_tensor_same_static_shape(
 //  CHECK-SAME:   %[[ARG0:.*]]: tensor<5x6xf32>
 //   CHECK-NOT:   tensor.pad
 //       CHECK:   return %[[ARG0]]
-func @pad_same_static_shape(%arg0: tensor<5x6xf32>, %a: index)
+func @pad_tensor_same_static_shape(%arg0: tensor<5x6xf32>, %a: index)
     -> tensor<5x6xf32> {
   %cst = arith.constant 0.000000e+00 : f32
   %0 = tensor.pad %arg0 low[%a, 0] high[0, %a] {
@@ -1040,11 +1018,11 @@ func @pad_same_static_shape(%arg0: tensor<5x6xf32>, %a: index)
 
 // -----
 
-// CHECK-LABEL: func @pad_nofold_same_static_shape(
+// CHECK-LABEL: func @pad_tensor_nofold_same_static_shape(
 //  CHECK-SAME:   %[[ARG0:.*]]: tensor<5x6xf32>
 //       CHECK:   %[[PAD:.*]] = tensor.pad
 //       CHECK:   return %[[PAD]]
-func @pad_nofold_same_static_shape(%arg0: tensor<5x6xf32>, %a: index)
+func @pad_tensor_nofold_same_static_shape(%arg0: tensor<5x6xf32>, %a: index)
     -> tensor<5x6xf32> {
   %cst = arith.constant 0.000000e+00 : f32
   %0 = tensor.pad %arg0 nofold low[%a, 0] high[0, %a] {
@@ -1056,7 +1034,7 @@ func @pad_nofold_same_static_shape(%arg0: tensor<5x6xf32>, %a: index)
 
 // -----
 
-// CHECK-LABEL:   func @pad_after_cast_
diff erent_shape(
+// CHECK-LABEL:   func @pad_tensor_after_cast_
diff erent_shape(
 // CHECK-SAME:      %[[INPUT:.*]]: tensor<?x64x?x?xf32>) -> tensor<?x?x?x?xf32> {
 // CHECK:           %[[CST:.*]] = arith.constant 0.000000e+00 : f32
 // CHECK:           %[[PADDED:.*]] = tensor.pad %[[INPUT]]
@@ -1068,7 +1046,7 @@ func @pad_nofold_same_static_shape(%arg0: tensor<5x6xf32>, %a: index)
 // CHECK-SAME:         tensor<?x64x?x?xf32> to tensor<?x?x?x?xf32>
 // CHECK:           return %[[DYNAMIC]] : tensor<?x?x?x?xf32>
 // CHECK:         }
-func @pad_after_cast_
diff erent_shape(%arg0: tensor<?x64x?x?xf32>)
+func @pad_tensor_after_cast_
diff erent_shape(%arg0: tensor<?x64x?x?xf32>)
     -> tensor<?x?x?x?xf32> {
   %cst = arith.constant 0.000000e+00 : f32
   %dynamic = tensor.cast %arg0 : tensor<?x64x?x?xf32> to tensor<?x?x?x?xf32>
@@ -1081,7 +1059,7 @@ func @pad_after_cast_
diff erent_shape(%arg0: tensor<?x64x?x?xf32>)
 
 // -----
 
-// CHECK-LABEL:   func @pad_after_cast_same_shape(
+// CHECK-LABEL:   func @pad_tensor_after_cast_same_shape(
 // CHECK-SAME:      %[[INPUT:.*]]: tensor<?x64x?x?xf32>,
 // CHECK-SAME:      %[[PADDING:.*]]: index) -> tensor<?x?x?x?xf32> {
 // CHECK:           %[[CST:.*]] = arith.constant 0.000000e+00 : f32
@@ -1092,7 +1070,7 @@ func @pad_after_cast_
diff erent_shape(%arg0: tensor<?x64x?x?xf32>)
 // CHECK:           } : tensor<?x64x?x?xf32> to tensor<?x?x?x?xf32>
 // CHECK:           return %[[PADDED:.*]] : tensor<?x?x?x?xf32>
 // CHECK:         }
-func @pad_after_cast_same_shape(%arg0: tensor<?x64x?x?xf32>, %padding : index)
+func @pad_tensor_after_cast_same_shape(%arg0: tensor<?x64x?x?xf32>, %padding : index)
     -> tensor<?x?x?x?xf32> {
   %cst = arith.constant 0.000000e+00 : f32
   %dynamic = tensor.cast %arg0 : tensor<?x64x?x?xf32> to tensor<?x?x?x?xf32>
@@ -1105,11 +1083,11 @@ func @pad_after_cast_same_shape(%arg0: tensor<?x64x?x?xf32>, %padding : index)
 
 // -----
 
-// CHECK-LABEL: func @pad_of_cast(
+// CHECK-LABEL: func @pad_tensor_of_cast(
 // CHECK-NOT:     tensor.cast
 // CHECK:         tensor.pad
 // CHECK:         tensor<8x?xf32> to tensor<8x32xf32>
-func @pad_of_cast(%t: tensor<8x?xf32>, %s: index) -> tensor<8x32xf32> {
+func @pad_tensor_of_cast(%t: tensor<8x?xf32>, %s: index) -> tensor<8x32xf32> {
   %c0 = arith.constant 0 : index
   %cst = arith.constant 0.000000e+00 : f32
   %0 = tensor.cast %t : tensor<8x?xf32> to tensor<?x?xf32>
@@ -1155,7 +1133,7 @@ func @cast_of_pad_less_static(%arg0: tensor<32x?x?xf32>, %padding: index) -> ten
 
 // -----
 
-func @pad_cast_fold(%arg0: tensor<4x4xf32>) -> tensor<4x4xf32> {
+func @tensor_pad_cast_fold(%arg0: tensor<4x4xf32>) -> tensor<4x4xf32> {
   %c0 = arith.constant 0 : index
   %cst = arith.constant 0.0 : f32
   %0 = tensor.cast %arg0 : tensor<4x4xf32> to tensor<?x?xf32>
@@ -1165,17 +1143,17 @@ func @pad_cast_fold(%arg0: tensor<4x4xf32>) -> tensor<4x4xf32> {
   } : tensor<?x?xf32> to tensor<4x4xf32>
   return %1 : tensor<4x4xf32>
 }
-// CHECK-LABEL: @pad_cast
+// CHECK-LABEL: @tensor_pad_cast
 // CHECK-SAME: %[[ARG0:.+]]: tensor<4x4xf32>
 // CHECK: return %[[ARG0]]
 
 // -----
 
-// CHECK-LABEL: func @fold_pad_source_cast(
+// CHECK-LABEL: func @fold_pad_tensor_source_cast(
 //  CHECK-SAME:                  %[[ARG0:.*]]: tensor<4x?xf32>
 //   CHECK-NOT:   tensor.cast
 //       CHECK:   %[[RESULT:.*]] = tensor.pad %[[ARG0]]
-func @fold_pad_source_cast(%arg0: tensor<4x?xf32>) -> tensor<4x4xf32> {
+func @fold_pad_tensor_source_cast(%arg0: tensor<4x?xf32>) -> tensor<4x4xf32> {
   %cst = arith.constant 0.0 : f32
   %0 = tensor.cast %arg0 : tensor<4x?xf32> to tensor<?x?xf32>
   %1 = tensor.pad %0 low[0, 0] high[0, 1]  {


        


More information about the Mlir-commits mailing list