[Mlir-commits] [mlir] 747b10b - Revert "Revert "[mlir] Rewrite canonicalization of collapse(expand) and expand(collapse).""

Alexander Belyaev llvmlistbot at llvm.org
Wed Apr 6 03:18:47 PDT 2022


Author: Alexander Belyaev
Date: 2022-04-06T12:18:30+02:00
New Revision: 747b10be95200a71f4e6b8ca9f0aaea20db1b164

URL: https://github.com/llvm/llvm-project/commit/747b10be95200a71f4e6b8ca9f0aaea20db1b164
DIFF: https://github.com/llvm/llvm-project/commit/747b10be95200a71f4e6b8ca9f0aaea20db1b164.diff

LOG: Revert "Revert "[mlir] Rewrite canonicalization of collapse(expand) and expand(collapse).""

This reverts commit 96e9b6c9dc60946f08399def879a19395bc98107.

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
    mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
    mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
    mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
    mlir/test/Dialect/MemRef/canonicalize.mlir
    mlir/test/Dialect/Tensor/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
index dfeac25fd6c99..706d11d49be6a 100644
--- a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
@@ -68,6 +68,12 @@ SmallVector<ReassociationIndices, 2> convertReassociationMapsToIndices(
 Optional<SmallVector<ReassociationIndices>>
 getReassociationIndicesForReshape(ShapedType sourceType, ShapedType targetType);
 
+/// Returns the reassociation maps to collapse `sourceShape` to `targetShape` if
+/// possible.
+Optional<SmallVector<ReassociationIndices>>
+getReassociationIndicesForCollapse(ArrayRef<int64_t> sourceShape,
+                                   ArrayRef<int64_t> targetShape);
+
 /// Return true if the reassociation specification is valid, false otherwise.
 /// When false, the `invalidIndex` integer pointer is optionally filled with the
 /// index of the offending reassociation map.
@@ -156,10 +162,13 @@ static LogicalResult verifyReshapeLikeShapes(OpTy op, ShapedType collapsedType,
       op.getReassociationIndices(), isExpandingReshape);
 }
 
+/// Returns true iff the type is a MemRefType and has a non-identity layout.
+bool hasNonIdentityLayout(Type type);
+
 /// Pattern to collapse producer/consumer reshape ops that are both collapsing
 /// dimensions or are both expanding dimensions.
 template <typename ReshapeOpTy>
