[Mlir-commits] [mlir] 96e9b6c - Revert "[mlir] Rewrite canonicalization of collapse(expand) and expand(collapse)."
Hanhan Wang
llvmlistbot at llvm.org
Tue Apr 5 15:06:05 PDT 2022
Author: Hanhan Wang
Date: 2022-04-05T15:05:41-07:00
New Revision: 96e9b6c9dc60946f08399def879a19395bc98107
URL: https://github.com/llvm/llvm-project/commit/96e9b6c9dc60946f08399def879a19395bc98107
DIFF: https://github.com/llvm/llvm-project/commit/96e9b6c9dc60946f08399def879a19395bc98107.diff
LOG: Revert "[mlir] Rewrite canonicalization of collapse(expand) and expand(collapse)."
This reverts commit 64f659bee67b5a024defeb3cd2ecf65e1ad8c0a7.
An invalid tensor.expand_shape op is generated with the commit. To repro:
$ mlir-opt -canonicalize a.mlir
```
func @foo(%0: tensor<1x1xf32>, %1: tensor<1x1xf32>, %2: tensor<1x1xf32>) -> tensor<1x1xf32> {
%cst = arith.constant 0.000000e+00 : f32
%3 = linalg.init_tensor [8, 1] : tensor<8x1xf32>
%4 = linalg.fill ins(%cst : f32) outs(%3 : tensor<8x1xf32>) -> tensor<8x1xf32>
%5 = tensor.collapse_shape %0 [] : tensor<1x1xf32> into tensor<f32>
%6 = tensor.insert_slice %5 into %4[0, 0] [1, 1] [1, 1] : tensor<f32> into tensor<8x1xf32>
%7 = linalg.init_tensor [8, 1] : tensor<8x1xf32>
%8 = linalg.fill ins(%cst : f32) outs(%7 : tensor<8x1xf32>) -> tensor<8x1xf32>
%9 = tensor.collapse_shape %2 [] : tensor<1x1xf32> into tensor<f32>
%10 = tensor.insert_slice %9 into %8[0, 0] [1, 1] [1, 1] : tensor<f32> into tensor<8x1xf32>
%11 = tensor.collapse_shape %6 [[0, 1]] : tensor<8x1xf32> into tensor<8xf32>
%12 = linalg.init_tensor [8] : tensor<8xf32>
%13 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%11 : tensor<8xf32>) outs(%12 : tensor<8xf32>) {
^bb0(%arg3: f32, %arg4: f32):
linalg.yield %arg3 : f32
} -> tensor<8xf32>
%14 = tensor.expand_shape %13 [[0, 1, 2, 3]] : tensor<8xf32> into tensor<1x1x8x1xf32>
%15 = tensor.collapse_shape %1 [] : tensor<1x1xf32> into tensor<f32>
%16 = linalg.init_tensor [] : tensor<f32>
%17 = linalg.generic {indexing_maps = [affine_map<() -> ()>, affine_map<() -> ()>], iterator_types = []} ins(%15 : tensor<f32>) outs(%16 : tensor<f32>) {
^bb0(%arg3: f32, %arg4: f32):
linalg.yield %arg3 : f32
} -> tensor<f32>
%18 = tensor.expand_shape %17 [] : tensor<f32> into tensor<1x1x1x1xf32>
%19 = tensor.collapse_shape %10 [[0, 1]] : tensor<8x1xf32> into tensor<8xf32>
%20 = linalg.init_tensor [8] : tensor<8xf32>
%21 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%19 : tensor<8xf32>) outs(%20 : tensor<8xf32>) {
^bb0(%arg3: f32, %arg4: f32):
linalg.yield %arg3 : f32
} -> tensor<8xf32>
%22 = tensor.expand_shape %21 [[0, 1, 2, 3]] : tensor<8xf32> into tensor<1x1x8x1xf32>
%23 = linalg.mmt4d {comment = "f32*f32->f32, aarch64, matrix*vector"} ins(%14, %18 : tensor<1x1x8x1xf32>, tensor<1x1x1x1xf32>) outs(%22 : tensor<1x1x8x1xf32>) -> tensor<1x1x8x1xf32>
%24 = tensor.collapse_shape %23 [[0, 1, 2, 3]] : tensor<1x1x8x1xf32> into tensor<8xf32>
%25 = linalg.init_tensor [8] : tensor<8xf32>
%26 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%24 : tensor<8xf32>) outs(%25 : tensor<8xf32>) {
^bb0(%arg3: f32, %arg4: f32):
linalg.yield %arg3 : f32
} -> tensor<8xf32>
%27 = tensor.expand_shape %26 [[0, 1]] : tensor<8xf32> into tensor<8x1xf32>
%28 = tensor.extract_slice %27[0, 0] [1, 1] [1, 1] : tensor<8x1xf32> to tensor<f32>
%29 = tensor.expand_shape %28 [] : tensor<f32> into tensor<1x1xf32>
return %29 : tensor<1x1xf32>
}
```
Differential Revision: https://reviews.llvm.org/D123161
Added:
Modified:
mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
mlir/test/Dialect/MemRef/canonicalize.mlir
mlir/test/Dialect/Tensor/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
index e2b4c0742ffdf..dfeac25fd6c99 100644
--- a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
@@ -68,12 +68,6 @@ SmallVector<ReassociationIndices, 2> convertReassociationMapsToIndices(
Optional<SmallVector<ReassociationIndices>>
getReassociationIndicesForReshape(ShapedType sourceType, ShapedType targetType);
-/// Returns the reassociation maps to collapse `sourceShape` to `targetShape` if
-/// possible.
-Optional<SmallVector<ReassociationIndices>>
-getReassociationIndicesForCollapse(ArrayRef<int64_t> sourceShape,
- ArrayRef<int64_t> targetShape);
-
/// Return true if the reassociation specification is valid, false otherwise.
/// When false, the `invalidIndex` integer pointer is optionally filled with the
/// index of the offending reassociation map.
@@ -162,13 +156,10 @@ static LogicalResult verifyReshapeLikeShapes(OpTy op, ShapedType collapsedType,
op.getReassociationIndices(), isExpandingReshape);
}
-/// Returns true iff the type is a MemRefType and has a non-identity layout.
-bool hasNonIdentityLayout(Type type);
-
/// Pattern to collapse producer/consumer reshape ops that are both collapsing
/// dimensions or are both expanding dimensions.
template <typename ReshapeOpTy>
-struct ComposeReassociativeReshapeOps : public OpRewritePattern<ReshapeOpTy> {
+struct CollapseReshapeOps : public OpRewritePattern<ReshapeOpTy> {
using OpRewritePattern<ReshapeOpTy>::OpRewritePattern;
LogicalResult matchAndRewrite(ReshapeOpTy reshapeOp,
PatternRewriter &rewriter) const override {
@@ -177,12 +168,6 @@ struct ComposeReassociativeReshapeOps : public OpRewritePattern<ReshapeOpTy> {
return failure();
ShapedType resultType = reshapeOp.getResultType();
-
- if (hasNonIdentityLayout(srcReshapeOp.src().getType()) ||
- hasNonIdentityLayout(reshapeOp.src().getType()) ||
- hasNonIdentityLayout(reshapeOp.result().getType()))
- return failure();
-
Optional<SmallVector<ReassociationIndices>> reassociationIndices =
composeReassociationIndices(srcReshapeOp.getReassociationIndices(),
reshapeOp.getReassociationIndices(),
@@ -195,180 +180,46 @@ struct ComposeReassociativeReshapeOps : public OpRewritePattern<ReshapeOpTy> {
}
};
-/// Pattern to compose
-/// `collapse_shape(expand_shape(%src, reassociation_1), reassociation_2)`.
-/// In that case both `srcType` and `resultType` can be expressed as a function
-/// of `intermediateType`.
-/// In order to demonstrate the approach, let's assume that `rank(srcType) >
-/// `rank(resultType)`, i.e. the resulting operation should be `collapse_shape`.
-/// In that case, we can iterate over every set of indices in `reassociation_2`
-/// and try to find ids of sets of indices in `reassociation_1` that cover it
-/// completely.
-///
-/// Example:
-///
-/// %0 = tensor.expand_shape %arg [[0], [1], [2, 3]]
-/// : tensor<?x?x?xi64> into tensor<?x?x?x1xi64>
-/// %1 = tensor.collapse_shape %0 [[0, 1], [2, 3]]
-/// : tensor<?x?x?x1xi64> into tensor<?x?xi64>
-///
-/// can be canonicalized into
-///
-/// %0 = tensor.collapse_shape %arg [[0, 1], [2]]
-/// : tensor<?x?x?xi64> into tensor<?x?xi64>
-///
-/// because [0] and [1] from `expand_shape` reassociation cover completely
-/// `[0, 1]` from `collapse_shape`. If it is impossible to find such union of
-/// indices, then we fail.
-//
-/// When `rank(srcType) < rank(resultType)`, then we just swap `reassociation_1`
-/// `reassociation_2` and produce `expand_shape`.
-template <typename CollapseOpTy, typename ExpandOpTy>
-struct ComposeCollapseOfExpandOp : public OpRewritePattern<CollapseOpTy> {
- using OpRewritePattern<CollapseOpTy>::OpRewritePattern;
- LogicalResult matchAndRewrite(CollapseOpTy collapseOp,
+/// Pattern to collapse producer/consumer reshape ops that are both collapsing
+/// dimensions or are both expanding dimensions.
+template <typename ReshapeOpTy, typename InverseReshapeOpTy>
+struct CollapseMixedReshapeOps : public OpRewritePattern<ReshapeOpTy> {
+ using OpRewritePattern<ReshapeOpTy>::OpRewritePattern;
+ LogicalResult matchAndRewrite(ReshapeOpTy reshapeOp,
PatternRewriter &rewriter) const override {
- auto expandOp = collapseOp.src().template getDefiningOp<ExpandOpTy>();
- if (!expandOp)
+ auto srcReshapeOp =
+ reshapeOp.src().template getDefiningOp<InverseReshapeOpTy>();
+ if (!srcReshapeOp)
return failure();
- ShapedType srcType = expandOp.getSrcType();
- ShapedType resultType = collapseOp.getResultType();
-
- if (hasNonIdentityLayout(collapseOp.src().getType()) ||
- hasNonIdentityLayout(expandOp.src().getType()) ||
- hasNonIdentityLayout(expandOp.result().getType()))
- return failure();
+ ShapedType srcReshapeSrcType = srcReshapeOp.getSrcType();
+ ShapedType intermediateType = reshapeOp.getSrcType();
+ ShapedType resultType = reshapeOp.getResultType();
- int64_t srcRank = srcType.getRank();
- int64_t resultRank = resultType.getRank();
- if (srcType == resultType)
+ // If the source reshape can be collapsed/expanded into the target reshape
+ // they can still be folded. This can only be reasoned about statically
+ // for cases where
+ // - either all shapes are static, or
+ // - The number of dynamic dimensions matches in the source of source and
+ // result with all other dimensions being 1.
+ Optional<SmallVector<ReassociationIndices>> reassociationIndices =
+ getReassociationIndicesForReshape(srcReshapeSrcType, resultType);
+ if (!reassociationIndices)
return failure();
-
- SmallVector<ReassociationIndices, 4> higherRankReassociation,
- lowerRankReassociation;
-
- bool isResultCollapsed = srcRank > resultRank;
- if (isResultCollapsed) {
- higherRankReassociation = expandOp.getReassociationIndices();
- lowerRankReassociation = collapseOp.getReassociationIndices();
- } else {
- higherRankReassociation = collapseOp.getReassociationIndices();
- lowerRankReassociation = expandOp.getReassociationIndices();
- }
-
- size_t higherRankIndicesID = 0;
- SmallVector<ReassociationIndices, 4> composedReassociation;
- for (const auto &lowerRankIndices : lowerRankReassociation) {
- ReassociationIndices composedIndices;
- while (higherRankIndicesID < higherRankReassociation.size()) {
- auto rightmostIndex =
- higherRankReassociation[higherRankIndicesID].back();
- if (rightmostIndex > lowerRankIndices.back())
- return failure();
- composedIndices.push_back(higherRankIndicesID++);
- if (rightmostIndex == lowerRankIndices.back())
- break;
- }
- composedReassociation.push_back(composedIndices);
- }
- if (isResultCollapsed)
- rewriter.replaceOpWithNewOp<CollapseOpTy>(
- collapseOp, resultType, expandOp.src(), composedReassociation);
+ bool originalOpExpands =
+ intermediateType.getRank() > srcReshapeSrcType.getRank();
+ bool resultingOpExpands =
+ resultType.getRank() > srcReshapeSrcType.getRank();
+ if (!(resultingOpExpands ^ originalOpExpands))
+ rewriter.replaceOpWithNewOp<InverseReshapeOpTy>(
+ reshapeOp, resultType, srcReshapeOp.src(), *reassociationIndices);
else
- rewriter.replaceOpWithNewOp<ExpandOpTy>(
- collapseOp, resultType, expandOp.src(), composedReassociation);
+ rewriter.replaceOpWithNewOp<ReshapeOpTy>(
+ reshapeOp, resultType, srcReshapeOp.src(), *reassociationIndices);
return success();
}
};
-template <typename ExpandOpTy, typename CollapseOpTy>
-struct ComposeExpandOfCollapseOp : public OpRewritePattern<ExpandOpTy> {
- using OpRewritePattern<ExpandOpTy>::OpRewritePattern;
- LogicalResult matchAndRewrite(ExpandOpTy expandOp,
- PatternRewriter &rewriter) const override {
- auto collapseOp = expandOp.src().template getDefiningOp<CollapseOpTy>();
- if (!collapseOp)
- return failure();
-
- ShapedType srcType = collapseOp.getSrcType();
- ShapedType resultType = expandOp.getResultType();
-
- if (hasNonIdentityLayout(expandOp.src().getType()) ||
- hasNonIdentityLayout(collapseOp.src().getType()) ||
- hasNonIdentityLayout(collapseOp.result().getType()))
- return failure();
-
- int64_t srcRank = srcType.getRank();
- int64_t resultRank = resultType.getRank();
- if (srcType == resultType)
- return failure();
-
- auto srcReassociation = collapseOp.getReassociationIndices();
- auto resultReassociation = expandOp.getReassociationIndices();
- if (srcRank > resultRank) {
- auto composedReassociation = findCollapsingReassociation(
- srcReassociation, resultReassociation, srcType.getShape(),
- resultType.getShape());
- if (!composedReassociation.hasValue())
- return failure();
-
- rewriter.replaceOpWithNewOp<CollapseOpTy>(
- expandOp, resultType, collapseOp.src(), *composedReassociation);
- return success();
- }
- auto composedReassociation =
- findCollapsingReassociation(resultReassociation, srcReassociation,
- resultType.getShape(), srcType.getShape());
- if (!composedReassociation.hasValue())
- return failure();
-
- rewriter.replaceOpWithNewOp<ExpandOpTy>(
- expandOp, resultType, collapseOp.src(), *composedReassociation);
- return success();
- }
-
-private:
- // Attempts to find a way to collapse `srcShape` to `resultShape` by
- // collapsing subshapes defined by the reassociation indices.
- Optional<SmallVector<ReassociationIndices>> findCollapsingReassociation(
- ArrayRef<ReassociationIndices> srcReassociation,
- ArrayRef<ReassociationIndices> resultReassociation,
- ArrayRef<int64_t> srcShape, ArrayRef<int64_t> resultShape) const {
- SmallVector<ReassociationIndices, 4> composedReassociation;
-
- for (auto item : llvm::zip(srcReassociation, resultReassociation)) {
- auto &srcIndices = std::get<0>(item);
- auto &resultIndices = std::get<1>(item);
- auto srcSubShape = srcShape.slice(srcIndices.front(), srcIndices.size());
- auto resultSubShape =
- resultShape.slice(resultIndices.front(), resultIndices.size());
-
- if (srcSubShape.size() == resultSubShape.size()) {
- if (srcSubShape == resultSubShape)
- composedReassociation.push_back(srcIndices);
- else
- return llvm::None;
- }
-
- // Find reassociation to collapse `srcSubShape` into `resultSubShape`.
- auto subShapeReassociation =
- getReassociationIndicesForCollapse(srcSubShape, resultSubShape);
- if (!subShapeReassociation.hasValue())
- return llvm::None;
-
- // Remap the subshape indices back to the original srcShape.
- for (auto &subshape_indices : *subShapeReassociation) {
- ReassociationIndices shape_indices;
- for (int64_t index : subshape_indices)
- shape_indices.push_back(srcIndices.front() + index);
- composedReassociation.push_back(shape_indices);
- }
- }
- return {std::move(composedReassociation)};
- }
-};
-
} // namespace mlir
#endif // MLIR_DIALECT_UTILS_RESHAPEOPSUTILS_H
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index bb36f3d00d179..5a8bd2b8dd551 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -1793,9 +1793,8 @@ LogicalResult ExpandShapeOp::verify() {
void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<ComposeReassociativeReshapeOps<ExpandShapeOp>,
- ComposeExpandOfCollapseOp<ExpandShapeOp, CollapseShapeOp>>(
- context);
+ results.add<CollapseReshapeOps<ExpandShapeOp>,
+ CollapseMixedReshapeOps<ExpandShapeOp, CollapseShapeOp>>(context);
}
/// Compute the layout map after collapsing a given source MemRef type with the
@@ -2000,8 +1999,8 @@ struct CollapseShapeOpMemRefCastFolder
void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<ComposeReassociativeReshapeOps<CollapseShapeOp>,
- ComposeCollapseOfExpandOp<CollapseShapeOp, ExpandShapeOp>,
+ results.add<CollapseReshapeOps<CollapseShapeOp>,
+ CollapseMixedReshapeOps<CollapseShapeOp, ExpandShapeOp>,
CollapseShapeOpMemRefCastFolder>(context);
}
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 5b52a3fdd24ce..1c8065ec88095 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -890,16 +890,16 @@ struct FoldReshapeWithFromElements : OpRewritePattern<TensorReshapeOp> {
void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<ComposeReassociativeReshapeOps<ExpandShapeOp>,
- ComposeExpandOfCollapseOp<ExpandShapeOp, CollapseShapeOp>,
+ results.add<CollapseReshapeOps<ExpandShapeOp>,
+ CollapseMixedReshapeOps<ExpandShapeOp, CollapseShapeOp>,
FoldReshapeWithConstant<ExpandShapeOp>,
FoldReshapeWithFromElements<ExpandShapeOp>>(context);
}
void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<ComposeReassociativeReshapeOps<CollapseShapeOp>,
- ComposeCollapseOfExpandOp<CollapseShapeOp, ExpandShapeOp>,
+ results.add<CollapseReshapeOps<CollapseShapeOp>,
+ CollapseMixedReshapeOps<CollapseShapeOp, ExpandShapeOp>,
FoldReshapeWithConstant<CollapseShapeOp>,
FoldReshapeWithFromElements<CollapseShapeOp>>(context);
}
diff --git a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
index 64937be9fac05..03cd3af2e7bec 100644
--- a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
+++ b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
@@ -18,23 +18,18 @@ using namespace mlir;
Optional<SmallVector<ReassociationIndices>>
mlir::getReassociationIndicesForReshape(ShapedType sourceType,
ShapedType targetType) {
- if (sourceType.getRank() > targetType.getRank())
- return getReassociationIndicesForCollapse(sourceType.getShape(),
- targetType.getShape());
+ // Make the sourceType greater rank than the targetType. If they are same
+ // rank, then its an unsupported reshape op.
+ if (sourceType.getRank() == targetType.getRank())
+ return llvm::None;
if (sourceType.getRank() < targetType.getRank())
- return getReassociationIndicesForCollapse(targetType.getShape(),
- sourceType.getShape());
- return llvm::None;
-}
+ std::swap(sourceType, targetType);
-Optional<SmallVector<ReassociationIndices>>
-mlir::getReassociationIndicesForCollapse(ArrayRef<int64_t> sourceShape,
- ArrayRef<int64_t> targetShape) {
- if (sourceShape.size() <= targetShape.size())
- return llvm::None;
+ ArrayRef<int64_t> sourceShape = sourceType.getShape();
+ ArrayRef<int64_t> targetShape = targetType.getShape();
unsigned sourceDim = 0;
SmallVector<ReassociationIndices> reassociationMap;
- reassociationMap.reserve(targetShape.size());
+ reassociationMap.reserve(targetType.getRank());
ReassociationIndices currIndices;
int64_t prodOfCollapsedDims = 1;
@@ -42,7 +37,7 @@ mlir::getReassociationIndicesForCollapse(ArrayRef<int64_t> sourceShape,
unsigned targetDim = reassociationMap.size();
// If we have mapped all the target dimensions stop and handle the remaining
// tail of size-1 dimensions explictly.
- if (targetDim == targetShape.size())
+ if (targetDim == targetType.getRank())
break;
int64_t currTargetShape = targetShape[targetDim];
@@ -192,7 +187,6 @@ mlir::getSymbolLessAffineMaps(ArrayRef<ReassociationExprs> reassociation) {
}
return maps;
}
-
bool mlir::isReassociationValid(ArrayRef<AffineMap> reassociation,
int *invalidIndex) {
if (reassociation.empty())
@@ -264,9 +258,3 @@ LogicalResult mlir::reshapeLikeShapesAreCompatible(
}
return success();
}
-
-bool mlir::hasNonIdentityLayout(Type type) {
- if (auto memrefType = type.dyn_cast<MemRefType>())
- return !memrefType.getLayout().isIdentity();
- return false;
-}
diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index 1a01460a24dc9..8a4f80e77b61f 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -302,20 +302,20 @@ func @allocator(%arg0 : memref<memref<?xi32>>, %arg1 : index) {
// -----
-func @compose_collapse_of_collapse_zero_dim(%arg0 : memref<1x1x1xf32>)
- -> memref<f32> {
+func @collapsing_memref_reshapes_to_zero_dim(%arg0 : memref<1x1x1xf32>)
+ -> memref<f32> {
%0 = memref.collapse_shape %arg0 [[0, 1, 2]]
: memref<1x1x1xf32> into memref<1xf32>
%1 = memref.collapse_shape %0 [] : memref<1xf32> into memref<f32>
return %1 : memref<f32>
}
-// CHECK-LABEL: func @compose_collapse_of_collapse_zero_dim
+// CHECK-LABEL: collapsing_memref_reshapes_to_zero
// CHECK: memref.collapse_shape %{{.*}} []
// CHECK-SAME: memref<1x1x1xf32> into memref<f32>
// -----
-func @compose_collapse_of_collapse(%arg0 : memref<?x?x?x?x?xf32>)
+func @collapsing_memref_reshapes(%arg0 : memref<?x?x?x?x?xf32>)
-> memref<?x?xf32> {
%0 = memref.collapse_shape %arg0 [[0, 1], [2], [3, 4]]
: memref<?x?x?x?x?xf32> into memref<?x?x?xf32>
@@ -323,30 +323,13 @@ func @compose_collapse_of_collapse(%arg0 : memref<?x?x?x?x?xf32>)
: memref<?x?x?xf32> into memref<?x?xf32>
return %1 : memref<?x?xf32>
}
-// CHECK-LABEL: func @compose_collapse_of_collapse
+// CHECK-LABEL: collapsing_memref_reshapes
// CHECK: memref.collapse_shape %{{.*}} {{\[}}[0, 1, 2], [3, 4]]
// CHECK-NOT: memref.collapse_shape
// -----
-func @do_not_compose_collapse_of_expand_non_identity_layout(
- %arg0: memref<?x?xf32, offset : 0, strides : [?, 1]>)
- -> memref<?xf32> {
- %1 = memref.expand_shape %arg0 [[0, 1], [2]] :
- memref<?x?xf32, offset : 0, strides : [?, 1]> into
- memref<?x4x?xf32, offset : 0, strides : [?, ?, 1]>
- %2 = memref.collapse_shape %1 [[0, 1, 2]] :
- memref<?x4x?xf32, offset : 0, strides : [?, ?, 1]> into
- memref<?xf32>
- return %2 : memref<?xf32>
-}
-// CHECK-LABEL: func @do_not_compose_collapse_of_expand_non_identity_layout
-// CHECK: expand
-// CHECK: collapse
-
-// -----
-
-func @compose_expand_of_expand(%arg0 : memref<?x?xf32>)
+func @expanding_memref_reshapes(%arg0 : memref<?x?xf32>)
-> memref<?x6x4x5x?xf32> {
%0 = memref.expand_shape %arg0 [[0, 1], [2]]
: memref<?x?xf32> into memref<?x4x?xf32>
@@ -354,46 +337,45 @@ func @compose_expand_of_expand(%arg0 : memref<?x?xf32>)
: memref<?x4x?xf32> into memref<?x6x4x5x?xf32>
return %1 : memref<?x6x4x5x?xf32>
}
-// CHECK-LABEL: func @compose_expand_of_expand
+// CHECK-LABEL: expanding_memref_reshapes
// CHECK: memref.expand_shape %{{.*}} {{\[}}[0, 1, 2], [3, 4]]
// CHECK-NOT: memref.expand_shape
// -----
-func @compose_expand_of_expand_of_zero_dim(%arg0 : memref<f32>)
- -> memref<1x1x1xf32> {
+func @expanding_memref_reshapes_to_zero_dim(%arg0 : memref<f32>)
+ -> memref<1x1x1xf32> {
%0 = memref.expand_shape %arg0 [] : memref<f32> into memref<1xf32>
%1 = memref.expand_shape %0 [[0, 1, 2]]
: memref<1xf32> into memref<1x1x1xf32>
return %1 : memref<1x1x1xf32>
}
-// CHECK-LABEL: func @compose_expand_of_expand_of_zero_dim
+// CHECK-LABEL: expanding_memref_reshapes_to_zero
// CHECK: memref.expand_shape %{{.*}} []
// CHECK-SAME: memref<f32> into memref<1x1x1xf32>
// -----
-func @fold_collapse_of_expand(%arg0 : memref<12x4xf32>) -> memref<12x4xf32> {
+func @fold_memref_reshape(%arg0 : memref<12x4xf32>) -> memref<12x4xf32> {
%0 = memref.expand_shape %arg0 [[0, 1], [2]]
: memref<12x4xf32> into memref<3x4x4xf32>
%1 = memref.collapse_shape %0 [[0, 1], [2]]
: memref<3x4x4xf32> into memref<12x4xf32>
return %1 : memref<12x4xf32>
}
-// CHECK-LABEL: func @fold_collapse_of_expand
+// CHECK-LABEL: @fold_memref_reshape
// CHECK-NOT: linalg.{{.*}}_shape
// -----
-func @fold_collapse_collapse_of_expand(%arg0 : memref<?x?xf32>)
- -> memref<?x?xf32> {
+func @fold_memref_reshape_dynamic(%arg0 : memref<?x?xf32>) -> memref<?x?xf32> {
%0 = memref.expand_shape %arg0 [[0, 1], [2]]
: memref<?x?xf32> into memref<?x4x?xf32>
%1 = memref.collapse_shape %0 [[0, 1], [2]]
: memref<?x4x?xf32> into memref<?x?xf32>
return %1 : memref<?x?xf32>
}
-// CHECK-LABEL: @fold_collapse_collapse_of_expand
+// CHECK-LABEL: @fold_memref_reshape_dynamic
// CHECK-NOT: linalg.{{.*}}_shape
// -----
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 9996b9776c4d5..22770c2e67342 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -646,7 +646,7 @@ func @fold_overlapping_insert(%input : tensor<?x?x?xf32>, %slice1: tensor<4x?x8x
// -----
-func @compose_expand_of_expand(%arg0 : tensor<?x?xf32>)
+func @expanding_tensor_reshapes(%arg0 : tensor<?x?xf32>)
-> tensor<?x6x4x?x5xf32> {
%0 = tensor.expand_shape %arg0 [[0, 1], [2]]
: tensor<?x?xf32> into tensor<?x4x?xf32>
@@ -654,51 +654,49 @@ func @compose_expand_of_expand(%arg0 : tensor<?x?xf32>)
: tensor<?x4x?xf32> into tensor<?x6x4x?x5xf32>
return %1 : tensor<?x6x4x?x5xf32>
}
-// CHECK-LABEL: compose_expand_of_expand
+// CHECK-LABEL: expanding_tensor_reshapes
// CHECK: tensor.expand_shape %{{.*}} {{\[}}[0, 1, 2], [3, 4]]
// CHECK-NOT: tensor.expand_shape
// -----
-func @compose_expand_of_expand_of_zero_dim(%arg0 : tensor<f32>)
+func @expanding_tensor_reshapes_to_zero_dim(%arg0 : tensor<f32>)
-> tensor<1x1x1xf32> {
%0 = tensor.expand_shape %arg0 [] : tensor<f32> into tensor<1xf32>
%1 = tensor.expand_shape %0 [[0, 1, 2]]
: tensor<1xf32> into tensor<1x1x1xf32>
return %1 : tensor<1x1x1xf32>
}
-// CHECK-LABEL: compose_expand_of_expand_of_zero_dim
+// CHECK-LABEL: expanding_tensor_reshapes_to_zero
// CHECK: tensor.expand_shape %{{.*}} []
// CHECK-SAME: tensor<f32> into tensor<1x1x1xf32>
// -----
-func @fold_collapse_of_expand(%arg0 : tensor<12x4xf32>) -> tensor<12x4xf32> {
+func @fold_tensor_reshape(%arg0 : tensor<12x4xf32>) -> tensor<12x4xf32> {
%0 = tensor.expand_shape %arg0 [[0, 1], [2]]
: tensor<12x4xf32> into tensor<3x4x4xf32>
%1 = tensor.collapse_shape %0 [[0, 1], [2]]
: tensor<3x4x4xf32> into tensor<12x4xf32>
return %1 : tensor<12x4xf32>
}
-// CHECK-LABEL: @fold_collapse_of_expand
+// CHECK-LABEL: @fold_tensor_reshape
// CHECK-NOT: linalg.{{.*}}shape
// -----
-func @fold_collapse_of_expand_dynamic(%arg0 : tensor<?x?xf32>)
- -> tensor<?x?xf32> {
+func @fold_tensor_reshape_dynamic(%arg0 : tensor<?x?xf32>) -> tensor<?x?xf32> {
%0 = tensor.expand_shape %arg0 [[0, 1], [2]]
: tensor<?x?xf32> into tensor<?x4x?xf32>
%1 = tensor.collapse_shape %0 [[0, 1], [2]]
: tensor<?x4x?xf32> into tensor<?x?xf32>
return %1 : tensor<?x?xf32>
}
-// CHECK-LABEL: @fold_collapse_of_expand_dynamic
+// CHECK-LABEL: @fold_tensor_reshape_dynamic
// CHECK-NOT: linalg.{{.*}}_shape
// -----
-
-func @compose_expand_of_collapse(%arg0 : tensor<2x3x4x5x6x7x8xf32>)
+func @reshape_collapse(%arg0 : tensor<2x3x4x5x6x7x8xf32>)
-> tensor<24x5x42x8xf32> {
%0 = tensor.collapse_shape %arg0 [[0, 1, 2, 3, 4, 5, 6]]
: tensor<2x3x4x5x6x7x8xf32> into tensor<40320xf32>
@@ -706,7 +704,7 @@ func @compose_expand_of_collapse(%arg0 : tensor<2x3x4x5x6x7x8xf32>)
: tensor<40320xf32> into tensor<24x5x42x8xf32>
return %1 : tensor<24x5x42x8xf32>
}
-// CHECK: func @compose_expand_of_collapse
+// CHECK: func @reshape_collapse
// CHECK-SAME: %[[ARG0:.+]]: tensor<2x3x4x5x6x7x8xf32>
// CHECK: %[[RESULT:.+]] = tensor.collapse_shape %[[ARG0]]
// CHECK-SAME: [0, 1, 2], [3], [4, 5], [6]
@@ -714,7 +712,7 @@ func @compose_expand_of_collapse(%arg0 : tensor<2x3x4x5x6x7x8xf32>)
// -----
-func @compose_expand_of_collapse_7D(%arg0 : tensor<24x5x42x8xf32>)
+func @reshape_expand(%arg0 : tensor<24x5x42x8xf32>)
-> tensor<2x3x4x5x6x7x8xf32> {
%0 = tensor.collapse_shape %arg0 [[0, 1, 2, 3]]
: tensor<24x5x42x8xf32> into tensor<40320xf32>
@@ -722,7 +720,7 @@ func @compose_expand_of_collapse_7D(%arg0 : tensor<24x5x42x8xf32>)
: tensor<40320xf32> into tensor<2x3x4x5x6x7x8xf32>
return %1 : tensor<2x3x4x5x6x7x8xf32>
}
-// CHECK: func @compose_expand_of_collapse_7D
+// CHECK: func @reshape_expand
// CHECK-SAME: %[[ARG0:.+]]: tensor<24x5x42x8xf32>
// CHECK: %[[RESULT:.+]] = tensor.expand_shape %[[ARG0]]
// CHECK-SAME: [0, 1, 2], [3], [4, 5], [6]
@@ -730,37 +728,20 @@ func @compose_expand_of_collapse_7D(%arg0 : tensor<24x5x42x8xf32>)
// -----
-func @compose_collapse_of_expand(%arg : tensor<?x?x?xi64>)
- -> tensor<?x?xi64> {
- %0 = tensor.expand_shape %arg [[0], [1], [2, 3]]
- : tensor<?x?x?xi64> into tensor<?x?x?x1xi64>
- %1 = tensor.collapse_shape %0 [[0, 1], [2, 3]]
- : tensor<?x?x?x1xi64> into tensor<?x?xi64>
- return %1 : tensor<?x?xi64>
-}
-// CHECK-LABEL: func @compose_collapse_of_expand
-// CHECK: (%[[ARG:.*]]: tensor<?x?x?xi64>)
-// CHECK-NEXT: tensor.collapse_shape %[[ARG]]
-// CHECK-SAME: [0, 1], [2]
-// CHECK-SAME: : tensor<?x?x?xi64> into tensor<?x?xi64>
-
-// -----
-
-func @compose_collapse_of_expand_1D(%arg0 : tensor<2048xf32>)
- -> tensor<4x512xf32> {
+func @expand_reshape_1D(%arg0 : tensor<2048xf32>) -> tensor<4x512xf32> {
%0 = tensor.expand_shape %arg0 [[0, 1, 2, 3]]
: tensor<2048xf32> into tensor<1x4x1x512xf32>
%1 = tensor.collapse_shape %0 [[0, 1, 2], [3]]
: tensor<1x4x1x512xf32> into tensor<4x512xf32>
return %1 : tensor<4x512xf32>
}
-// CHECK: func @compose_collapse_of_expand_1D
+// CHECK: func @expand_reshape_1D
// CHECK: tensor.expand_shape %{{.*}} {{\[}}[0, 1]]
// CHECK-SAME: tensor<2048xf32> into tensor<4x512xf32>
// -----
-// CHECK-LABEL: func @zero_rank_reshape_multi
+// CHECK-LABEL: zero_rank_reshape_multi
func @zero_rank_reshape_multi(%arg0: tensor<f32>) -> tensor<f32> {
// CHECK: return %arg0
%0 = tensor.expand_shape %arg0 [] : tensor<f32> into tensor<1xf32>
@@ -771,7 +752,7 @@ func @zero_rank_reshape_multi(%arg0: tensor<f32>) -> tensor<f32> {
// -----
-func @compose_collapse_of_collapse(%arg0 : tensor<?x?x?x?x?xf32>)
+func @collapsing_tensor_reshapes(%arg0 : tensor<?x?x?x?x?xf32>)
-> tensor<?x?xf32> {
%0 = tensor.collapse_shape %arg0 [[0, 1], [2], [3, 4]]
: tensor<?x?x?x?x?xf32> into tensor<?x?x?xf32>
@@ -779,39 +760,39 @@ func @compose_collapse_of_collapse(%arg0 : tensor<?x?x?x?x?xf32>)
: tensor<?x?x?xf32> into tensor<?x?xf32>
return %1 : tensor<?x?xf32>
}
-// CHECK-LABEL: func @compose_collapse_of_collapse
+// CHECK-LABEL: collapsing_tensor_reshapes
// CHECK: tensor.collapse_shape %{{.*}} {{\[}}[0, 1, 2], [3, 4]]
// CHECK-NOT: tensor.collapse_shape
// -----
-func @compose_collapse_of_collapse_zero_dim(%arg0 : tensor<1x1x1xf32>)
+func @collapsing_tensor_reshapes_to_zero_dim(%arg0 : tensor<1x1x1xf32>)
-> tensor<f32> {
%0 = tensor.collapse_shape %arg0 [[0, 1, 2]]
: tensor<1x1x1xf32> into tensor<1xf32>
%1 = tensor.collapse_shape %0 [] : tensor<1xf32> into tensor<f32>
return %1 : tensor<f32>
}
-// CHECK-LABEL: func @compose_collapse_of_collapse_zero_dim
+// CHECK-LABEL: collapsing_tensor_reshapes_to_zero
// CHECK: tensor.collapse_shape %{{.*}} []
// CHECK-SAME: tensor<1x1x1xf32> into tensor<f32>
// -----
-func @fold_collapse_of_expand_1D(%arg0 : tensor<4x512xf32>) -> tensor<2048xf32> {
+func @fold_reshape_1D(%arg0 : tensor<4x512xf32>) -> tensor<2048xf32> {
%0 = tensor.expand_shape %arg0 [[0, 1, 2], [3]]
: tensor<4x512xf32> into tensor<1x4x1x512xf32>
%1 = tensor.collapse_shape %0 [[0, 1, 2, 3]]
: tensor<1x4x1x512xf32> into tensor<2048xf32>
return %1 : tensor<2048xf32>
}
-// CHECK: func @fold_collapse_of_expand_1D
+// CHECK: func @fold_reshape_1D
// CHECK: tensor.collapse_shape %{{.*}} {{\[}}[0, 1]]
// CHECK-SAME: tensor<4x512xf32> into tensor<2048xf32>
// -----
-func @fold_collapse_of_expand_unit_dims(%arg0 : tensor<2048x1x1xf32>)
+func @fold_reshape_unit_dims(%arg0 : tensor<2048x1x1xf32>)
-> tensor<4x512x1x1xf32> {
%0 = tensor.expand_shape %arg0 [[0, 1, 2, 3], [4], [5]]
: tensor<2048x1x1xf32> into tensor<1x4x1x512x1x1xf32>
@@ -819,13 +800,13 @@ func @fold_collapse_of_expand_unit_dims(%arg0 : tensor<2048x1x1xf32>)
: tensor<1x4x1x512x1x1xf32> into tensor<4x512x1x1xf32>
return %1 : tensor<4x512x1x1xf32>
}
-// CHECK: func @fold_collapse_of_expand_unit_dims
+// CHECK: func @fold_reshape_unit_dims
// CHECK: tensor.expand_shape %{{.*}} {{\[}}[0, 1], [2], [3]]
// CHECK-SAME: tensor<2048x1x1xf32> into tensor<4x512x1x1xf32>
// -----
-func @compose_collapse_of_expand_unit_dims(%arg0 : tensor<2048x1x2048xf32>)
+func @expand_reshape_unit_dims(%arg0 : tensor<2048x1x2048xf32>)
-> tensor<4x512x1x512x4xf32> {
%0 = tensor.expand_shape %arg0 [[0, 1, 2, 3, 4], [5], [6, 7, 8]]
: tensor<2048x1x2048xf32> into tensor<1x4x1x512x1x1x512x1x4xf32>
@@ -833,70 +814,69 @@ func @compose_collapse_of_expand_unit_dims(%arg0 : tensor<2048x1x2048xf32>)
: tensor<1x4x1x512x1x1x512x1x4xf32> into tensor<4x512x1x512x4xf32>
return %1 : tensor<4x512x1x512x4xf32>
}
-// CHECK: func @compose_collapse_of_expand_unit_dims
+// CHECK: func @expand_reshape_unit_dims
// CHECK: tensor.expand_shape %{{.*}} {{\[}}[0, 1], [2], [3, 4]]
// CHECK-SAME: tensor<2048x1x2048xf32> into tensor<4x512x1x512x4xf32>
// -----
-func @compose_collapse_of_expand_trailing_unit_dims(%arg0: tensor<2xf32>)
- -> tensor<2x1xf32> {
+func @fold_reshape_trailing_unit_dims(%arg0: tensor<2xf32>) -> tensor<2x1xf32> {
%0 = tensor.expand_shape %arg0 [[0, 1, 2]]
: tensor<2xf32> into tensor<2x1x1xf32>
%1 = tensor.collapse_shape %0 [[0], [1, 2]]
: tensor<2x1x1xf32> into tensor<2x1xf32>
return %1 : tensor<2x1xf32>
}
-// CHECK: func @compose_collapse_of_expand_trailing_unit_dims
+// CHECK: func @fold_reshape_trailing_unit_dims
// CHECK: tensor.expand_shape %{{.*}} {{\[}}[0, 1]]
// CHECK-SAME: tensor<2xf32> into tensor<2x1xf32>
// -----
-func @compose_collapse_of_collapse_unit_dims_dynamic(
- %arg0 : tensor<?x1x?x1x1x?x?x1x1xf32>) -> tensor<?x?x?x?xf32> {
+func @collapse_reshape_unit_dims_dynamic(%arg0 : tensor<?x1x?x1x1x?x?x1x1xf32>)
+ -> tensor<?x?x?x?xf32> {
%0 = tensor.collapse_shape %arg0 [[0], [1, 2], [3], [4], [5], [6, 7, 8]]
: tensor<?x1x?x1x1x?x?x1x1xf32> into tensor<?x?x1x1x?x?xf32>
%1 = tensor.collapse_shape %0 [[0], [1], [2, 3, 4], [5]]
: tensor<?x?x1x1x?x?xf32> into tensor<?x?x?x?xf32>
return %1 : tensor<?x?x?x?xf32>
}
-// CHECK: func @compose_collapse_of_collapse_unit_dims_dynamic
+// CHECK: func @collapse_reshape_unit_dims_dynamic
// CHECK: tensor.collapse_shape
// CHECK-SAME: [0], [1, 2], [3, 4, 5], [6, 7, 8]
// CHECK-SAME: tensor<?x1x?x1x1x?x?x1x1xf32> into tensor<?x?x?x?xf32>
// -----
-func @fold_collapse_of_expand_trailing_unit_dims(%arg0: tensor<2xf32>)
- -> tensor<2x1xf32> {
+func @fold_reshape_trailing_unit_dims(%arg0: tensor<2xf32>) -> tensor<2x1xf32>
+{
%0 = tensor.expand_shape %arg0 [[0, 1, 2]]
: tensor<2xf32> into tensor<2x1x1xf32>
%1 = tensor.collapse_shape %0 [[0], [1, 2]]
: tensor<2x1x1xf32> into tensor<2x1xf32>
return %1 : tensor<2x1xf32>
}
-// CHECK: func @fold_collapse_of_expand_trailing_unit_dims
+// CHECK: func @fold_reshape_trailing_unit_dims
// CHECK: tensor.expand_shape %{{.*}} {{\[}}[0, 1]]
// CHECK-SAME: tensor<2xf32> into tensor<2x1xf32>
// -----
-func @fold_collapse_of_collapse_trailing_unit_dims_dynamic(
- %arg0: tensor<1x1x?x1x1x1xf32>) -> tensor<?xf32> {
+func @fold_reshape_trailing_unit_dims_dynamic(%arg0: tensor<1x1x?x1x1x1xf32>)
+ -> tensor<?xf32> {
%0 = tensor.collapse_shape %arg0 [[0, 1, 2], [3], [4], [5]]
: tensor<1x1x?x1x1x1xf32> into tensor<?x1x1x1xf32>
%1 = tensor.collapse_shape %0 [[0, 1, 2, 3]]
: tensor<?x1x1x1xf32> into tensor<?xf32>
return %1 : tensor<?xf32>
}
-// CHECK: func @fold_collapse_of_collapse_trailing_unit_dims_dynamic
+// CHECK: func @fold_reshape_trailing_unit_dims_dynamic
// CHECK: tensor.collapse_shape %{{.*}} {{\[}}[0, 1, 2, 3, 4, 5]]
// CHECK-SAME: tensor<1x1x?x1x1x1xf32> into tensor<?xf32>
// -----
-func @fold_collapse_of_expand_trailing_unit_dims(%arg0: tensor<12x42x1x1xf32>)
+func @fold_reshape_trailing_unit_dims(%arg0: tensor<12x42x1x1xf32>)
-> tensor<12x42xf32> {
%0 = tensor.expand_shape %arg0 [[0], [1], [2], [3, 4]]
: tensor<12x42x1x1xf32> into tensor<12x42x1x1x1xf32>
@@ -904,28 +884,27 @@ func @fold_collapse_of_expand_trailing_unit_dims(%arg0: tensor<12x42x1x1xf32>)
: tensor<12x42x1x1x1xf32> into tensor<12x42xf32>
return %1 : tensor<12x42xf32>
}
-// CHECK: func @fold_collapse_of_expand_trailing_unit_dims
+// CHECK: func @fold_reshape_trailing_unit_dims
// CHECK: tensor.collapse_shape %{{.*}} {{\[}}[0], [1, 2, 3]]
// CHECK-SAME: tensor<12x42x1x1xf32> into tensor<12x42xf32>
// -----
-func @fold_collapse_of_expand_unit_dims_in_middle(%arg0 : tensor<?x?x?xf32>)
- -> tensor<?x?xf32> {
+func @fold_reshapes_unit_dims_in_middle(%arg0 : tensor<?x?x?xf32>) -> tensor<?x?xf32> {
%0 = tensor.expand_shape %arg0 [[0], [1], [2, 3]]
: tensor<?x?x?xf32> into tensor<?x?x1x?xf32>
%1 = tensor.collapse_shape %0 [[0], [1, 2, 3]]
: tensor<?x?x1x?xf32> into tensor<?x?xf32>
return %1 : tensor<?x?xf32>
}
-// CHECK-LABEL: func @fold_collapse_of_expand_unit_dims_in_middle
+// CHECK-LABEL: func @fold_reshapes_unit_dims_in_middle
// CHECK-SAME: (%[[ARG:.*]]: tensor<?x?x?xf32>
// CHECK: tensor.collapse_shape %[[ARG]] {{\[}}[0], [1, 2]]
// CHECK-SAME: tensor<?x?x?xf32> into tensor<?x?xf32>
// -----
-func @no_fold_collapse_of_expand_incompatible(%arg0 : tensor<4x6x8xf32>)
+func @no_fold_reshape_incompatible(%arg0 : tensor<4x6x8xf32>)
-> tensor<2x6x16xf32> {
%0 = tensor.expand_shape %arg0 [[0, 1], [2, 3], [4]]
: tensor<4x6x8xf32> into tensor<2x2x3x2x8xf32>
@@ -933,21 +912,20 @@ func @no_fold_collapse_of_expand_incompatible(%arg0 : tensor<4x6x8xf32>)
: tensor<2x2x3x2x8xf32> into tensor<2x6x16xf32>
return %1 : tensor<2x6x16xf32>
}
-// CHECK-LABEL: func @no_fold_collapse_of_expand_incompatible
+// CHECK-LABEL: func @no_fold_reshape_incompatible
// CHECK: tensor.expand_shape
// CHECK: tensor.collapse_shape
// -----
-func @no_fold_collapse_of_expand_empty_expr(%arg0: tensor<3x2x2xf32>)
- -> tensor<12x1xf32> {
+func @no_fold_reshape_empty_expr(%arg0: tensor<3x2x2xf32>) -> tensor<12x1xf32> {
%0 = tensor.expand_shape %arg0 [[0], [1], [2, 3]]
: tensor<3x2x2xf32> into tensor<3x2x2x1xf32>
%1 = tensor.collapse_shape %0 [[0, 1, 2], [3]]
: tensor<3x2x2x1xf32> into tensor<12x1xf32>
return %1 : tensor<12x1xf32>
}
-// CHECK: func @no_fold_collapse_of_expand_empty_expr
+// CHECK: func @no_fold_reshape_empty_expr
// CHECK-SAME: %[[ARG0:.+]]: tensor<3x2x2xf32>
// CHECK: %[[RARG0:.+]] = tensor.expand_shape %[[ARG0]]
// CHECK-SAME: [0], [1], [2, 3]
@@ -1024,11 +1002,11 @@ func @fold_rank() -> (index) {
// -----
-// CHECK-LABEL: func @pad_same_static_shape(
+// CHECK-LABEL: func @pad_tensor_same_static_shape(
// CHECK-SAME: %[[ARG0:.*]]: tensor<5x6xf32>
// CHECK-NOT: tensor.pad
// CHECK: return %[[ARG0]]
-func @pad_same_static_shape(%arg0: tensor<5x6xf32>, %a: index)
+func @pad_tensor_same_static_shape(%arg0: tensor<5x6xf32>, %a: index)
-> tensor<5x6xf32> {
%cst = arith.constant 0.000000e+00 : f32
%0 = tensor.pad %arg0 low[%a, 0] high[0, %a] {
@@ -1040,11 +1018,11 @@ func @pad_same_static_shape(%arg0: tensor<5x6xf32>, %a: index)
// -----
-// CHECK-LABEL: func @pad_nofold_same_static_shape(
+// CHECK-LABEL: func @pad_tensor_nofold_same_static_shape(
// CHECK-SAME: %[[ARG0:.*]]: tensor<5x6xf32>
// CHECK: %[[PAD:.*]] = tensor.pad
// CHECK: return %[[PAD]]
-func @pad_nofold_same_static_shape(%arg0: tensor<5x6xf32>, %a: index)
+func @pad_tensor_nofold_same_static_shape(%arg0: tensor<5x6xf32>, %a: index)
-> tensor<5x6xf32> {
%cst = arith.constant 0.000000e+00 : f32
%0 = tensor.pad %arg0 nofold low[%a, 0] high[0, %a] {
@@ -1056,7 +1034,7 @@ func @pad_nofold_same_static_shape(%arg0: tensor<5x6xf32>, %a: index)
// -----
-// CHECK-LABEL: func @pad_after_cast_
diff erent_shape(
+// CHECK-LABEL: func @pad_tensor_after_cast_
diff erent_shape(
// CHECK-SAME: %[[INPUT:.*]]: tensor<?x64x?x?xf32>) -> tensor<?x?x?x?xf32> {
// CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
// CHECK: %[[PADDED:.*]] = tensor.pad %[[INPUT]]
@@ -1068,7 +1046,7 @@ func @pad_nofold_same_static_shape(%arg0: tensor<5x6xf32>, %a: index)
// CHECK-SAME: tensor<?x64x?x?xf32> to tensor<?x?x?x?xf32>
// CHECK: return %[[DYNAMIC]] : tensor<?x?x?x?xf32>
// CHECK: }
-func @pad_after_cast_
diff erent_shape(%arg0: tensor<?x64x?x?xf32>)
+func @pad_tensor_after_cast_
diff erent_shape(%arg0: tensor<?x64x?x?xf32>)
-> tensor<?x?x?x?xf32> {
%cst = arith.constant 0.000000e+00 : f32
%dynamic = tensor.cast %arg0 : tensor<?x64x?x?xf32> to tensor<?x?x?x?xf32>
@@ -1081,7 +1059,7 @@ func @pad_after_cast_
diff erent_shape(%arg0: tensor<?x64x?x?xf32>)
// -----
-// CHECK-LABEL: func @pad_after_cast_same_shape(
+// CHECK-LABEL: func @pad_tensor_after_cast_same_shape(
// CHECK-SAME: %[[INPUT:.*]]: tensor<?x64x?x?xf32>,
// CHECK-SAME: %[[PADDING:.*]]: index) -> tensor<?x?x?x?xf32> {
// CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
@@ -1092,7 +1070,7 @@ func @pad_after_cast_
diff erent_shape(%arg0: tensor<?x64x?x?xf32>)
// CHECK: } : tensor<?x64x?x?xf32> to tensor<?x?x?x?xf32>
// CHECK: return %[[PADDED:.*]] : tensor<?x?x?x?xf32>
// CHECK: }
-func @pad_after_cast_same_shape(%arg0: tensor<?x64x?x?xf32>, %padding : index)
+func @pad_tensor_after_cast_same_shape(%arg0: tensor<?x64x?x?xf32>, %padding : index)
-> tensor<?x?x?x?xf32> {
%cst = arith.constant 0.000000e+00 : f32
%dynamic = tensor.cast %arg0 : tensor<?x64x?x?xf32> to tensor<?x?x?x?xf32>
@@ -1105,11 +1083,11 @@ func @pad_after_cast_same_shape(%arg0: tensor<?x64x?x?xf32>, %padding : index)
// -----
-// CHECK-LABEL: func @pad_of_cast(
+// CHECK-LABEL: func @pad_tensor_of_cast(
// CHECK-NOT: tensor.cast
// CHECK: tensor.pad
// CHECK: tensor<8x?xf32> to tensor<8x32xf32>
-func @pad_of_cast(%t: tensor<8x?xf32>, %s: index) -> tensor<8x32xf32> {
+func @pad_tensor_of_cast(%t: tensor<8x?xf32>, %s: index) -> tensor<8x32xf32> {
%c0 = arith.constant 0 : index
%cst = arith.constant 0.000000e+00 : f32
%0 = tensor.cast %t : tensor<8x?xf32> to tensor<?x?xf32>
@@ -1155,7 +1133,7 @@ func @cast_of_pad_less_static(%arg0: tensor<32x?x?xf32>, %padding: index) -> ten
// -----
-func @pad_cast_fold(%arg0: tensor<4x4xf32>) -> tensor<4x4xf32> {
+func @tensor_pad_cast_fold(%arg0: tensor<4x4xf32>) -> tensor<4x4xf32> {
%c0 = arith.constant 0 : index
%cst = arith.constant 0.0 : f32
%0 = tensor.cast %arg0 : tensor<4x4xf32> to tensor<?x?xf32>
@@ -1165,17 +1143,17 @@ func @pad_cast_fold(%arg0: tensor<4x4xf32>) -> tensor<4x4xf32> {
} : tensor<?x?xf32> to tensor<4x4xf32>
return %1 : tensor<4x4xf32>
}
-// CHECK-LABEL: @pad_cast
+// CHECK-LABEL: @tensor_pad_cast
// CHECK-SAME: %[[ARG0:.+]]: tensor<4x4xf32>
// CHECK: return %[[ARG0]]
// -----
-// CHECK-LABEL: func @fold_pad_source_cast(
+// CHECK-LABEL: func @fold_pad_tensor_source_cast(
// CHECK-SAME: %[[ARG0:.*]]: tensor<4x?xf32>
// CHECK-NOT: tensor.cast
// CHECK: %[[RESULT:.*]] = tensor.pad %[[ARG0]]
-func @fold_pad_source_cast(%arg0: tensor<4x?xf32>) -> tensor<4x4xf32> {
+func @fold_pad_tensor_source_cast(%arg0: tensor<4x?xf32>) -> tensor<4x4xf32> {
%cst = arith.constant 0.0 : f32
%0 = tensor.cast %arg0 : tensor<4x?xf32> to tensor<?x?xf32>
%1 = tensor.pad %0 low[0, 0] high[0, 1] {
More information about the Mlir-commits
mailing list