[Mlir-commits] [mlir] 64f659b - [mlir] Rewrite canonicalization of collapse(expand) and expand(collapse).
Alexander Belyaev
llvmlistbot at llvm.org
Tue Apr 5 01:10:43 PDT 2022
Author: Alexander Belyaev
Date: 2022-04-05T10:03:07+02:00
New Revision: 64f659bee67b5a024defeb3cd2ecf65e1ad8c0a7
URL: https://github.com/llvm/llvm-project/commit/64f659bee67b5a024defeb3cd2ecf65e1ad8c0a7
DIFF: https://github.com/llvm/llvm-project/commit/64f659bee67b5a024defeb3cd2ecf65e1ad8c0a7.diff
LOG: [mlir] Rewrite canonicalization of collapse(expand) and expand(collapse).
Differential Revision: https://reviews.llvm.org/D122666
Added:
Modified:
mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
mlir/test/Dialect/MemRef/canonicalize.mlir
mlir/test/Dialect/Tensor/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
index dfeac25fd6c99..e2b4c0742ffdf 100644
--- a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
@@ -68,6 +68,12 @@ SmallVector<ReassociationIndices, 2> convertReassociationMapsToIndices(
Optional<SmallVector<ReassociationIndices>>
getReassociationIndicesForReshape(ShapedType sourceType, ShapedType targetType);
+/// Returns the reassociation maps to collapse `sourceShape` to `targetShape` if
+/// possible.
+Optional<SmallVector<ReassociationIndices>>
+getReassociationIndicesForCollapse(ArrayRef<int64_t> sourceShape,
+ ArrayRef<int64_t> targetShape);
+
/// Return true if the reassociation specification is valid, false otherwise.
/// When false, the `invalidIndex` integer pointer is optionally filled with the
/// index of the offending reassociation map.
@@ -156,10 +162,13 @@ static LogicalResult verifyReshapeLikeShapes(OpTy op, ShapedType collapsedType,
op.getReassociationIndices(), isExpandingReshape);
}
+/// Returns true iff the type is a MemRefType and has a non-identity layout.
+bool hasNonIdentityLayout(Type type);
+
/// Pattern to collapse producer/consumer reshape ops that are both collapsing
/// dimensions or are both expanding dimensions.
template <typename ReshapeOpTy>
-struct CollapseReshapeOps : public OpRewritePattern<ReshapeOpTy> {
+struct ComposeReassociativeReshapeOps : public OpRewritePattern<ReshapeOpTy> {
using OpRewritePattern<ReshapeOpTy>::OpRewritePattern;
LogicalResult matchAndRewrite(ReshapeOpTy reshapeOp,
PatternRewriter &rewriter) const override {
@@ -168,6 +177,12 @@ struct CollapseReshapeOps : public OpRewritePattern<ReshapeOpTy> {
return failure();
ShapedType resultType = reshapeOp.getResultType();
+
+ if (hasNonIdentityLayout(srcReshapeOp.src().getType()) ||
+ hasNonIdentityLayout(reshapeOp.src().getType()) ||
+ hasNonIdentityLayout(reshapeOp.result().getType()))
+ return failure();
+
Optional<SmallVector<ReassociationIndices>> reassociationIndices =
composeReassociationIndices(srcReshapeOp.getReassociationIndices(),
reshapeOp.getReassociationIndices(),
@@ -180,46 +195,180 @@ struct CollapseReshapeOps : public OpRewritePattern<ReshapeOpTy> {
}
};
-/// Pattern to collapse producer/consumer reshape ops that are both collapsing
-/// dimensions or are both expanding dimensions.
-template <typename ReshapeOpTy, typename InverseReshapeOpTy>
-struct CollapseMixedReshapeOps : public OpRewritePattern<ReshapeOpTy> {
- using OpRewritePattern<ReshapeOpTy>::OpRewritePattern;
- LogicalResult matchAndRewrite(ReshapeOpTy reshapeOp,
+/// Pattern to compose
+/// `collapse_shape(expand_shape(%src, reassociation_1), reassociation_2)`.
+/// In that case both `srcType` and `resultType` can be expressed as a function
+/// of `intermediateType`.
+/// In order to demonstrate the approach, let's assume that `rank(srcType) >
+/// `rank(resultType)`, i.e. the resulting operation should be `collapse_shape`.
+/// In that case, we can iterate over every set of indices in `reassociation_2`
+/// and try to find ids of sets of indices in `reassociation_1` that cover it
+/// completely.
+///
+/// Example:
+///
+/// %0 = tensor.expand_shape %arg [[0], [1], [2, 3]]
+/// : tensor<?x?x?xi64> into tensor<?x?x?x1xi64>
+/// %1 = tensor.collapse_shape %0 [[0, 1], [2, 3]]
+/// : tensor<?x?x?x1xi64> into tensor<?x?xi64>
+///
+/// can be canonicalized into
+///
+/// %0 = tensor.collapse_shape %arg [[0, 1], [2]]
+/// : tensor<?x?x?xi64> into tensor<?x?xi64>
+///
+/// because [0] and [1] from `expand_shape` reassociation cover completely
+/// `[0, 1]` from `collapse_shape`. If it is impossible to find such union of
+/// indices, then we fail.
+//
+/// When `rank(srcType) < rank(resultType)`, then we just swap `reassociation_1`
+/// `reassociation_2` and produce `expand_shape`.
+template <typename CollapseOpTy, typename ExpandOpTy>
+struct ComposeCollapseOfExpandOp : public OpRewritePattern<CollapseOpTy> {
+ using OpRewritePattern<CollapseOpTy>::OpRewritePattern;
+ LogicalResult matchAndRewrite(CollapseOpTy collapseOp,
PatternRewriter &rewriter) const override {
- auto srcReshapeOp =
- reshapeOp.src().template getDefiningOp<InverseReshapeOpTy>();
- if (!srcReshapeOp)
+ auto expandOp = collapseOp.src().template getDefiningOp<ExpandOpTy>();
+ if (!expandOp)
return failure();
- ShapedType srcReshapeSrcType = srcReshapeOp.getSrcType();
- ShapedType intermediateType = reshapeOp.getSrcType();
- ShapedType resultType = reshapeOp.getResultType();
+ ShapedType srcType = expandOp.getSrcType();
+ ShapedType resultType = collapseOp.getResultType();
- // If the source reshape can be collapsed/expanded into the target reshape
- // they can still be folded. This can only be reasoned about statically
- // for cases where
- // - either all shapes are static, or
- // - The number of dynamic dimensions matches in the source of source and
- // result with all other dimensions being 1.
- Optional<SmallVector<ReassociationIndices>> reassociationIndices =
- getReassociationIndicesForReshape(srcReshapeSrcType, resultType);
- if (!reassociationIndices)
+ if (hasNonIdentityLayout(collapseOp.src().getType()) ||
+ hasNonIdentityLayout(expandOp.src().getType()) ||
+ hasNonIdentityLayout(expandOp.result().getType()))
return failure();
- bool originalOpExpands =
- intermediateType.getRank() > srcReshapeSrcType.getRank();
- bool resultingOpExpands =
- resultType.getRank() > srcReshapeSrcType.getRank();
- if (!(resultingOpExpands ^ originalOpExpands))
- rewriter.replaceOpWithNewOp<InverseReshapeOpTy>(
- reshapeOp, resultType, srcReshapeOp.src(), *reassociationIndices);
+
+ int64_t srcRank = srcType.getRank();
+ int64_t resultRank = resultType.getRank();
+ if (srcType == resultType)
+ return failure();
+
+ SmallVector<ReassociationIndices, 4> higherRankReassociation,
+ lowerRankReassociation;
+
+ bool isResultCollapsed = srcRank > resultRank;
+ if (isResultCollapsed) {
+ higherRankReassociation = expandOp.getReassociationIndices();
+ lowerRankReassociation = collapseOp.getReassociationIndices();
+ } else {
+ higherRankReassociation = collapseOp.getReassociationIndices();
+ lowerRankReassociation = expandOp.getReassociationIndices();
+ }
+
+ size_t higherRankIndicesID = 0;
+ SmallVector<ReassociationIndices, 4> composedReassociation;
+ for (const auto &lowerRankIndices : lowerRankReassociation) {
+ ReassociationIndices composedIndices;
+ while (higherRankIndicesID < higherRankReassociation.size()) {
+ auto rightmostIndex =
+ higherRankReassociation[higherRankIndicesID].back();
+ if (rightmostIndex > lowerRankIndices.back())
+ return failure();
+ composedIndices.push_back(higherRankIndicesID++);
+ if (rightmostIndex == lowerRankIndices.back())
+ break;
+ }
+ composedReassociation.push_back(composedIndices);
+ }
+ if (isResultCollapsed)
+ rewriter.replaceOpWithNewOp<CollapseOpTy>(
+ collapseOp, resultType, expandOp.src(), composedReassociation);
else
- rewriter.replaceOpWithNewOp<ReshapeOpTy>(
- reshapeOp, resultType, srcReshapeOp.src(), *reassociationIndices);
+ rewriter.replaceOpWithNewOp<ExpandOpTy>(
+ collapseOp, resultType, expandOp.src(), composedReassociation);
return success();
}
};
+template <typename ExpandOpTy, typename CollapseOpTy>
+struct ComposeExpandOfCollapseOp : public OpRewritePattern<ExpandOpTy> {
+ using OpRewritePattern<ExpandOpTy>::OpRewritePattern;
+ LogicalResult matchAndRewrite(ExpandOpTy expandOp,
+ PatternRewriter &rewriter) const override {
+ auto collapseOp = expandOp.src().template getDefiningOp<CollapseOpTy>();
+ if (!collapseOp)
+ return failure();
+
+ ShapedType srcType = collapseOp.getSrcType();
+ ShapedType resultType = expandOp.getResultType();
+
+ if (hasNonIdentityLayout(expandOp.src().getType()) ||
+ hasNonIdentityLayout(collapseOp.src().getType()) ||
+ hasNonIdentityLayout(collapseOp.result().getType()))
+ return failure();
+
+ int64_t srcRank = srcType.getRank();
+ int64_t resultRank = resultType.getRank();
+ if (srcType == resultType)
+ return failure();
+
+ auto srcReassociation = collapseOp.getReassociationIndices();
+ auto resultReassociation = expandOp.getReassociationIndices();
+ if (srcRank > resultRank) {
+ auto composedReassociation = findCollapsingReassociation(
+ srcReassociation, resultReassociation, srcType.getShape(),
+ resultType.getShape());
+ if (!composedReassociation.hasValue())
+ return failure();
+
+ rewriter.replaceOpWithNewOp<CollapseOpTy>(
+ expandOp, resultType, collapseOp.src(), *composedReassociation);
+ return success();
+ }
+ auto composedReassociation =
+ findCollapsingReassociation(resultReassociation, srcReassociation,
+ resultType.getShape(), srcType.getShape());
+ if (!composedReassociation.hasValue())
+ return failure();
+
+ rewriter.replaceOpWithNewOp<ExpandOpTy>(
+ expandOp, resultType, collapseOp.src(), *composedReassociation);
+ return success();
+ }
+
+private:
+ // Attempts to find a way to collapse `srcShape` to `resultShape` by
+ // collapsing subshapes defined by the reassociation indices.
+ Optional<SmallVector<ReassociationIndices>> findCollapsingReassociation(
+ ArrayRef<ReassociationIndices> srcReassociation,
+ ArrayRef<ReassociationIndices> resultReassociation,
+ ArrayRef<int64_t> srcShape, ArrayRef<int64_t> resultShape) const {
+ SmallVector<ReassociationIndices, 4> composedReassociation;
+
+ for (auto item : llvm::zip(srcReassociation, resultReassociation)) {
+ auto &srcIndices = std::get<0>(item);
+ auto &resultIndices = std::get<1>(item);
+ auto srcSubShape = srcShape.slice(srcIndices.front(), srcIndices.size());
+ auto resultSubShape =
+ resultShape.slice(resultIndices.front(), resultIndices.size());
+
+ if (srcSubShape.size() == resultSubShape.size()) {
+ if (srcSubShape == resultSubShape)
+ composedReassociation.push_back(srcIndices);
+ else
+ return llvm::None;
+ }
+
+ // Find reassociation to collapse `srcSubShape` into `resultSubShape`.
+ auto subShapeReassociation =
+ getReassociationIndicesForCollapse(srcSubShape, resultSubShape);
+ if (!subShapeReassociation.hasValue())
+ return llvm::None;
+
+ // Remap the subshape indices back to the original srcShape.
+ for (auto &subshape_indices : *subShapeReassociation) {
+ ReassociationIndices shape_indices;
+ for (int64_t index : subshape_indices)
+ shape_indices.push_back(srcIndices.front() + index);
+ composedReassociation.push_back(shape_indices);
+ }
+ }
+ return {std::move(composedReassociation)};
+ }
+};
+
} // namespace mlir
#endif // MLIR_DIALECT_UTILS_RESHAPEOPSUTILS_H
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 5a8bd2b8dd551..bb36f3d00d179 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -1793,8 +1793,9 @@ LogicalResult ExpandShapeOp::verify() {
void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<CollapseReshapeOps<ExpandShapeOp>,
- CollapseMixedReshapeOps<ExpandShapeOp, CollapseShapeOp>>(context);
+ results.add<ComposeReassociativeReshapeOps<ExpandShapeOp>,
+ ComposeExpandOfCollapseOp<ExpandShapeOp, CollapseShapeOp>>(
+ context);
}
/// Compute the layout map after collapsing a given source MemRef type with the
@@ -1999,8 +2000,8 @@ struct CollapseShapeOpMemRefCastFolder
void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<CollapseReshapeOps<CollapseShapeOp>,
- CollapseMixedReshapeOps<CollapseShapeOp, ExpandShapeOp>,
+ results.add<ComposeReassociativeReshapeOps<CollapseShapeOp>,
+ ComposeCollapseOfExpandOp<CollapseShapeOp, ExpandShapeOp>,
CollapseShapeOpMemRefCastFolder>(context);
}
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 1c8065ec88095..5b52a3fdd24ce 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -890,16 +890,16 @@ struct FoldReshapeWithFromElements : OpRewritePattern<TensorReshapeOp> {
void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<CollapseReshapeOps<ExpandShapeOp>,
- CollapseMixedReshapeOps<ExpandShapeOp, CollapseShapeOp>,
+ results.add<ComposeReassociativeReshapeOps<ExpandShapeOp>,
+ ComposeExpandOfCollapseOp<ExpandShapeOp, CollapseShapeOp>,
FoldReshapeWithConstant<ExpandShapeOp>,
FoldReshapeWithFromElements<ExpandShapeOp>>(context);
}
void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<CollapseReshapeOps<CollapseShapeOp>,
- CollapseMixedReshapeOps<CollapseShapeOp, ExpandShapeOp>,
+ results.add<ComposeReassociativeReshapeOps<CollapseShapeOp>,
+ ComposeCollapseOfExpandOp<CollapseShapeOp, ExpandShapeOp>,
FoldReshapeWithConstant<CollapseShapeOp>,
FoldReshapeWithFromElements<CollapseShapeOp>>(context);
}
diff --git a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
index 03cd3af2e7bec..64937be9fac05 100644
--- a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
+++ b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
@@ -18,18 +18,23 @@ using namespace mlir;
Optional<SmallVector<ReassociationIndices>>
mlir::getReassociationIndicesForReshape(ShapedType sourceType,
ShapedType targetType) {
- // Make the sourceType greater rank than the targetType. If they are same
- // rank, then its an unsupported reshape op.
- if (sourceType.getRank() == targetType.getRank())
- return llvm::None;
+ if (sourceType.getRank() > targetType.getRank())
+ return getReassociationIndicesForCollapse(sourceType.getShape(),
+ targetType.getShape());
if (sourceType.getRank() < targetType.getRank())
- std::swap(sourceType, targetType);
+ return getReassociationIndicesForCollapse(targetType.getShape(),
+ sourceType.getShape());
+ return llvm::None;
+}
- ArrayRef<int64_t> sourceShape = sourceType.getShape();
- ArrayRef<int64_t> targetShape = targetType.getShape();
+Optional<SmallVector<ReassociationIndices>>
+mlir::getReassociationIndicesForCollapse(ArrayRef<int64_t> sourceShape,
+ ArrayRef<int64_t> targetShape) {
+ if (sourceShape.size() <= targetShape.size())
+ return llvm::None;
unsigned sourceDim = 0;
SmallVector<ReassociationIndices> reassociationMap;
- reassociationMap.reserve(targetType.getRank());
+ reassociationMap.reserve(targetShape.size());
ReassociationIndices currIndices;
int64_t prodOfCollapsedDims = 1;
@@ -37,7 +42,7 @@ mlir::getReassociationIndicesForReshape(ShapedType sourceType,
unsigned targetDim = reassociationMap.size();
// If we have mapped all the target dimensions stop and handle the remaining
// tail of size-1 dimensions explictly.
- if (targetDim == targetType.getRank())
+ if (targetDim == targetShape.size())
break;
int64_t currTargetShape = targetShape[targetDim];
@@ -187,6 +192,7 @@ mlir::getSymbolLessAffineMaps(ArrayRef<ReassociationExprs> reassociation) {
}
return maps;
}
+
bool mlir::isReassociationValid(ArrayRef<AffineMap> reassociation,
int *invalidIndex) {
if (reassociation.empty())
@@ -258,3 +264,9 @@ LogicalResult mlir::reshapeLikeShapesAreCompatible(
}
return success();
}
+
+bool mlir::hasNonIdentityLayout(Type type) {
+ if (auto memrefType = type.dyn_cast<MemRefType>())
+ return !memrefType.getLayout().isIdentity();
+ return false;
+}
diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index 8a4f80e77b61f..1a01460a24dc9 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -302,20 +302,20 @@ func @allocator(%arg0 : memref<memref<?xi32>>, %arg1 : index) {
// -----
-func @collapsing_memref_reshapes_to_zero_dim(%arg0 : memref<1x1x1xf32>)
- -> memref<f32> {
+func @compose_collapse_of_collapse_zero_dim(%arg0 : memref<1x1x1xf32>)
+ -> memref<f32> {
%0 = memref.collapse_shape %arg0 [[0, 1, 2]]
: memref<1x1x1xf32> into memref<1xf32>
%1 = memref.collapse_shape %0 [] : memref<1xf32> into memref<f32>
return %1 : memref<f32>
}
-// CHECK-LABEL: collapsing_memref_reshapes_to_zero
+// CHECK-LABEL: func @compose_collapse_of_collapse_zero_dim
// CHECK: memref.collapse_shape %{{.*}} []
// CHECK-SAME: memref<1x1x1xf32> into memref<f32>
// -----
-func @collapsing_memref_reshapes(%arg0 : memref<?x?x?x?x?xf32>)
+func @compose_collapse_of_collapse(%arg0 : memref<?x?x?x?x?xf32>)
-> memref<?x?xf32> {
%0 = memref.collapse_shape %arg0 [[0, 1], [2], [3, 4]]
: memref<?x?x?x?x?xf32> into memref<?x?x?xf32>
@@ -323,13 +323,30 @@ func @collapsing_memref_reshapes(%arg0 : memref<?x?x?x?x?xf32>)
: memref<?x?x?xf32> into memref<?x?xf32>
return %1 : memref<?x?xf32>
}
-// CHECK-LABEL: collapsing_memref_reshapes
+// CHECK-LABEL: func @compose_collapse_of_collapse
// CHECK: memref.collapse_shape %{{.*}} {{\[}}[0, 1, 2], [3, 4]]
// CHECK-NOT: memref.collapse_shape
// -----
-func @expanding_memref_reshapes(%arg0 : memref<?x?xf32>)
+func @do_not_compose_collapse_of_expand_non_identity_layout(
+ %arg0: memref<?x?xf32, offset : 0, strides : [?, 1]>)
+ -> memref<?xf32> {
+ %1 = memref.expand_shape %arg0 [[0, 1], [2]] :
+ memref<?x?xf32, offset : 0, strides : [?, 1]> into
+ memref<?x4x?xf32, offset : 0, strides : [?, ?, 1]>
+ %2 = memref.collapse_shape %1 [[0, 1, 2]] :
+ memref<?x4x?xf32, offset : 0, strides : [?, ?, 1]> into
+ memref<?xf32>
+ return %2 : memref<?xf32>
+}
+// CHECK-LABEL: func @do_not_compose_collapse_of_expand_non_identity_layout
+// CHECK: expand
+// CHECK: collapse
+
+// -----
+
+func @compose_expand_of_expand(%arg0 : memref<?x?xf32>)
-> memref<?x6x4x5x?xf32> {
%0 = memref.expand_shape %arg0 [[0, 1], [2]]
: memref<?x?xf32> into memref<?x4x?xf32>
@@ -337,45 +354,46 @@ func @expanding_memref_reshapes(%arg0 : memref<?x?xf32>)
: memref<?x4x?xf32> into memref<?x6x4x5x?xf32>
return %1 : memref<?x6x4x5x?xf32>
}
-// CHECK-LABEL: expanding_memref_reshapes
+// CHECK-LABEL: func @compose_expand_of_expand
// CHECK: memref.expand_shape %{{.*}} {{\[}}[0, 1, 2], [3, 4]]
// CHECK-NOT: memref.expand_shape
// -----
-func @expanding_memref_reshapes_to_zero_dim(%arg0 : memref<f32>)
- -> memref<1x1x1xf32> {
+func @compose_expand_of_expand_of_zero_dim(%arg0 : memref<f32>)
+ -> memref<1x1x1xf32> {
%0 = memref.expand_shape %arg0 [] : memref<f32> into memref<1xf32>
%1 = memref.expand_shape %0 [[0, 1, 2]]
: memref<1xf32> into memref<1x1x1xf32>
return %1 : memref<1x1x1xf32>
}
-// CHECK-LABEL: expanding_memref_reshapes_to_zero
+// CHECK-LABEL: func @compose_expand_of_expand_of_zero_dim
// CHECK: memref.expand_shape %{{.*}} []
// CHECK-SAME: memref<f32> into memref<1x1x1xf32>
// -----
-func @fold_memref_reshape(%arg0 : memref<12x4xf32>) -> memref<12x4xf32> {
+func @fold_collapse_of_expand(%arg0 : memref<12x4xf32>) -> memref<12x4xf32> {
%0 = memref.expand_shape %arg0 [[0, 1], [2]]
: memref<12x4xf32> into memref<3x4x4xf32>
%1 = memref.collapse_shape %0 [[0, 1], [2]]
: memref<3x4x4xf32> into memref<12x4xf32>
return %1 : memref<12x4xf32>
}
-// CHECK-LABEL: @fold_memref_reshape
+// CHECK-LABEL: func @fold_collapse_of_expand
// CHECK-NOT: linalg.{{.*}}_shape
// -----
-func @fold_memref_reshape_dynamic(%arg0 : memref<?x?xf32>) -> memref<?x?xf32> {
+func @fold_collapse_collapse_of_expand(%arg0 : memref<?x?xf32>)
+ -> memref<?x?xf32> {
%0 = memref.expand_shape %arg0 [[0, 1], [2]]
: memref<?x?xf32> into memref<?x4x?xf32>
%1 = memref.collapse_shape %0 [[0, 1], [2]]
: memref<?x4x?xf32> into memref<?x?xf32>
return %1 : memref<?x?xf32>
}
-// CHECK-LABEL: @fold_memref_reshape_dynamic
+// CHECK-LABEL: @fold_collapse_collapse_of_expand
// CHECK-NOT: linalg.{{.*}}_shape
// -----
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 22770c2e67342..9996b9776c4d5 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -646,7 +646,7 @@ func @fold_overlapping_insert(%input : tensor<?x?x?xf32>, %slice1: tensor<4x?x8x
// -----
-func @expanding_tensor_reshapes(%arg0 : tensor<?x?xf32>)
+func @compose_expand_of_expand(%arg0 : tensor<?x?xf32>)
-> tensor<?x6x4x?x5xf32> {
%0 = tensor.expand_shape %arg0 [[0, 1], [2]]
: tensor<?x?xf32> into tensor<?x4x?xf32>
@@ -654,49 +654,51 @@ func @expanding_tensor_reshapes(%arg0 : tensor<?x?xf32>)
: tensor<?x4x?xf32> into tensor<?x6x4x?x5xf32>
return %1 : tensor<?x6x4x?x5xf32>
}
-// CHECK-LABEL: expanding_tensor_reshapes
+// CHECK-LABEL: compose_expand_of_expand
// CHECK: tensor.expand_shape %{{.*}} {{\[}}[0, 1, 2], [3, 4]]
// CHECK-NOT: tensor.expand_shape
// -----
-func @expanding_tensor_reshapes_to_zero_dim(%arg0 : tensor<f32>)
+func @compose_expand_of_expand_of_zero_dim(%arg0 : tensor<f32>)
-> tensor<1x1x1xf32> {
%0 = tensor.expand_shape %arg0 [] : tensor<f32> into tensor<1xf32>
%1 = tensor.expand_shape %0 [[0, 1, 2]]
: tensor<1xf32> into tensor<1x1x1xf32>
return %1 : tensor<1x1x1xf32>
}
-// CHECK-LABEL: expanding_tensor_reshapes_to_zero
+// CHECK-LABEL: compose_expand_of_expand_of_zero_dim
// CHECK: tensor.expand_shape %{{.*}} []
// CHECK-SAME: tensor<f32> into tensor<1x1x1xf32>
// -----
-func @fold_tensor_reshape(%arg0 : tensor<12x4xf32>) -> tensor<12x4xf32> {
+func @fold_collapse_of_expand(%arg0 : tensor<12x4xf32>) -> tensor<12x4xf32> {
%0 = tensor.expand_shape %arg0 [[0, 1], [2]]
: tensor<12x4xf32> into tensor<3x4x4xf32>
%1 = tensor.collapse_shape %0 [[0, 1], [2]]
: tensor<3x4x4xf32> into tensor<12x4xf32>
return %1 : tensor<12x4xf32>
}
-// CHECK-LABEL: @fold_tensor_reshape
+// CHECK-LABEL: @fold_collapse_of_expand
// CHECK-NOT: linalg.{{.*}}shape
// -----
-func @fold_tensor_reshape_dynamic(%arg0 : tensor<?x?xf32>) -> tensor<?x?xf32> {
+func @fold_collapse_of_expand_dynamic(%arg0 : tensor<?x?xf32>)
+ -> tensor<?x?xf32> {
%0 = tensor.expand_shape %arg0 [[0, 1], [2]]
: tensor<?x?xf32> into tensor<?x4x?xf32>
%1 = tensor.collapse_shape %0 [[0, 1], [2]]
: tensor<?x4x?xf32> into tensor<?x?xf32>
return %1 : tensor<?x?xf32>
}
-// CHECK-LABEL: @fold_tensor_reshape_dynamic
+// CHECK-LABEL: @fold_collapse_of_expand_dynamic
// CHECK-NOT: linalg.{{.*}}_shape
// -----
-func @reshape_collapse(%arg0 : tensor<2x3x4x5x6x7x8xf32>)
+
+func @compose_expand_of_collapse(%arg0 : tensor<2x3x4x5x6x7x8xf32>)
-> tensor<24x5x42x8xf32> {
%0 = tensor.collapse_shape %arg0 [[0, 1, 2, 3, 4, 5, 6]]
: tensor<2x3x4x5x6x7x8xf32> into tensor<40320xf32>
@@ -704,7 +706,7 @@ func @reshape_collapse(%arg0 : tensor<2x3x4x5x6x7x8xf32>)
: tensor<40320xf32> into tensor<24x5x42x8xf32>
return %1 : tensor<24x5x42x8xf32>
}
-// CHECK: func @reshape_collapse
+// CHECK: func @compose_expand_of_collapse
// CHECK-SAME: %[[ARG0:.+]]: tensor<2x3x4x5x6x7x8xf32>
// CHECK: %[[RESULT:.+]] = tensor.collapse_shape %[[ARG0]]
// CHECK-SAME: [0, 1, 2], [3], [4, 5], [6]
@@ -712,7 +714,7 @@ func @reshape_collapse(%arg0 : tensor<2x3x4x5x6x7x8xf32>)
// -----
-func @reshape_expand(%arg0 : tensor<24x5x42x8xf32>)
+func @compose_expand_of_collapse_7D(%arg0 : tensor<24x5x42x8xf32>)
-> tensor<2x3x4x5x6x7x8xf32> {
%0 = tensor.collapse_shape %arg0 [[0, 1, 2, 3]]
: tensor<24x5x42x8xf32> into tensor<40320xf32>
@@ -720,7 +722,7 @@ func @reshape_expand(%arg0 : tensor<24x5x42x8xf32>)
: tensor<40320xf32> into tensor<2x3x4x5x6x7x8xf32>
return %1 : tensor<2x3x4x5x6x7x8xf32>
}
-// CHECK: func @reshape_expand
+// CHECK: func @compose_expand_of_collapse_7D
// CHECK-SAME: %[[ARG0:.+]]: tensor<24x5x42x8xf32>
// CHECK: %[[RESULT:.+]] = tensor.expand_shape %[[ARG0]]
// CHECK-SAME: [0, 1, 2], [3], [4, 5], [6]
@@ -728,20 +730,37 @@ func @reshape_expand(%arg0 : tensor<24x5x42x8xf32>)
// -----
-func @expand_reshape_1D(%arg0 : tensor<2048xf32>) -> tensor<4x512xf32> {
+func @compose_collapse_of_expand(%arg : tensor<?x?x?xi64>)
+ -> tensor<?x?xi64> {
+ %0 = tensor.expand_shape %arg [[0], [1], [2, 3]]
+ : tensor<?x?x?xi64> into tensor<?x?x?x1xi64>
+ %1 = tensor.collapse_shape %0 [[0, 1], [2, 3]]
+ : tensor<?x?x?x1xi64> into tensor<?x?xi64>
+ return %1 : tensor<?x?xi64>
+}
+// CHECK-LABEL: func @compose_collapse_of_expand
+// CHECK: (%[[ARG:.*]]: tensor<?x?x?xi64>)
+// CHECK-NEXT: tensor.collapse_shape %[[ARG]]
+// CHECK-SAME: [0, 1], [2]
+// CHECK-SAME: : tensor<?x?x?xi64> into tensor<?x?xi64>
+
+// -----
+
+func @compose_collapse_of_expand_1D(%arg0 : tensor<2048xf32>)
+ -> tensor<4x512xf32> {
%0 = tensor.expand_shape %arg0 [[0, 1, 2, 3]]
: tensor<2048xf32> into tensor<1x4x1x512xf32>
%1 = tensor.collapse_shape %0 [[0, 1, 2], [3]]
: tensor<1x4x1x512xf32> into tensor<4x512xf32>
return %1 : tensor<4x512xf32>
}
-// CHECK: func @expand_reshape_1D
+// CHECK: func @compose_collapse_of_expand_1D
// CHECK: tensor.expand_shape %{{.*}} {{\[}}[0, 1]]
// CHECK-SAME: tensor<2048xf32> into tensor<4x512xf32>
// -----
-// CHECK-LABEL: zero_rank_reshape_multi
+// CHECK-LABEL: func @zero_rank_reshape_multi
func @zero_rank_reshape_multi(%arg0: tensor<f32>) -> tensor<f32> {
// CHECK: return %arg0
%0 = tensor.expand_shape %arg0 [] : tensor<f32> into tensor<1xf32>
@@ -752,7 +771,7 @@ func @zero_rank_reshape_multi(%arg0: tensor<f32>) -> tensor<f32> {
// -----
-func @collapsing_tensor_reshapes(%arg0 : tensor<?x?x?x?x?xf32>)
+func @compose_collapse_of_collapse(%arg0 : tensor<?x?x?x?x?xf32>)
-> tensor<?x?xf32> {
%0 = tensor.collapse_shape %arg0 [[0, 1], [2], [3, 4]]
: tensor<?x?x?x?x?xf32> into tensor<?x?x?xf32>
@@ -760,39 +779,39 @@ func @collapsing_tensor_reshapes(%arg0 : tensor<?x?x?x?x?xf32>)
: tensor<?x?x?xf32> into tensor<?x?xf32>
return %1 : tensor<?x?xf32>
}
-// CHECK-LABEL: collapsing_tensor_reshapes
+// CHECK-LABEL: func @compose_collapse_of_collapse
// CHECK: tensor.collapse_shape %{{.*}} {{\[}}[0, 1, 2], [3, 4]]
// CHECK-NOT: tensor.collapse_shape
// -----
-func @collapsing_tensor_reshapes_to_zero_dim(%arg0 : tensor<1x1x1xf32>)
+func @compose_collapse_of_collapse_zero_dim(%arg0 : tensor<1x1x1xf32>)
-> tensor<f32> {
%0 = tensor.collapse_shape %arg0 [[0, 1, 2]]
: tensor<1x1x1xf32> into tensor<1xf32>
%1 = tensor.collapse_shape %0 [] : tensor<1xf32> into tensor<f32>
return %1 : tensor<f32>
}
-// CHECK-LABEL: collapsing_tensor_reshapes_to_zero
+// CHECK-LABEL: func @compose_collapse_of_collapse_zero_dim
// CHECK: tensor.collapse_shape %{{.*}} []
// CHECK-SAME: tensor<1x1x1xf32> into tensor<f32>
// -----
-func @fold_reshape_1D(%arg0 : tensor<4x512xf32>) -> tensor<2048xf32> {
+func @fold_collapse_of_expand_1D(%arg0 : tensor<4x512xf32>) -> tensor<2048xf32> {
%0 = tensor.expand_shape %arg0 [[0, 1, 2], [3]]
: tensor<4x512xf32> into tensor<1x4x1x512xf32>
%1 = tensor.collapse_shape %0 [[0, 1, 2, 3]]
: tensor<1x4x1x512xf32> into tensor<2048xf32>
return %1 : tensor<2048xf32>
}
-// CHECK: func @fold_reshape_1D
+// CHECK: func @fold_collapse_of_expand_1D
// CHECK: tensor.collapse_shape %{{.*}} {{\[}}[0, 1]]
// CHECK-SAME: tensor<4x512xf32> into tensor<2048xf32>
// -----
-func @fold_reshape_unit_dims(%arg0 : tensor<2048x1x1xf32>)
+func @fold_collapse_of_expand_unit_dims(%arg0 : tensor<2048x1x1xf32>)
-> tensor<4x512x1x1xf32> {
%0 = tensor.expand_shape %arg0 [[0, 1, 2, 3], [4], [5]]
: tensor<2048x1x1xf32> into tensor<1x4x1x512x1x1xf32>
@@ -800,13 +819,13 @@ func @fold_reshape_unit_dims(%arg0 : tensor<2048x1x1xf32>)
: tensor<1x4x1x512x1x1xf32> into tensor<4x512x1x1xf32>
return %1 : tensor<4x512x1x1xf32>
}
-// CHECK: func @fold_reshape_unit_dims
+// CHECK: func @fold_collapse_of_expand_unit_dims
// CHECK: tensor.expand_shape %{{.*}} {{\[}}[0, 1], [2], [3]]
// CHECK-SAME: tensor<2048x1x1xf32> into tensor<4x512x1x1xf32>
// -----
-func @expand_reshape_unit_dims(%arg0 : tensor<2048x1x2048xf32>)
+func @compose_collapse_of_expand_unit_dims(%arg0 : tensor<2048x1x2048xf32>)
-> tensor<4x512x1x512x4xf32> {
%0 = tensor.expand_shape %arg0 [[0, 1, 2, 3, 4], [5], [6, 7, 8]]
: tensor<2048x1x2048xf32> into tensor<1x4x1x512x1x1x512x1x4xf32>
@@ -814,69 +833,70 @@ func @expand_reshape_unit_dims(%arg0 : tensor<2048x1x2048xf32>)
: tensor<1x4x1x512x1x1x512x1x4xf32> into tensor<4x512x1x512x4xf32>
return %1 : tensor<4x512x1x512x4xf32>
}
-// CHECK: func @expand_reshape_unit_dims
+// CHECK: func @compose_collapse_of_expand_unit_dims
// CHECK: tensor.expand_shape %{{.*}} {{\[}}[0, 1], [2], [3, 4]]
// CHECK-SAME: tensor<2048x1x2048xf32> into tensor<4x512x1x512x4xf32>
// -----
-func @fold_reshape_trailing_unit_dims(%arg0: tensor<2xf32>) -> tensor<2x1xf32> {
+func @compose_collapse_of_expand_trailing_unit_dims(%arg0: tensor<2xf32>)
+ -> tensor<2x1xf32> {
%0 = tensor.expand_shape %arg0 [[0, 1, 2]]
: tensor<2xf32> into tensor<2x1x1xf32>
%1 = tensor.collapse_shape %0 [[0], [1, 2]]
: tensor<2x1x1xf32> into tensor<2x1xf32>
return %1 : tensor<2x1xf32>
}
-// CHECK: func @fold_reshape_trailing_unit_dims
+// CHECK: func @compose_collapse_of_expand_trailing_unit_dims
// CHECK: tensor.expand_shape %{{.*}} {{\[}}[0, 1]]
// CHECK-SAME: tensor<2xf32> into tensor<2x1xf32>
// -----
-func @collapse_reshape_unit_dims_dynamic(%arg0 : tensor<?x1x?x1x1x?x?x1x1xf32>)
- -> tensor<?x?x?x?xf32> {
+func @compose_collapse_of_collapse_unit_dims_dynamic(
+ %arg0 : tensor<?x1x?x1x1x?x?x1x1xf32>) -> tensor<?x?x?x?xf32> {
%0 = tensor.collapse_shape %arg0 [[0], [1, 2], [3], [4], [5], [6, 7, 8]]
: tensor<?x1x?x1x1x?x?x1x1xf32> into tensor<?x?x1x1x?x?xf32>
%1 = tensor.collapse_shape %0 [[0], [1], [2, 3, 4], [5]]
: tensor<?x?x1x1x?x?xf32> into tensor<?x?x?x?xf32>
return %1 : tensor<?x?x?x?xf32>
}
-// CHECK: func @collapse_reshape_unit_dims_dynamic
+// CHECK: func @compose_collapse_of_collapse_unit_dims_dynamic
// CHECK: tensor.collapse_shape
// CHECK-SAME: [0], [1, 2], [3, 4, 5], [6, 7, 8]
// CHECK-SAME: tensor<?x1x?x1x1x?x?x1x1xf32> into tensor<?x?x?x?xf32>
// -----
-func @fold_reshape_trailing_unit_dims(%arg0: tensor<2xf32>) -> tensor<2x1xf32>
-{
+func @fold_collapse_of_expand_trailing_unit_dims(%arg0: tensor<2xf32>)
+ -> tensor<2x1xf32> {
%0 = tensor.expand_shape %arg0 [[0, 1, 2]]
: tensor<2xf32> into tensor<2x1x1xf32>
%1 = tensor.collapse_shape %0 [[0], [1, 2]]
: tensor<2x1x1xf32> into tensor<2x1xf32>
return %1 : tensor<2x1xf32>
}
-// CHECK: func @fold_reshape_trailing_unit_dims
+// CHECK: func @fold_collapse_of_expand_trailing_unit_dims
// CHECK: tensor.expand_shape %{{.*}} {{\[}}[0, 1]]
// CHECK-SAME: tensor<2xf32> into tensor<2x1xf32>
// -----
-func @fold_reshape_trailing_unit_dims_dynamic(%arg0: tensor<1x1x?x1x1x1xf32>)
- -> tensor<?xf32> {
+func @fold_collapse_of_collapse_trailing_unit_dims_dynamic(
+ %arg0: tensor<1x1x?x1x1x1xf32>) -> tensor<?xf32> {
%0 = tensor.collapse_shape %arg0 [[0, 1, 2], [3], [4], [5]]
: tensor<1x1x?x1x1x1xf32> into tensor<?x1x1x1xf32>
%1 = tensor.collapse_shape %0 [[0, 1, 2, 3]]
: tensor<?x1x1x1xf32> into tensor<?xf32>
return %1 : tensor<?xf32>
}
-// CHECK: func @fold_reshape_trailing_unit_dims_dynamic
+// CHECK: func @fold_collapse_of_collapse_trailing_unit_dims_dynamic
// CHECK: tensor.collapse_shape %{{.*}} {{\[}}[0, 1, 2, 3, 4, 5]]
// CHECK-SAME: tensor<1x1x?x1x1x1xf32> into tensor<?xf32>
// -----
-func @fold_reshape_trailing_unit_dims(%arg0: tensor<12x42x1x1xf32>)
+func @fold_collapse_of_expand_trailing_unit_dims(%arg0: tensor<12x42x1x1xf32>)
-> tensor<12x42xf32> {
%0 = tensor.expand_shape %arg0 [[0], [1], [2], [3, 4]]
: tensor<12x42x1x1xf32> into tensor<12x42x1x1x1xf32>
@@ -884,27 +904,28 @@ func @fold_reshape_trailing_unit_dims(%arg0: tensor<12x42x1x1xf32>)
: tensor<12x42x1x1x1xf32> into tensor<12x42xf32>
return %1 : tensor<12x42xf32>
}
-// CHECK: func @fold_reshape_trailing_unit_dims
+// CHECK: func @fold_collapse_of_expand_trailing_unit_dims
// CHECK: tensor.collapse_shape %{{.*}} {{\[}}[0], [1, 2, 3]]
// CHECK-SAME: tensor<12x42x1x1xf32> into tensor<12x42xf32>
// -----
-func @fold_reshapes_unit_dims_in_middle(%arg0 : tensor<?x?x?xf32>) -> tensor<?x?xf32> {
+func @fold_collapse_of_expand_unit_dims_in_middle(%arg0 : tensor<?x?x?xf32>)
+ -> tensor<?x?xf32> {
%0 = tensor.expand_shape %arg0 [[0], [1], [2, 3]]
: tensor<?x?x?xf32> into tensor<?x?x1x?xf32>
%1 = tensor.collapse_shape %0 [[0], [1, 2, 3]]
: tensor<?x?x1x?xf32> into tensor<?x?xf32>
return %1 : tensor<?x?xf32>
}
-// CHECK-LABEL: func @fold_reshapes_unit_dims_in_middle
+// CHECK-LABEL: func @fold_collapse_of_expand_unit_dims_in_middle
// CHECK-SAME: (%[[ARG:.*]]: tensor<?x?x?xf32>
// CHECK: tensor.collapse_shape %[[ARG]] {{\[}}[0], [1, 2]]
// CHECK-SAME: tensor<?x?x?xf32> into tensor<?x?xf32>
// -----
-func @no_fold_reshape_incompatible(%arg0 : tensor<4x6x8xf32>)
+func @no_fold_collapse_of_expand_incompatible(%arg0 : tensor<4x6x8xf32>)
-> tensor<2x6x16xf32> {
%0 = tensor.expand_shape %arg0 [[0, 1], [2, 3], [4]]
: tensor<4x6x8xf32> into tensor<2x2x3x2x8xf32>
@@ -912,20 +933,21 @@ func @no_fold_reshape_incompatible(%arg0 : tensor<4x6x8xf32>)
: tensor<2x2x3x2x8xf32> into tensor<2x6x16xf32>
return %1 : tensor<2x6x16xf32>
}
-// CHECK-LABEL: func @no_fold_reshape_incompatible
+// CHECK-LABEL: func @no_fold_collapse_of_expand_incompatible
// CHECK: tensor.expand_shape
// CHECK: tensor.collapse_shape
// -----
-func @no_fold_reshape_empty_expr(%arg0: tensor<3x2x2xf32>) -> tensor<12x1xf32> {
+func @no_fold_collapse_of_expand_empty_expr(%arg0: tensor<3x2x2xf32>)
+ -> tensor<12x1xf32> {
%0 = tensor.expand_shape %arg0 [[0], [1], [2, 3]]
: tensor<3x2x2xf32> into tensor<3x2x2x1xf32>
%1 = tensor.collapse_shape %0 [[0, 1, 2], [3]]
: tensor<3x2x2x1xf32> into tensor<12x1xf32>
return %1 : tensor<12x1xf32>
}
-// CHECK: func @no_fold_reshape_empty_expr
+// CHECK: func @no_fold_collapse_of_expand_empty_expr
// CHECK-SAME: %[[ARG0:.+]]: tensor<3x2x2xf32>
// CHECK: %[[RARG0:.+]] = tensor.expand_shape %[[ARG0]]
// CHECK-SAME: [0], [1], [2, 3]
@@ -1002,11 +1024,11 @@ func @fold_rank() -> (index) {
// -----
-// CHECK-LABEL: func @pad_tensor_same_static_shape(
+// CHECK-LABEL: func @pad_same_static_shape(
// CHECK-SAME: %[[ARG0:.*]]: tensor<5x6xf32>
// CHECK-NOT: tensor.pad
// CHECK: return %[[ARG0]]
-func @pad_tensor_same_static_shape(%arg0: tensor<5x6xf32>, %a: index)
+func @pad_same_static_shape(%arg0: tensor<5x6xf32>, %a: index)
-> tensor<5x6xf32> {
%cst = arith.constant 0.000000e+00 : f32
%0 = tensor.pad %arg0 low[%a, 0] high[0, %a] {
@@ -1018,11 +1040,11 @@ func @pad_tensor_same_static_shape(%arg0: tensor<5x6xf32>, %a: index)
// -----
-// CHECK-LABEL: func @pad_tensor_nofold_same_static_shape(
+// CHECK-LABEL: func @pad_nofold_same_static_shape(
// CHECK-SAME: %[[ARG0:.*]]: tensor<5x6xf32>
// CHECK: %[[PAD:.*]] = tensor.pad
// CHECK: return %[[PAD]]
-func @pad_tensor_nofold_same_static_shape(%arg0: tensor<5x6xf32>, %a: index)
+func @pad_nofold_same_static_shape(%arg0: tensor<5x6xf32>, %a: index)
-> tensor<5x6xf32> {
%cst = arith.constant 0.000000e+00 : f32
%0 = tensor.pad %arg0 nofold low[%a, 0] high[0, %a] {
@@ -1034,7 +1056,7 @@ func @pad_tensor_nofold_same_static_shape(%arg0: tensor<5x6xf32>, %a: index)
// -----
-// CHECK-LABEL: func @pad_tensor_after_cast_
diff erent_shape(
+// CHECK-LABEL: func @pad_after_cast_
diff erent_shape(
// CHECK-SAME: %[[INPUT:.*]]: tensor<?x64x?x?xf32>) -> tensor<?x?x?x?xf32> {
// CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
// CHECK: %[[PADDED:.*]] = tensor.pad %[[INPUT]]
@@ -1046,7 +1068,7 @@ func @pad_tensor_nofold_same_static_shape(%arg0: tensor<5x6xf32>, %a: index)
// CHECK-SAME: tensor<?x64x?x?xf32> to tensor<?x?x?x?xf32>
// CHECK: return %[[DYNAMIC]] : tensor<?x?x?x?xf32>
// CHECK: }
-func @pad_tensor_after_cast_
diff erent_shape(%arg0: tensor<?x64x?x?xf32>)
+func @pad_after_cast_
diff erent_shape(%arg0: tensor<?x64x?x?xf32>)
-> tensor<?x?x?x?xf32> {
%cst = arith.constant 0.000000e+00 : f32
%dynamic = tensor.cast %arg0 : tensor<?x64x?x?xf32> to tensor<?x?x?x?xf32>
@@ -1059,7 +1081,7 @@ func @pad_tensor_after_cast_
diff erent_shape(%arg0: tensor<?x64x?x?xf32>)
// -----
-// CHECK-LABEL: func @pad_tensor_after_cast_same_shape(
+// CHECK-LABEL: func @pad_after_cast_same_shape(
// CHECK-SAME: %[[INPUT:.*]]: tensor<?x64x?x?xf32>,
// CHECK-SAME: %[[PADDING:.*]]: index) -> tensor<?x?x?x?xf32> {
// CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
@@ -1070,7 +1092,7 @@ func @pad_tensor_after_cast_
diff erent_shape(%arg0: tensor<?x64x?x?xf32>)
// CHECK: } : tensor<?x64x?x?xf32> to tensor<?x?x?x?xf32>
// CHECK: return %[[PADDED:.*]] : tensor<?x?x?x?xf32>
// CHECK: }
-func @pad_tensor_after_cast_same_shape(%arg0: tensor<?x64x?x?xf32>, %padding : index)
+func @pad_after_cast_same_shape(%arg0: tensor<?x64x?x?xf32>, %padding : index)
-> tensor<?x?x?x?xf32> {
%cst = arith.constant 0.000000e+00 : f32
%dynamic = tensor.cast %arg0 : tensor<?x64x?x?xf32> to tensor<?x?x?x?xf32>
@@ -1083,11 +1105,11 @@ func @pad_tensor_after_cast_same_shape(%arg0: tensor<?x64x?x?xf32>, %padding : i
// -----
-// CHECK-LABEL: func @pad_tensor_of_cast(
+// CHECK-LABEL: func @pad_of_cast(
// CHECK-NOT: tensor.cast
// CHECK: tensor.pad
// CHECK: tensor<8x?xf32> to tensor<8x32xf32>
-func @pad_tensor_of_cast(%t: tensor<8x?xf32>, %s: index) -> tensor<8x32xf32> {
+func @pad_of_cast(%t: tensor<8x?xf32>, %s: index) -> tensor<8x32xf32> {
%c0 = arith.constant 0 : index
%cst = arith.constant 0.000000e+00 : f32
%0 = tensor.cast %t : tensor<8x?xf32> to tensor<?x?xf32>
@@ -1133,7 +1155,7 @@ func @cast_of_pad_less_static(%arg0: tensor<32x?x?xf32>, %padding: index) -> ten
// -----
-func @tensor_pad_cast_fold(%arg0: tensor<4x4xf32>) -> tensor<4x4xf32> {
+func @pad_cast_fold(%arg0: tensor<4x4xf32>) -> tensor<4x4xf32> {
%c0 = arith.constant 0 : index
%cst = arith.constant 0.0 : f32
%0 = tensor.cast %arg0 : tensor<4x4xf32> to tensor<?x?xf32>
@@ -1143,17 +1165,17 @@ func @tensor_pad_cast_fold(%arg0: tensor<4x4xf32>) -> tensor<4x4xf32> {
} : tensor<?x?xf32> to tensor<4x4xf32>
return %1 : tensor<4x4xf32>
}
-// CHECK-LABEL: @tensor_pad_cast
+// CHECK-LABEL: @pad_cast
// CHECK-SAME: %[[ARG0:.+]]: tensor<4x4xf32>
// CHECK: return %[[ARG0]]
// -----
-// CHECK-LABEL: func @fold_pad_tensor_source_cast(
+// CHECK-LABEL: func @fold_pad_source_cast(
// CHECK-SAME: %[[ARG0:.*]]: tensor<4x?xf32>
// CHECK-NOT: tensor.cast
// CHECK: %[[RESULT:.*]] = tensor.pad %[[ARG0]]
-func @fold_pad_tensor_source_cast(%arg0: tensor<4x?xf32>) -> tensor<4x4xf32> {
+func @fold_pad_source_cast(%arg0: tensor<4x?xf32>) -> tensor<4x4xf32> {
%cst = arith.constant 0.0 : f32
%0 = tensor.cast %arg0 : tensor<4x?xf32> to tensor<?x?xf32>
%1 = tensor.pad %0 low[0, 0] high[0, 1] {
More information about the Mlir-commits
mailing list