-struct CollapseReshapeOps : public OpRewritePattern<ReshapeOpTy> {
+struct ComposeReassociativeReshapeOps : public OpRewritePattern<ReshapeOpTy> {
   using OpRewritePattern<ReshapeOpTy>::OpRewritePattern;
   LogicalResult matchAndRewrite(ReshapeOpTy reshapeOp,
                                 PatternRewriter &rewriter) const override {
@@ -168,6 +177,12 @@ struct CollapseReshapeOps : public OpRewritePattern<ReshapeOpTy> {
       return failure();
 
     ShapedType resultType = reshapeOp.getResultType();
+
+    if (hasNonIdentityLayout(srcReshapeOp.src().getType()) ||
+        hasNonIdentityLayout(reshapeOp.src().getType()) ||
+        hasNonIdentityLayout(reshapeOp.result().getType()))
+      return failure();
+
     Optional<SmallVector<ReassociationIndices>> reassociationIndices =
         composeReassociationIndices(srcReshapeOp.getReassociationIndices(),
                                     reshapeOp.getReassociationIndices(),
@@ -180,44 +195,181 @@ struct CollapseReshapeOps : public OpRewritePattern<ReshapeOpTy> {
   }
 };
 
-/// Pattern to collapse producer/consumer reshape ops that are both collapsing
-/// dimensions or are both expanding dimensions.
-template <typename ReshapeOpTy, typename InverseReshapeOpTy>
-struct CollapseMixedReshapeOps : public OpRewritePattern<ReshapeOpTy> {
-  using OpRewritePattern<ReshapeOpTy>::OpRewritePattern;
-  LogicalResult matchAndRewrite(ReshapeOpTy reshapeOp,
+/// Pattern to compose
+/// `collapse_shape(expand_shape(%src, reassociation_1), reassociation_2)`.
+/// In that case both `srcType` and `resultType` can be expressed as a function
+/// of `intermediateType`.
+/// In order to demonstrate the approach, let's assume that `rank(srcType) >
+/// `rank(resultType)`, i.e. the resulting operation should be `collapse_shape`.
+/// In that case, we can iterate over every set of indices in `reassociation_2`
+/// and try to find ids of sets of indices in `reassociation_1` that cover it
+/// completely.
+///
+/// Example:
+///
+///   %0 = tensor.expand_shape %arg [[0], [1], [2, 3]]
+///     : tensor<?x?x?xi64> into tensor<?x?x?x1xi64>
+///   %1 = tensor.collapse_shape %0 [[0, 1], [2, 3]]
+///     : tensor<?x?x?x1xi64> into tensor<?x?xi64>
+///
+/// can be canonicalized into
+///
+///   %0 = tensor.collapse_shape %arg [[0, 1], [2]]
+///     : tensor<?x?x?xi64> into tensor<?x?xi64>
+///
+/// because [0] and [1] from `expand_shape` reassociation cover completely
+/// `[0, 1]` from `collapse_shape`. If it is impossible to find such union of
+/// indices, then we fail.
+//
+/// When `rank(srcType) < rank(resultType)`, then we just swap `reassociation_1`
+/// `reassociation_2` and produce `expand_shape`.
+template <typename CollapseOpTy, typename ExpandOpTy>
+struct ComposeCollapseOfExpandOp : public OpRewritePattern<CollapseOpTy> {
+  using OpRewritePattern<CollapseOpTy>::OpRewritePattern;
+  LogicalResult matchAndRewrite(CollapseOpTy collapseOp,
                                 PatternRewriter &rewriter) const override {
-    auto srcReshapeOp =
-        reshapeOp.src().template getDefiningOp<InverseReshapeOpTy>();
-    if (!srcReshapeOp)
+    auto expandOp = collapseOp.src().template getDefiningOp<ExpandOpTy>();
+    if (!expandOp)
       return failure();
 
-    ShapedType srcReshapeSrcType = srcReshapeOp.getSrcType();
-    ShapedType intermediateType = reshapeOp.getSrcType();
-    ShapedType resultType = reshapeOp.getResultType();
+    ShapedType srcType = expandOp.getSrcType();
+    ShapedType resultType = collapseOp.getResultType();
 
-    // If the source reshape can be collapsed/expanded into the target reshape
-    // they can still be folded. This can only be reasoned about statically
-    // for cases where
-    // - either all shapes are static, or
-    // - The number of dynamic dimensions matches in the source of source and
-    //   result with all other dimensions being 1.
-    Optional<SmallVector<ReassociationIndices>> reassociationIndices =
-        getReassociationIndicesForReshape(srcReshapeSrcType, resultType);
-    if (!reassociationIndices)
+    if (hasNonIdentityLayout(collapseOp.src().getType()) ||
+        hasNonIdentityLayout(expandOp.src().getType()) ||
+        hasNonIdentityLayout(expandOp.result().getType()))
       return failure();
-    bool originalOpExpands =
-        intermediateType.getRank() > srcReshapeSrcType.getRank();
-    bool resultingOpExpands =
-        resultType.getRank() > srcReshapeSrcType.getRank();
-    if (!(resultingOpExpands ^ originalOpExpands))
-      rewriter.replaceOpWithNewOp<InverseReshapeOpTy>(
-          reshapeOp, resultType, srcReshapeOp.src(), *reassociationIndices);
+
+    int64_t srcRank = srcType.getRank();
+    int64_t resultRank = resultType.getRank();
+    if (srcType == resultType)
+      return failure();
+
+    SmallVector<ReassociationIndices, 4> higherRankReassociation,
+        lowerRankReassociation;
+
+    bool isResultCollapsed = srcRank > resultRank;
+    if (isResultCollapsed) {
+      higherRankReassociation = expandOp.getReassociationIndices();
+      lowerRankReassociation = collapseOp.getReassociationIndices();
+    } else {
+      higherRankReassociation = collapseOp.getReassociationIndices();
+      lowerRankReassociation = expandOp.getReassociationIndices();
+    }
+
+    size_t higherRankIndicesID = 0;
+    SmallVector<ReassociationIndices, 4> composedReassociation;
+    for (const auto &lowerRankIndices : lowerRankReassociation) {
+      ReassociationIndices composedIndices;
+      while (higherRankIndicesID < higherRankReassociation.size()) {
+        auto rightmostIndex =
+            higherRankReassociation[higherRankIndicesID].back();
+        if (rightmostIndex > lowerRankIndices.back())
+          return failure();
+        composedIndices.push_back(higherRankIndicesID++);
+        if (rightmostIndex == lowerRankIndices.back())
+          break;
+      }
+      composedReassociation.push_back(composedIndices);
+    }
+    if (isResultCollapsed)
+      rewriter.replaceOpWithNewOp<CollapseOpTy>(
+          collapseOp, resultType, expandOp.src(), composedReassociation);
     else
-      rewriter.replaceOpWithNewOp<ReshapeOpTy>(
-          reshapeOp, resultType, srcReshapeOp.src(), *reassociationIndices);
+      rewriter.replaceOpWithNewOp<ExpandOpTy>(
+          collapseOp, resultType, expandOp.src(), composedReassociation);
+    return success();
+  }
+};
+
+template <typename ExpandOpTy, typename CollapseOpTy>
+struct ComposeExpandOfCollapseOp : public OpRewritePattern<ExpandOpTy> {
+  using OpRewritePattern<ExpandOpTy>::OpRewritePattern;
+  LogicalResult matchAndRewrite(ExpandOpTy expandOp,
+                                PatternRewriter &rewriter) const override {
+    auto collapseOp = expandOp.src().template getDefiningOp<CollapseOpTy>();
+    if (!collapseOp)
+      return failure();
+
+    ShapedType srcType = collapseOp.getSrcType();
+    ShapedType resultType = expandOp.getResultType();
+
+    if (hasNonIdentityLayout(expandOp.src().getType()) ||
+        hasNonIdentityLayout(collapseOp.src().getType()) ||
+        hasNonIdentityLayout(collapseOp.result().getType()))
+      return failure();
+
+    int64_t srcRank = srcType.getRank();
+    int64_t resultRank = resultType.getRank();
+    if (srcType == resultType)
+      return failure();
+
+    auto srcReassociation = collapseOp.getReassociationIndices();
+    auto resultReassociation = expandOp.getReassociationIndices();
+    if (srcRank > resultRank) {
+      auto composedReassociation = findCollapsingReassociation(
+          srcReassociation, resultReassociation, srcType.getShape(),
+          resultType.getShape());
+      if (!composedReassociation.hasValue())
+        return failure();
+
+      rewriter.replaceOpWithNewOp<CollapseOpTy>(
+          expandOp, resultType, collapseOp.src(), *composedReassociation);
+      return success();
+    }
+    auto composedReassociation =
+        findCollapsingReassociation(resultReassociation, srcReassociation,
+                                    resultType.getShape(), srcType.getShape());
+    if (!composedReassociation.hasValue())
+      return failure();
+
+    rewriter.replaceOpWithNewOp<ExpandOpTy>(
+        expandOp, resultType, collapseOp.src(), *composedReassociation);
     return success();
   }
+
+private:
+  // Attempts to find a way to collapse `srcShape` to `resultShape` by
+  // collapsing subshapes defined by the reassociation indices.
+  Optional<SmallVector<ReassociationIndices>> findCollapsingReassociation(
+      ArrayRef<ReassociationIndices> srcReassociation,
+      ArrayRef<ReassociationIndices> resultReassociation,
+      ArrayRef<int64_t> srcShape, ArrayRef<int64_t> resultShape) const {
+    SmallVector<ReassociationIndices, 4> composedReassociation;
+
+    if (srcReassociation.empty())
+      return {getReassociationIndicesForCollapse(srcShape, resultShape)};
+
+    for (auto item : llvm::zip(srcReassociation, resultReassociation)) {
+      auto &srcIndices = std::get<0>(item);
+      auto &resultIndices = std::get<1>(item);
+      auto srcSubShape = srcShape.slice(srcIndices.front(), srcIndices.size());
+      auto resultSubShape =
+          resultShape.slice(resultIndices.front(), resultIndices.size());
+
+      if (srcSubShape.size() == resultSubShape.size()) {
+        if (srcSubShape == resultSubShape)
+          composedReassociation.push_back(srcIndices);
+        else
+          return llvm::None;
+      }
+
+      // Find reassociation to collapse `srcSubShape` into `resultSubShape`.
+      auto subShapeReassociation =
+          getReassociationIndicesForCollapse(srcSubShape, resultSubShape);
+      if (!subShapeReassociation.hasValue())
+        return llvm::None;
+
+      // Remap the subshape indices back to the original srcShape.
+      for (auto &subshape_indices : *subShapeReassociation) {
+        ReassociationIndices shape_indices;
+        for (int64_t index : subshape_indices)
+          shape_indices.push_back(srcIndices.front() + index);
+        composedReassociation.push_back(shape_indices);
+      }
+    }
+    return {std::move(composedReassociation)};
+  }
 };
 
 } // namespace mlir

diff  --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 094e6611cd101..86dc0199b0c39 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -1802,8 +1802,9 @@ LogicalResult ExpandShapeOp::verify() {
 
 void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                                 MLIRContext *context) {
-  results.add<CollapseReshapeOps<ExpandShapeOp>,
-              CollapseMixedReshapeOps<ExpandShapeOp, CollapseShapeOp>>(context);
+  results.add<ComposeReassociativeReshapeOps<ExpandShapeOp>,
+              ComposeExpandOfCollapseOp<ExpandShapeOp, CollapseShapeOp>>(
+      context);
 }
 
 /// Compute the layout map after collapsing a given source MemRef type with the
@@ -1966,8 +1967,8 @@ struct CollapseShapeOpMemRefCastFolder
 
 void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                                   MLIRContext *context) {
-  results.add<CollapseReshapeOps<CollapseShapeOp>,
-              CollapseMixedReshapeOps<CollapseShapeOp, ExpandShapeOp>,
+  results.add<ComposeReassociativeReshapeOps<CollapseShapeOp>,
+              ComposeCollapseOfExpandOp<CollapseShapeOp, ExpandShapeOp>,
               CollapseShapeOpMemRefCastFolder>(context);
 }
 

diff  --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 1c8065ec88095..5b52a3fdd24ce 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -890,16 +890,16 @@ struct FoldReshapeWithFromElements : OpRewritePattern<TensorReshapeOp> {
 
 void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                                 MLIRContext *context) {
-  results.add<CollapseReshapeOps<ExpandShapeOp>,
-              CollapseMixedReshapeOps<ExpandShapeOp, CollapseShapeOp>,
+  results.add<ComposeReassociativeReshapeOps<ExpandShapeOp>,
+              ComposeExpandOfCollapseOp<ExpandShapeOp, CollapseShapeOp>,
               FoldReshapeWithConstant<ExpandShapeOp>,
               FoldReshapeWithFromElements<ExpandShapeOp>>(context);
 }
 
 void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                                   MLIRContext *context) {
-  results.add<CollapseReshapeOps<CollapseShapeOp>,
-              CollapseMixedReshapeOps<CollapseShapeOp, ExpandShapeOp>,
+  results.add<ComposeReassociativeReshapeOps<CollapseShapeOp>,
+              ComposeCollapseOfExpandOp<CollapseShapeOp, ExpandShapeOp>,
               FoldReshapeWithConstant<CollapseShapeOp>,
               FoldReshapeWithFromElements<CollapseShapeOp>>(context);
 }

diff  --git a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
index 03cd3af2e7bec..64937be9fac05 100644
--- a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
+++ b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
@@ -18,18 +18,23 @@ using namespace mlir;
 Optional<SmallVector<ReassociationIndices>>
 mlir::getReassociationIndicesForReshape(ShapedType sourceType,
                                         ShapedType targetType) {
-  // Make the sourceType greater rank than the targetType. If they are same
-  // rank, then its an unsupported reshape op.
-  if (sourceType.getRank() == targetType.getRank())
-    return llvm::None;
+  if (sourceType.getRank() > targetType.getRank())
+    return getReassociationIndicesForCollapse(sourceType.getShape(),
+                                              targetType.getShape());
   if (sourceType.getRank() < targetType.getRank())
-    std::swap(sourceType, targetType);
+    return getReassociationIndicesForCollapse(targetType.getShape(),
+                                              sourceType.getShape());
+  return llvm::None;
+}
 
-  ArrayRef<int64_t> sourceShape = sourceType.getShape();
-  ArrayRef<int64_t> targetShape = targetType.getShape();
+Optional<SmallVector<ReassociationIndices>>
+mlir::getReassociationIndicesForCollapse(ArrayRef<int64_t> sourceShape,
+                                         ArrayRef<int64_t> targetShape) {
+  if (sourceShape.size() <= targetShape.size())
+    return llvm::None;
   unsigned sourceDim = 0;
   SmallVector<ReassociationIndices> reassociationMap;
-  reassociationMap.reserve(targetType.getRank());
+  reassociationMap.reserve(targetShape.size());
 
   ReassociationIndices currIndices;
   int64_t prodOfCollapsedDims = 1;
@@ -37,7 +42,7 @@ mlir::getReassociationIndicesForReshape(ShapedType sourceType,
     unsigned targetDim = reassociationMap.size();
     // If we have mapped all the target dimensions stop and handle the remaining
     // tail of size-1 dimensions explictly.
-    if (targetDim == targetType.getRank())
+    if (targetDim == targetShape.size())
       break;
 
     int64_t currTargetShape = targetShape[targetDim];
@@ -187,6 +192,7 @@ mlir::getSymbolLessAffineMaps(ArrayRef<ReassociationExprs> reassociation) {
   }
   return maps;
 }
+
 bool mlir::isReassociationValid(ArrayRef<AffineMap> reassociation,
                                 int *invalidIndex) {
   if (reassociation.empty())
@@ -258,3 +264,9 @@ LogicalResult mlir::reshapeLikeShapesAreCompatible(
   }
   return success();
 }
+
+bool mlir::hasNonIdentityLayout(Type type) {
+  if (auto memrefType = type.dyn_cast<MemRefType>())
+    return !memrefType.getLayout().isIdentity();
+  return false;
+}

diff  --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index 8a4f80e77b61f..1a01460a24dc9 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -302,20 +302,20 @@ func @allocator(%arg0 : memref<memref<?xi32>>, %arg1 : index)  {
 
 // -----
 
-func @collapsing_memref_reshapes_to_zero_dim(%arg0 : memref<1x1x1xf32>)
-                                             -> memref<f32> {
+func @compose_collapse_of_collapse_zero_dim(%arg0 : memref<1x1x1xf32>)
+    -> memref<f32> {
   %0 = memref.collapse_shape %arg0 [[0, 1, 2]]
       : memref<1x1x1xf32> into memref<1xf32>
   %1 = memref.collapse_shape %0 [] : memref<1xf32> into memref<f32>
   return %1 : memref<f32>
 }
-// CHECK-LABEL: collapsing_memref_reshapes_to_zero
+// CHECK-LABEL: func @compose_collapse_of_collapse_zero_dim
 //       CHECK:   memref.collapse_shape %{{.*}} []
 //  CHECK-SAME:     memref<1x1x1xf32> into memref<f32>
 
 // -----
 
-func @collapsing_memref_reshapes(%arg0 : memref<?x?x?x?x?xf32>)
+func @compose_collapse_of_collapse(%arg0 : memref<?x?x?x?x?xf32>)
     -> memref<?x?xf32> {
   %0 = memref.collapse_shape %arg0 [[0, 1], [2], [3, 4]]
       : memref<?x?x?x?x?xf32> into memref<?x?x?xf32>
@@ -323,13 +323,30 @@ func @collapsing_memref_reshapes(%arg0 : memref<?x?x?x?x?xf32>)
       : memref<?x?x?xf32> into memref<?x?xf32>
   return %1 : memref<?x?xf32>
 }
-// CHECK-LABEL: collapsing_memref_reshapes
+// CHECK-LABEL: func @compose_collapse_of_collapse
 //       CHECK:   memref.collapse_shape %{{.*}} {{\[}}[0, 1, 2], [3, 4]]
 //   CHECK-NOT:   memref.collapse_shape
 
 // -----
 
-func @expanding_memref_reshapes(%arg0 : memref<?x?xf32>)
+func @do_not_compose_collapse_of_expand_non_identity_layout(
+    %arg0: memref<?x?xf32, offset : 0, strides : [?, 1]>)
+    -> memref<?xf32> {
+  %1 = memref.expand_shape %arg0 [[0, 1], [2]] :
+    memref<?x?xf32, offset : 0, strides : [?, 1]> into
+    memref<?x4x?xf32, offset : 0, strides : [?, ?, 1]>
+  %2 = memref.collapse_shape %1 [[0, 1, 2]] :
+    memref<?x4x?xf32, offset : 0, strides : [?, ?, 1]> into
+    memref<?xf32>
+  return %2 : memref<?xf32>
+}
+// CHECK-LABEL: func @do_not_compose_collapse_of_expand_non_identity_layout
+// CHECK: expand
+// CHECK: collapse
+
+// -----
+
+func @compose_expand_of_expand(%arg0 : memref<?x?xf32>)
     -> memref<?x6x4x5x?xf32> {
   %0 = memref.expand_shape %arg0 [[0, 1], [2]]
       : memref<?x?xf32> into memref<?x4x?xf32>
@@ -337,45 +354,46 @@ func @expanding_memref_reshapes(%arg0 : memref<?x?xf32>)
       : memref<?x4x?xf32> into memref<?x6x4x5x?xf32>
   return %1 : memref<?x6x4x5x?xf32>
 }
-// CHECK-LABEL: expanding_memref_reshapes
+// CHECK-LABEL: func @compose_expand_of_expand
 //       CHECK:   memref.expand_shape %{{.*}} {{\[}}[0, 1, 2], [3, 4]]
 //   CHECK-NOT:   memref.expand_shape
 
 // -----
 
-func @expanding_memref_reshapes_to_zero_dim(%arg0 : memref<f32>)
-                                             -> memref<1x1x1xf32> {
+func @compose_expand_of_expand_of_zero_dim(%arg0 : memref<f32>)
+    -> memref<1x1x1xf32> {
   %0 = memref.expand_shape %arg0 [] : memref<f32> into memref<1xf32>
   %1 = memref.expand_shape %0 [[0, 1, 2]]
       : memref<1xf32> into memref<1x1x1xf32>
   return %1 : memref<1x1x1xf32>
 }
-// CHECK-LABEL: expanding_memref_reshapes_to_zero
+// CHECK-LABEL: func @compose_expand_of_expand_of_zero_dim
 //       CHECK:   memref.expand_shape %{{.*}} []
 //  CHECK-SAME:     memref<f32> into memref<1x1x1xf32>
 
 // -----
 
-func @fold_memref_reshape(%arg0 : memref<12x4xf32>) -> memref<12x4xf32> {
+func @fold_collapse_of_expand(%arg0 : memref<12x4xf32>) -> memref<12x4xf32> {
   %0 = memref.expand_shape %arg0 [[0, 1], [2]]
       : memref<12x4xf32> into memref<3x4x4xf32>
   %1 = memref.collapse_shape %0 [[0, 1], [2]]
       : memref<3x4x4xf32> into memref<12x4xf32>
   return %1 : memref<12x4xf32>
 }
-// CHECK-LABEL: @fold_memref_reshape
+// CHECK-LABEL: func @fold_collapse_of_expand
 //   CHECK-NOT:   linalg.{{.*}}_shape
 
 // -----
 
-func @fold_memref_reshape_dynamic(%arg0 : memref<?x?xf32>) -> memref<?x?xf32> {
+func @fold_collapse_collapse_of_expand(%arg0 : memref<?x?xf32>)
+    -> memref<?x?xf32> {
   %0 = memref.expand_shape %arg0 [[0, 1], [2]]
       : memref<?x?xf32> into memref<?x4x?xf32>
   %1 = memref.collapse_shape %0 [[0, 1], [2]]
       : memref<?x4x?xf32> into memref<?x?xf32>
   return %1 : memref<?x?xf32>
 }
-// CHECK-LABEL: @fold_memref_reshape_dynamic
+// CHECK-LABEL: @fold_collapse_collapse_of_expand
 //   CHECK-NOT:   linalg.{{.*}}_shape
 
 // -----

diff  --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 22770c2e67342..493f278763bb7 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -646,7 +646,7 @@ func @fold_overlapping_insert(%input : tensor<?x?x?xf32>, %slice1: tensor<4x?x8x
 
 // -----
 
-func @expanding_tensor_reshapes(%arg0 : tensor<?x?xf32>)
+func @compose_expand_of_expand(%arg0 : tensor<?x?xf32>)
     -> tensor<?x6x4x?x5xf32> {
   %0 = tensor.expand_shape %arg0 [[0, 1], [2]]
       : tensor<?x?xf32> into tensor<?x4x?xf32>
@@ -654,49 +654,51 @@ func @expanding_tensor_reshapes(%arg0 : tensor<?x?xf32>)
       : tensor<?x4x?xf32> into tensor<?x6x4x?x5xf32>
   return %1 : tensor<?x6x4x?x5xf32>
 }
-// CHECK-LABEL: expanding_tensor_reshapes
+// CHECK-LABEL: compose_expand_of_expand
 //       CHECK:   tensor.expand_shape %{{.*}} {{\[}}[0, 1, 2], [3, 4]]
 //   CHECK-NOT:   tensor.expand_shape
 
 // -----
 
-func @expanding_tensor_reshapes_to_zero_dim(%arg0 : tensor<f32>)
+func @compose_expand_of_expand_of_zero_dim(%arg0 : tensor<f32>)
     -> tensor<1x1x1xf32> {
   %0 = tensor.expand_shape %arg0 [] : tensor<f32> into tensor<1xf32>
   %1 = tensor.expand_shape %0 [[0, 1, 2]]
       : tensor<1xf32> into tensor<1x1x1xf32>
   return %1 : tensor<1x1x1xf32>
 }
-// CHECK-LABEL: expanding_tensor_reshapes_to_zero
+// CHECK-LABEL: compose_expand_of_expand_of_zero_dim
 //       CHECK:   tensor.expand_shape %{{.*}} []
 //  CHECK-SAME:     tensor<f32> into tensor<1x1x1xf32>
 
 // -----
 
-func @fold_tensor_reshape(%arg0 : tensor<12x4xf32>) -> tensor<12x4xf32> {
+func @fold_collapse_of_expand(%arg0 : tensor<12x4xf32>) -> tensor<12x4xf32> {
   %0 = tensor.expand_shape %arg0 [[0, 1], [2]]
       : tensor<12x4xf32> into tensor<3x4x4xf32>
   %1 = tensor.collapse_shape %0 [[0, 1], [2]]
       : tensor<3x4x4xf32> into tensor<12x4xf32>
   return %1 : tensor<12x4xf32>
 }
-// CHECK-LABEL: @fold_tensor_reshape
+// CHECK-LABEL: @fold_collapse_of_expand
 //   CHECK-NOT:   linalg.{{.*}}shape
 
 // -----
 
-func @fold_tensor_reshape_dynamic(%arg0 : tensor<?x?xf32>) -> tensor<?x?xf32> {
+func @fold_collapse_of_expand_dynamic(%arg0 : tensor<?x?xf32>)
+    -> tensor<?x?xf32> {
   %0 = tensor.expand_shape %arg0 [[0, 1], [2]]
       : tensor<?x?xf32> into tensor<?x4x?xf32>
   %1 = tensor.collapse_shape %0 [[0, 1], [2]]
       : tensor<?x4x?xf32> into tensor<?x?xf32>
   return %1 : tensor<?x?xf32>
 }
-// CHECK-LABEL: @fold_tensor_reshape_dynamic
+// CHECK-LABEL: @fold_collapse_of_expand_dynamic
 //   CHECK-NOT:   linalg.{{.*}}_shape
 
 // -----
-func @reshape_collapse(%arg0 : tensor<2x3x4x5x6x7x8xf32>)
+
+func @compose_expand_of_collapse(%arg0 : tensor<2x3x4x5x6x7x8xf32>)
     -> tensor<24x5x42x8xf32> {
   %0 = tensor.collapse_shape %arg0 [[0, 1, 2, 3, 4, 5, 6]]
       : tensor<2x3x4x5x6x7x8xf32> into tensor<40320xf32>
@@ -704,7 +706,7 @@ func @reshape_collapse(%arg0 : tensor<2x3x4x5x6x7x8xf32>)
       : tensor<40320xf32> into tensor<24x5x42x8xf32>
   return %1 : tensor<24x5x42x8xf32>
 }
-//      CHECK: func @reshape_collapse
+//      CHECK: func @compose_expand_of_collapse
 // CHECK-SAME:   %[[ARG0:.+]]: tensor<2x3x4x5x6x7x8xf32>
 //      CHECK:   %[[RESULT:.+]] = tensor.collapse_shape %[[ARG0]]
 // CHECK-SAME:     [0, 1, 2], [3], [4, 5], [6]
@@ -712,7 +714,7 @@ func @reshape_collapse(%arg0 : tensor<2x3x4x5x6x7x8xf32>)
 
 // -----
 
-func @reshape_expand(%arg0 : tensor<24x5x42x8xf32>)
+func @compose_expand_of_collapse_7D(%arg0 : tensor<24x5x42x8xf32>)
     -> tensor<2x3x4x5x6x7x8xf32> {
   %0 = tensor.collapse_shape %arg0 [[0, 1, 2, 3]]
       : tensor<24x5x42x8xf32> into tensor<40320xf32>
@@ -720,7 +722,7 @@ func @reshape_expand(%arg0 : tensor<24x5x42x8xf32>)
       : tensor<40320xf32> into tensor<2x3x4x5x6x7x8xf32>
   return %1 : tensor<2x3x4x5x6x7x8xf32>
 }
-//      CHECK: func @reshape_expand
+//      CHECK: func @compose_expand_of_collapse_7D
 // CHECK-SAME:   %[[ARG0:.+]]: tensor<24x5x42x8xf32>
 //      CHECK:   %[[RESULT:.+]] = tensor.expand_shape %[[ARG0]]
 // CHECK-SAME:     [0, 1, 2], [3], [4, 5], [6]
@@ -728,20 +730,69 @@ func @reshape_expand(%arg0 : tensor<24x5x42x8xf32>)
 
 // -----
 
-func @expand_reshape_1D(%arg0 : tensor<2048xf32>) -> tensor<4x512xf32> {
+func @compose_collapse_of_expand(%arg : tensor<?x?x?xi64>)
+    -> tensor<?x?xi64> {
+  %0 = tensor.expand_shape %arg [[0], [1], [2, 3]]
+    : tensor<?x?x?xi64> into tensor<?x?x?x1xi64>
+  %1 = tensor.collapse_shape %0 [[0, 1], [2, 3]]
+    : tensor<?x?x?x1xi64> into tensor<?x?xi64>
+  return %1 : tensor<?x?xi64>
+}
+// CHECK-LABEL: func @compose_collapse_of_expand
+//       CHECK:   (%[[ARG:.*]]: tensor<?x?x?xi64>)
+//  CHECK-NEXT: tensor.collapse_shape %[[ARG]]
+//  CHECK-SAME:   [0, 1], [2]
+//  CHECK-SAME:   : tensor<?x?x?xi64> into tensor<?x?xi64>
+
+// -----
+
+func @compose_collapse_of_expand_1D(%arg0 : tensor<2048xf32>)
+    -> tensor<4x512xf32> {
   %0 = tensor.expand_shape %arg0 [[0, 1, 2, 3]]
     : tensor<2048xf32> into tensor<1x4x1x512xf32>
   %1 = tensor.collapse_shape %0 [[0, 1, 2], [3]]
     : tensor<1x4x1x512xf32> into tensor<4x512xf32>
   return %1 : tensor<4x512xf32>
 }
-//       CHECK: func @expand_reshape_1D
+//       CHECK: func @compose_collapse_of_expand_1D
 //       CHECK: tensor.expand_shape %{{.*}} {{\[}}[0, 1]]
 //  CHECK-SAME:   tensor<2048xf32> into tensor<4x512xf32>
 
 // -----
 
-// CHECK-LABEL: zero_rank_reshape_multi
+func @compose_expand_of_collapse_0_rank_to_expand(%arg0 : tensor<1x1x1xf32>)
+    -> tensor<1x1x1x1xf32> {
+  %0 = tensor.collapse_shape %arg0 []
+      : tensor<1x1x1xf32> into tensor<f32>
+  %1 = tensor.expand_shape %0 []
+      : tensor<f32> into tensor<1x1x1x1xf32>
+  return %1 : tensor<1x1x1x1xf32>
+}
+//      CHECK: func @compose_expand_of_collapse_0_rank_to_expand
+// CHECK-SAME:   %[[ARG0:.+]]: tensor<1x1x1xf32>
+//      CHECK:   %[[RESULT:.+]] = tensor.expand_shape %[[ARG0]]
+// CHECK-SAME:     [0], [1], [2, 3]
+//      CHECK:   return %[[RESULT]]
+
+// -----
+
+func @compose_expand_of_collapse_0_rank_to_collapse(%arg0 : tensor<1x1x1x1xf32>)
+    -> tensor<1x1x1xf32> {
+  %0 = tensor.collapse_shape %arg0 []
+      : tensor<1x1x1x1xf32> into tensor<f32>
+  %1 = tensor.expand_shape %0 []
+      : tensor<f32> into tensor<1x1x1xf32>
+  return %1 : tensor<1x1x1xf32>
+}
+//      CHECK: func @compose_expand_of_collapse_0_rank_to_collapse
+// CHECK-SAME:   %[[ARG0:.+]]: tensor<1x1x1x1xf32>
+//      CHECK:   %[[RESULT:.+]] = tensor.collapse_shape %[[ARG0]]
+// CHECK-SAME:     [0], [1], [2, 3]
+//      CHECK:   return %[[RESULT]]
+
+// -----
+
+// CHECK-LABEL: func @zero_rank_reshape_multi
 func @zero_rank_reshape_multi(%arg0: tensor<f32>) -> tensor<f32> {
   // CHECK: return %arg0
   %0 = tensor.expand_shape %arg0 [] : tensor<f32> into tensor<1xf32>
@@ -752,7 +803,7 @@ func @zero_rank_reshape_multi(%arg0: tensor<f32>) -> tensor<f32> {
 
 // -----
 
-func @collapsing_tensor_reshapes(%arg0 : tensor<?x?x?x?x?xf32>)
+func @compose_collapse_of_collapse(%arg0 : tensor<?x?x?x?x?xf32>)
     -> tensor<?x?xf32> {
   %0 = tensor.collapse_shape %arg0 [[0, 1], [2], [3, 4]]
       : tensor<?x?x?x?x?xf32> into tensor<?x?x?xf32>
@@ -760,39 +811,39 @@ func @collapsing_tensor_reshapes(%arg0 : tensor<?x?x?x?x?xf32>)
       : tensor<?x?x?xf32> into tensor<?x?xf32>
   return %1 : tensor<?x?xf32>
 }
-// CHECK-LABEL: collapsing_tensor_reshapes
+// CHECK-LABEL: func @compose_collapse_of_collapse
 //       CHECK:   tensor.collapse_shape %{{.*}} {{\[}}[0, 1, 2], [3, 4]]
 //   CHECK-NOT:   tensor.collapse_shape
 
 // -----
 
-func @collapsing_tensor_reshapes_to_zero_dim(%arg0 : tensor<1x1x1xf32>)
+func @compose_collapse_of_collapse_zero_dim(%arg0 : tensor<1x1x1xf32>)
     -> tensor<f32> {
   %0 = tensor.collapse_shape %arg0 [[0, 1, 2]]
       : tensor<1x1x1xf32> into tensor<1xf32>
   %1 = tensor.collapse_shape %0 [] : tensor<1xf32> into tensor<f32>
   return %1 : tensor<f32>
 }
-// CHECK-LABEL: collapsing_tensor_reshapes_to_zero
+// CHECK-LABEL: func @compose_collapse_of_collapse_zero_dim
 //       CHECK:   tensor.collapse_shape %{{.*}} []
 //  CHECK-SAME:     tensor<1x1x1xf32> into tensor<f32>
 
 // -----
 
-func @fold_reshape_1D(%arg0 : tensor<4x512xf32>) -> tensor<2048xf32> {
+func @fold_collapse_of_expand_1D(%arg0 : tensor<4x512xf32>) -> tensor<2048xf32> {
   %0 = tensor.expand_shape %arg0 [[0, 1, 2], [3]]
     : tensor<4x512xf32> into tensor<1x4x1x512xf32>
   %1 = tensor.collapse_shape %0 [[0, 1, 2, 3]]
     : tensor<1x4x1x512xf32> into tensor<2048xf32>
   return %1 : tensor<2048xf32>
 }
-//       CHECK: func @fold_reshape_1D
+//       CHECK: func @fold_collapse_of_expand_1D
 //       CHECK: tensor.collapse_shape %{{.*}} {{\[}}[0, 1]]
 //  CHECK-SAME:   tensor<4x512xf32> into tensor<2048xf32>
 
 // -----
 
-func @fold_reshape_unit_dims(%arg0 : tensor<2048x1x1xf32>)
+func @fold_collapse_of_expand_unit_dims(%arg0 : tensor<2048x1x1xf32>)
     -> tensor<4x512x1x1xf32> {
   %0 = tensor.expand_shape %arg0 [[0, 1, 2, 3], [4], [5]]
     : tensor<2048x1x1xf32> into tensor<1x4x1x512x1x1xf32>
@@ -800,13 +851,13 @@ func @fold_reshape_unit_dims(%arg0 : tensor<2048x1x1xf32>)
     : tensor<1x4x1x512x1x1xf32> into tensor<4x512x1x1xf32>
   return %1 : tensor<4x512x1x1xf32>
 }
-//       CHECK: func @fold_reshape_unit_dims
+//       CHECK: func @fold_collapse_of_expand_unit_dims
 //       CHECK: tensor.expand_shape %{{.*}} {{\[}}[0, 1], [2], [3]]
 //  CHECK-SAME:   tensor<2048x1x1xf32> into tensor<4x512x1x1xf32>
 
 // -----
 
-func @expand_reshape_unit_dims(%arg0 : tensor<2048x1x2048xf32>)
+func @compose_collapse_of_expand_unit_dims(%arg0 : tensor<2048x1x2048xf32>)
     -> tensor<4x512x1x512x4xf32> {
   %0 = tensor.expand_shape %arg0 [[0, 1, 2, 3, 4], [5], [6, 7, 8]]
     : tensor<2048x1x2048xf32> into tensor<1x4x1x512x1x1x512x1x4xf32>
@@ -814,69 +865,70 @@ func @expand_reshape_unit_dims(%arg0 : tensor<2048x1x2048xf32>)
     : tensor<1x4x1x512x1x1x512x1x4xf32> into tensor<4x512x1x512x4xf32>
   return %1 : tensor<4x512x1x512x4xf32>
 }
-//       CHECK: func @expand_reshape_unit_dims
+//       CHECK: func @compose_collapse_of_expand_unit_dims
 //       CHECK: tensor.expand_shape %{{.*}} {{\[}}[0, 1], [2], [3, 4]]
 //  CHECK-SAME:   tensor<2048x1x2048xf32> into tensor<4x512x1x512x4xf32>
 
 // -----
 
-func @fold_reshape_trailing_unit_dims(%arg0: tensor<2xf32>) -> tensor<2x1xf32> {
+func @compose_collapse_of_expand_trailing_unit_dims(%arg0: tensor<2xf32>)
+    -> tensor<2x1xf32> {
   %0 = tensor.expand_shape %arg0 [[0, 1, 2]]
       : tensor<2xf32> into tensor<2x1x1xf32>
   %1 = tensor.collapse_shape %0 [[0], [1, 2]]
       : tensor<2x1x1xf32> into tensor<2x1xf32>
   return %1 : tensor<2x1xf32>
 }
-//       CHECK: func @fold_reshape_trailing_unit_dims
+//       CHECK: func @compose_collapse_of_expand_trailing_unit_dims
 //       CHECK: tensor.expand_shape %{{.*}} {{\[}}[0, 1]]
 //  CHECK-SAME:   tensor<2xf32> into tensor<2x1xf32>
 
 // -----
 
-func @collapse_reshape_unit_dims_dynamic(%arg0 : tensor<?x1x?x1x1x?x?x1x1xf32>)
-    -> tensor<?x?x?x?xf32> {
+func @compose_collapse_of_collapse_unit_dims_dynamic(
+    %arg0 : tensor<?x1x?x1x1x?x?x1x1xf32>) -> tensor<?x?x?x?xf32> {
   %0 = tensor.collapse_shape %arg0 [[0], [1, 2], [3], [4], [5], [6, 7, 8]]
     : tensor<?x1x?x1x1x?x?x1x1xf32> into tensor<?x?x1x1x?x?xf32>
   %1 = tensor.collapse_shape %0 [[0], [1], [2, 3, 4], [5]]
     : tensor<?x?x1x1x?x?xf32> into tensor<?x?x?x?xf32>
   return %1 : tensor<?x?x?x?xf32>
 }
-//       CHECK: func @collapse_reshape_unit_dims_dynamic
+//       CHECK: func @compose_collapse_of_collapse_unit_dims_dynamic
 //       CHECK: tensor.collapse_shape
 //  CHECK-SAME:   [0], [1, 2], [3, 4, 5], [6, 7, 8]
 //  CHECK-SAME:   tensor<?x1x?x1x1x?x?x1x1xf32> into tensor<?x?x?x?xf32>
 
 // -----
 
-func @fold_reshape_trailing_unit_dims(%arg0: tensor<2xf32>) -> tensor<2x1xf32>
-{
+func @fold_collapse_of_expand_trailing_unit_dims(%arg0: tensor<2xf32>)
+    -> tensor<2x1xf32> {
   %0 = tensor.expand_shape %arg0 [[0, 1, 2]]
       : tensor<2xf32> into tensor<2x1x1xf32>
   %1 = tensor.collapse_shape %0 [[0], [1, 2]]
       : tensor<2x1x1xf32> into tensor<2x1xf32>
   return %1 : tensor<2x1xf32>
 }
-//       CHECK: func @fold_reshape_trailing_unit_dims
+//       CHECK: func @fold_collapse_of_expand_trailing_unit_dims
 //       CHECK: tensor.expand_shape %{{.*}} {{\[}}[0, 1]]
 //  CHECK-SAME:   tensor<2xf32> into tensor<2x1xf32>
 
 // -----
 
-func @fold_reshape_trailing_unit_dims_dynamic(%arg0: tensor<1x1x?x1x1x1xf32>)
-    -> tensor<?xf32> {
+func @fold_collapse_of_collapse_trailing_unit_dims_dynamic(
+    %arg0: tensor<1x1x?x1x1x1xf32>) -> tensor<?xf32> {
   %0 = tensor.collapse_shape %arg0 [[0, 1, 2], [3], [4], [5]]
       : tensor<1x1x?x1x1x1xf32> into tensor<?x1x1x1xf32>
   %1 = tensor.collapse_shape %0 [[0, 1, 2, 3]]
       : tensor<?x1x1x1xf32> into tensor<?xf32>
   return %1 : tensor<?xf32>
 }
-//       CHECK: func @fold_reshape_trailing_unit_dims_dynamic
+//       CHECK: func @fold_collapse_of_collapse_trailing_unit_dims_dynamic
 //       CHECK: tensor.collapse_shape %{{.*}} {{\[}}[0, 1, 2, 3, 4, 5]]
 //  CHECK-SAME:   tensor<1x1x?x1x1x1xf32> into tensor<?xf32>
 
 // -----
 
-func @fold_reshape_trailing_unit_dims(%arg0: tensor<12x42x1x1xf32>)
+func @fold_collapse_of_expand_trailing_unit_dims(%arg0: tensor<12x42x1x1xf32>)
     -> tensor<12x42xf32> {
   %0 = tensor.expand_shape %arg0 [[0], [1], [2], [3, 4]]
       : tensor<12x42x1x1xf32> into tensor<12x42x1x1x1xf32>
@@ -884,27 +936,28 @@ func @fold_reshape_trailing_unit_dims(%arg0: tensor<12x42x1x1xf32>)
       : tensor<12x42x1x1x1xf32> into tensor<12x42xf32>
   return %1 : tensor<12x42xf32>
 }
-//       CHECK: func @fold_reshape_trailing_unit_dims
+//       CHECK: func @fold_collapse_of_expand_trailing_unit_dims
 //       CHECK: tensor.collapse_shape %{{.*}} {{\[}}[0], [1, 2, 3]]
 //  CHECK-SAME:   tensor<12x42x1x1xf32> into tensor<12x42xf32>
 
 // -----
 
-func @fold_reshapes_unit_dims_in_middle(%arg0 : tensor<?x?x?xf32>) -> tensor<?x?xf32> {
+func @fold_collapse_of_expand_unit_dims_in_middle(%arg0 : tensor<?x?x?xf32>)
+    -> tensor<?x?xf32> {
   %0 = tensor.expand_shape %arg0 [[0], [1], [2, 3]]
       : tensor<?x?x?xf32> into tensor<?x?x1x?xf32>
   %1 = tensor.collapse_shape %0 [[0], [1, 2, 3]]
       : tensor<?x?x1x?xf32> into tensor<?x?xf32>
   return %1 : tensor<?x?xf32>
 }
-// CHECK-LABEL: func @fold_reshapes_unit_dims_in_middle
+// CHECK-LABEL: func @fold_collapse_of_expand_unit_dims_in_middle
 //  CHECK-SAME: (%[[ARG:.*]]: tensor<?x?x?xf32>
 //       CHECK: tensor.collapse_shape %[[ARG]] {{\[}}[0], [1, 2]]
 //  CHECK-SAME:   tensor<?x?x?xf32> into tensor<?x?xf32>
 
 // -----
 
-func @no_fold_reshape_incompatible(%arg0 : tensor<4x6x8xf32>)
+func @no_fold_collapse_of_expand_incompatible(%arg0 : tensor<4x6x8xf32>)
     -> tensor<2x6x16xf32> {
   %0 = tensor.expand_shape %arg0 [[0, 1], [2, 3], [4]]
       : tensor<4x6x8xf32> into tensor<2x2x3x2x8xf32>
@@ -912,20 +965,21 @@ func @no_fold_reshape_incompatible(%arg0 : tensor<4x6x8xf32>)
       : tensor<2x2x3x2x8xf32> into tensor<2x6x16xf32>
   return %1 : tensor<2x6x16xf32>
 }
-// CHECK-LABEL: func @no_fold_reshape_incompatible
+// CHECK-LABEL: func @no_fold_collapse_of_expand_incompatible
 //       CHECK:   tensor.expand_shape
 //       CHECK:   tensor.collapse_shape
 
 // -----
 
-func @no_fold_reshape_empty_expr(%arg0: tensor<3x2x2xf32>) -> tensor<12x1xf32> {
+func @no_fold_collapse_of_expand_empty_expr(%arg0: tensor<3x2x2xf32>)
+    -> tensor<12x1xf32> {
   %0 = tensor.expand_shape %arg0 [[0], [1], [2, 3]]
       : tensor<3x2x2xf32> into tensor<3x2x2x1xf32>
   %1 = tensor.collapse_shape %0 [[0, 1, 2], [3]]
       : tensor<3x2x2x1xf32> into tensor<12x1xf32>
   return %1 : tensor<12x1xf32>
 }
-//      CHECK: func @no_fold_reshape_empty_expr
+//      CHECK: func @no_fold_collapse_of_expand_empty_expr
 // CHECK-SAME:    %[[ARG0:.+]]: tensor<3x2x2xf32>
 //      CHECK:    %[[RARG0:.+]] = tensor.expand_shape %[[ARG0]]
 // CHECK-SAME:      [0], [1], [2, 3]
@@ -1002,11 +1056,11 @@ func @fold_rank() -> (index) {
 
 // -----
 
-// CHECK-LABEL: func @pad_tensor_same_static_shape(
+// CHECK-LABEL: func @pad_same_static_shape(
 //  CHECK-SAME:   %[[ARG0:.*]]: tensor<5x6xf32>
 //   CHECK-NOT:   tensor.pad
 //       CHECK:   return %[[ARG0]]
-func @pad_tensor_same_static_shape(%arg0: tensor<5x6xf32>, %a: index)
+func @pad_same_static_shape(%arg0: tensor<5x6xf32>, %a: index)
     -> tensor<5x6xf32> {
   %cst = arith.constant 0.000000e+00 : f32
   %0 = tensor.pad %arg0 low[%a, 0] high[0, %a] {
@@ -1018,11 +1072,11 @@ func @pad_tensor_same_static_shape(%arg0: tensor<5x6xf32>, %a: index)
 
 // -----
 
-// CHECK-LABEL: func @pad_tensor_nofold_same_static_shape(
+// CHECK-LABEL: func @pad_nofold_same_static_shape(
 //  CHECK-SAME:   %[[ARG0:.*]]: tensor<5x6xf32>
 //       CHECK:   %[[PAD:.*]] = tensor.pad
 //       CHECK:   return %[[PAD]]
-func @pad_tensor_nofold_same_static_shape(%arg0: tensor<5x6xf32>, %a: index)
+func @pad_nofold_same_static_shape(%arg0: tensor<5x6xf32>, %a: index)
     -> tensor<5x6xf32> {
   %cst = arith.constant 0.000000e+00 : f32
   %0 = tensor.pad %arg0 nofold low[%a, 0] high[0, %a] {
@@ -1034,7 +1088,7 @@ func @pad_tensor_nofold_same_static_shape(%arg0: tensor<5x6xf32>, %a: index)
 
 // -----
 
-// CHECK-LABEL:   func @pad_tensor_after_cast_
diff erent_shape(
+// CHECK-LABEL:   func @pad_after_cast_
diff erent_shape(
 // CHECK-SAME:      %[[INPUT:.*]]: tensor<?x64x?x?xf32>) -> tensor<?x?x?x?xf32> {
 // CHECK:           %[[CST:.*]] = arith.constant 0.000000e+00 : f32
 // CHECK:           %[[PADDED:.*]] = tensor.pad %[[INPUT]]
@@ -1046,7 +1100,7 @@ func @pad_tensor_nofold_same_static_shape(%arg0: tensor<5x6xf32>, %a: index)
 // CHECK-SAME:         tensor<?x64x?x?xf32> to tensor<?x?x?x?xf32>
 // CHECK:           return %[[DYNAMIC]] : tensor<?x?x?x?xf32>
 // CHECK:         }
-func @pad_tensor_after_cast_
diff erent_shape(%arg0: tensor<?x64x?x?xf32>)
+func @pad_after_cast_
diff erent_shape(%arg0: tensor<?x64x?x?xf32>)
     -> tensor<?x?x?x?xf32> {
   %cst = arith.constant 0.000000e+00 : f32
   %dynamic = tensor.cast %arg0 : tensor<?x64x?x?xf32> to tensor<?x?x?x?xf32>
@@ -1059,7 +1113,7 @@ func @pad_tensor_after_cast_
diff erent_shape(%arg0: tensor<?x64x?x?xf32>)
 
 // -----
 
-// CHECK-LABEL:   func @pad_tensor_after_cast_same_shape(
+// CHECK-LABEL:   func @pad_after_cast_same_shape(
 // CHECK-SAME:      %[[INPUT:.*]]: tensor<?x64x?x?xf32>,
 // CHECK-SAME:      %[[PADDING:.*]]: index) -> tensor<?x?x?x?xf32> {
 // CHECK:           %[[CST:.*]] = arith.constant 0.000000e+00 : f32
@@ -1070,7 +1124,7 @@ func @pad_tensor_after_cast_
diff erent_shape(%arg0: tensor<?x64x?x?xf32>)
 // CHECK:           } : tensor<?x64x?x?xf32> to tensor<?x?x?x?xf32>
 // CHECK:           return %[[PADDED:.*]] : tensor<?x?x?x?xf32>
 // CHECK:         }
-func @pad_tensor_after_cast_same_shape(%arg0: tensor<?x64x?x?xf32>, %padding : index)
+func @pad_after_cast_same_shape(%arg0: tensor<?x64x?x?xf32>, %padding : index)
     -> tensor<?x?x?x?xf32> {
   %cst = arith.constant 0.000000e+00 : f32
   %dynamic = tensor.cast %arg0 : tensor<?x64x?x?xf32> to tensor<?x?x?x?xf32>
@@ -1083,11 +1137,11 @@ func @pad_tensor_after_cast_same_shape(%arg0: tensor<?x64x?x?xf32>, %padding : i
 
 // -----
 
-// CHECK-LABEL: func @pad_tensor_of_cast(
+// CHECK-LABEL: func @pad_of_cast(
 // CHECK-NOT:     tensor.cast
 // CHECK:         tensor.pad
 // CHECK:         tensor<8x?xf32> to tensor<8x32xf32>
-func @pad_tensor_of_cast(%t: tensor<8x?xf32>, %s: index) -> tensor<8x32xf32> {
+func @pad_of_cast(%t: tensor<8x?xf32>, %s: index) -> tensor<8x32xf32> {
   %c0 = arith.constant 0 : index
   %cst = arith.constant 0.000000e+00 : f32
   %0 = tensor.cast %t : tensor<8x?xf32> to tensor<?x?xf32>
@@ -1133,7 +1187,7 @@ func @cast_of_pad_less_static(%arg0: tensor<32x?x?xf32>, %padding: index) -> ten
 
 // -----
 
-func @tensor_pad_cast_fold(%arg0: tensor<4x4xf32>) -> tensor<4x4xf32> {
+func @pad_cast_fold(%arg0: tensor<4x4xf32>) -> tensor<4x4xf32> {
   %c0 = arith.constant 0 : index
   %cst = arith.constant 0.0 : f32
   %0 = tensor.cast %arg0 : tensor<4x4xf32> to tensor<?x?xf32>
@@ -1143,17 +1197,17 @@ func @tensor_pad_cast_fold(%arg0: tensor<4x4xf32>) -> tensor<4x4xf32> {
   } : tensor<?x?xf32> to tensor<4x4xf32>
   return %1 : tensor<4x4xf32>
 }
-// CHECK-LABEL: @tensor_pad_cast
+// CHECK-LABEL: @pad_cast
 // CHECK-SAME: %[[ARG0:.+]]: tensor<4x4xf32>
 // CHECK: return %[[ARG0]]
 
 // -----
 
-// CHECK-LABEL: func @fold_pad_tensor_source_cast(
+// CHECK-LABEL: func @fold_pad_source_cast(
 //  CHECK-SAME:                  %[[ARG0:.*]]: tensor<4x?xf32>
 //   CHECK-NOT:   tensor.cast
 //       CHECK:   %[[RESULT:.*]] = tensor.pad %[[ARG0]]
-func @fold_pad_tensor_source_cast(%arg0: tensor<4x?xf32>) -> tensor<4x4xf32> {
+func @fold_pad_source_cast(%arg0: tensor<4x?xf32>) -> tensor<4x4xf32> {
   %cst = arith.constant 0.0 : f32
   %0 = tensor.cast %arg0 : tensor<4x?xf32> to tensor<?x?xf32>
   %1 = tensor.pad %0 low[0, 0] high[0, 1]  {


        


More information about the Mlir-commits mailing list