[Mlir-commits] [mlir] Fixes in 'tosa.reshape' lowering and folder (PR #85798)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Mar 19 07:41:35 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-tosa
@llvm/pr-subscribers-mlir
Author: Rafael Ubal (rafaelubalmw)
<details>
<summary>Changes</summary>
This pull request addresses missing features in the lowering conversion pattern for the `tosa.reshape` op and a bug in its canonicalizer.
- Example of a valid use of `tosa.reshape` previously not supported:
```
func.func @<!-- -->main(%input: tensor<?x?x?xf32>) -> tensor<1x3x2x1xf32> {
%0 = tosa.reshape %input {new_shape = array<i64: 1, 3, 2, 1>} : (tensor<?x?x?xf32>) -> tensor<1x3x2x1xf32>
return %0 : tensor<1x3x2x1xf32>
}
```
- The new lowering is based on the use of `tensor.reshape` instead of a combination of `tensor.collapse_shape` + `tensor.expand_shape`.
- When no -1 placeholder is present in the `new_shape` attribute, the target shape is encoded with an `arith.constant` op and the reshape occurs with a `tensor.reshape` op.
- When a -1 placeholder is used in `new_shape` and the corresponding dimension in the result type is dynamic, the missing dimension size is inferred by calculating the input tensor size (`tensor.collapse_shape` + `tensor.dim`) and dividing it by the product of all other target dimension sizes (`arith.divui`).
- When a -1 placeholder is used in `new_shape` and the corresponding dimension in the result type is static, the missing dimension size is grabbed from the result type.
- Fixed bug in canonicalization pattern `tosa::ReshapeOp::fold()`. The input and result types being equal is not a sufficient condition for folding. If there is more than 1 dynamic dimension in the input and result types, a productive reshape could still occur. Unit tests are now available for 1 (existing test) and 2 (new test) dynamic dimensions.
---
Patch is 25.70 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/85798.diff
4 Files Affected:
- (modified) mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp (+78-195)
- (modified) mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp (+4-1)
- (modified) mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir (+144-23)
- (modified) mlir/test/Dialect/Tosa/canonicalize.mlir (+8)
``````````diff
diff --git a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
index 505d85f211111c..62ed41ebda4f50 100644
--- a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
+++ b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
@@ -19,217 +19,98 @@
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/DialectConversion.h"
+#include <numeric>
+
using namespace mlir;
using namespace tosa;
-static bool findIntermediateShape(ArrayRef<int64_t> lhsShape,
- ArrayRef<int64_t> rhsShape,
- SmallVector<int64_t> &intermediateShape,
- bool isDynamic) {
- if (isDynamic) {
- // TODO (natashaknk): Make dynamic intermediate shape not always be rank-1
- intermediateShape = {ShapedType::kDynamic};
- return true;
- }
-
- if (lhsShape.empty() || rhsShape.empty()) {
- intermediateShape = {};
- return true;
- }
-
- unsigned currLhsDim = 0, currRhsDim = 0;
- while (currLhsDim < lhsShape.size() && currRhsDim < rhsShape.size()) {
- int64_t rhsSize = rhsShape[currRhsDim];
- int64_t lhsSize = lhsShape[currLhsDim];
- while (lhsSize != rhsSize && currLhsDim < lhsShape.size() &&
- currRhsDim < rhsShape.size()) {
- if (lhsSize < rhsSize) {
- currLhsDim++;
- if (currLhsDim < lhsShape.size()) {
- lhsSize *= lhsShape[currLhsDim];
- }
- } else {
- currRhsDim++;
- if (currRhsDim < rhsShape.size()) {
- rhsSize *= rhsShape[currRhsDim];
- }
- }
- }
- if (lhsSize == rhsSize) {
- intermediateShape.push_back(lhsSize);
- }
- currRhsDim++;
- currLhsDim++;
- }
-
- // If the iterators didn't reach the end and their leftover dimensions are not
- // equal to 1 an intermediate shape was not found.
- while (currLhsDim < lhsShape.size()) {
- if (lhsShape[currLhsDim++] != 1) {
- return false;
- }
- }
-
- while (currRhsDim < rhsShape.size()) {
- if (rhsShape[currRhsDim++] != 1) {
- return false;
- }
- }
-
- return true;
+static Value getIndexConstant(OpBuilder& builder, Location loc, int64_t index) {
+ return builder.create<arith::ConstantIndexOp>(loc, index);
}
-static bool createReassociationMapsForCollapse(
- PatternRewriter &rewriter, ArrayRef<int64_t> srcShape,
- ArrayRef<int64_t> dstShape,
- SmallVector<ReassociationExprs, 4> &reassociationMap, bool isDynamic) {
-
- // If the shape is dynamic, create a map for collapsing into one dimension.
- if (isDynamic) {
- SmallVector<AffineExpr, 2> exprs;
- for (int i = 0, s = srcShape.size(); i < s; ++i)
- exprs.push_back(rewriter.getAffineDimExpr(i));
- reassociationMap = {exprs};
- return true;
- }
-
- if (dstShape.empty()) {
- reassociationMap = {};
- return true;
- }
-
- reassociationMap.resize(dstShape.size());
- unsigned currSrcDim = 0, currDstDim = 0;
- while (currSrcDim < srcShape.size() && currDstDim < dstShape.size()) {
- int64_t dstSize = dstShape[currDstDim];
- int64_t srcSize = srcShape[currSrcDim];
- while (srcSize < dstSize && currSrcDim < srcShape.size()) {
- reassociationMap[currDstDim].push_back(
- rewriter.getAffineDimExpr(currSrcDim++));
- srcSize *= srcShape[currSrcDim];
- }
- if (srcSize == dstSize) {
- reassociationMap[currDstDim].push_back(
- rewriter.getAffineDimExpr(currSrcDim++));
- // If the next dim in collapsedShape is not 1, treat subsequent dims in
- // expandedShape which are 1 to be collapsed.
- if (currDstDim == dstShape.size() - 1 || dstShape[currDstDim + 1] != 1) {
- while (currSrcDim < srcShape.size() && srcShape[currSrcDim] == 1) {
- reassociationMap[currDstDim].push_back(
- rewriter.getAffineDimExpr(currSrcDim++));
- }
- }
- }
- currDstDim++;
+// Return the total size of the given input tensor.
+static Value getTensorSize(OpBuilder& builder, Location loc, TypedValue<TensorType> input) {
+ // If the input tensor is statically shaped, return its size as a constant.
+ if (input.getType().hasStaticShape()) {
+ auto shape = input.getType().getShape();
+ auto size = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies());
+ return getIndexConstant(builder, loc, size);
}
- // If both iterators didn't reach the end, we have leftover dimentions which
- // implies that we have a mismatch in shape.
- return currSrcDim == srcShape.size() && currDstDim == dstShape.size();
+ // When the input tensor has at least one dynamic dimension, collapse it into
+ // a 1D tensor and get its size.
+ auto rank = input.getType().getRank();
+ auto elementType = input.getType().getElementType();
+ auto collapsedType = RankedTensorType::get({ShapedType::kDynamic}, elementType);
+ auto reassociationIndices = SmallVector<ReassociationIndices>{
+ llvm::to_vector(llvm::seq<int64_t>(rank))
+ };
+ auto collapsed = builder.create<tensor::CollapseShapeOp>(
+ loc, collapsedType, input, reassociationIndices);
+ return builder.create<tensor::DimOp>(loc, collapsed, 0);
}
-namespace {
-Value createCollapse(ConversionPatternRewriter &rewriter, Location loc,
- ShapedType resultTy, Value operand) {
- ShapedType operandTy = cast<ShapedType>(operand.getType());
- if (resultTy == operandTy)
- return operand;
-
- bool isDynamic = !operandTy.hasStaticShape();
-
- if (isDynamic && resultTy.getRank() != 1) {
- (void)rewriter.notifyMatchFailure(
- loc, "Cannot collapse dynamic dims to more than one dimension");
- return {};
- }
-
- SmallVector<ReassociationExprs, 4> reassociationMap;
- if (!createReassociationMapsForCollapse(rewriter, operandTy.getShape(),
- resultTy.getShape(),
- reassociationMap, isDynamic)) {
- (void)rewriter.notifyMatchFailure(
- loc, "tosa.reshape Attempting to collapse into an incompatible shape");
- return {};
- }
-
- SmallVector<int64_t> intermediateShape;
- if (!findIntermediateShape(operandTy.getShape(), resultTy.getShape(),
- intermediateShape, isDynamic)) {
- (void)rewriter.notifyMatchFailure(
- loc, "tosa.reshape Cannot collapse into given shape");
- return {};
- }
- return rewriter.create<tensor::CollapseShapeOp>(loc, resultTy, operand,
- reassociationMap);
+// Compute the dimension size of the result tensor corresponding to the
+// placeholder value set to -1 in the 'new_shape' attribute of a 'tosa.reshape'
+// op. Argument 'index' indicates the position of the -1 placeholder.
+static Value getReshapePlaceholderDimSize(OpBuilder &builder,
+ tosa::ReshapeOp reshape,
+ int64_t index) {
+ auto loc = reshape.getLoc();
+ auto input = reshape.getInput1();
+ auto newShape = reshape.getNewShape();
+ auto resultType = reshape.getResult().getType();
+
+ // If the corresponding dimension in the result type is static, take the
+ // dimension size from there.
+ assert(newShape[index] == -1);
+ if (!resultType.isDynamicDim(index))
+ return getIndexConstant(builder, loc, resultType.getDimSize(index));
+
+ // Calculate the product of all dimensions in the new shape. We expect to have
+ // exactly one size set to -1, so we can discard this component by just
+ // negating the final product.
+ auto newSizeLiteral = -std::accumulate(newShape.begin(), newShape.end(), 1,
+ std::multiplies<int64_t>());
+ assert(newSizeLiteral >= 0);
+ auto newSize = builder.create<arith::ConstantIndexOp>(loc, newSizeLiteral);
+
+ // Avoid a division by zero. If any of the given dimension sizes was set to
+ // zero, set the placeholder size to zero, too.
+ if (newSizeLiteral == 0)
+ return newSize;
+
+ // The size of the placeholder dimension is the size of the input tensor
+ // divided by all non-placeholder dimension sizes.
+ auto inputSize = getTensorSize(builder, loc, input);
+ return builder.createOrFold<arith::DivUIOp>(loc, inputSize, newSize);
}
-Value createExpand(ConversionPatternRewriter &rewriter, Location loc,
- ShapedType resultTy, Value operand) {
- ShapedType operandTy = cast<ShapedType>(operand.getType());
- if (resultTy == operandTy)
- return operand;
-
- bool isDynamic = !operandTy.hasStaticShape();
-
- if (isDynamic && operandTy.getRank() != 1) {
- (void)rewriter.notifyMatchFailure(
- loc, "Cannot expand dynamic dims from more than one dimension");
- return {};
- }
-
- SmallVector<ReassociationExprs, 4> reassociationMap;
- if (!createReassociationMapsForCollapse(rewriter, resultTy.getShape(),
- operandTy.getShape(),
- reassociationMap, isDynamic)) {
- (void)rewriter.notifyMatchFailure(
- loc, "tosa.reshape Attempting to expand into an incompatible shape");
- return {};
- }
-
- SmallVector<int64_t> intermediateShape;
- if (!findIntermediateShape(operandTy.getShape(), resultTy.getShape(),
- intermediateShape, isDynamic) ||
- intermediateShape != operandTy.getShape()) {
- (void)rewriter.notifyMatchFailure(
- loc, "tosa.reshape Cannot expand into given shape");
- return {};
- }
- return rewriter.create<tensor::ExpandShapeOp>(loc, resultTy, operand,
- reassociationMap);
-}
+namespace {
-class ReshapeConverterCollapseExpand
- : public OpConversionPattern<tosa::ReshapeOp> {
+class ReshapeConverter : public OpConversionPattern<tosa::ReshapeOp> {
public:
using OpConversionPattern<tosa::ReshapeOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(tosa::ReshapeOp reshape, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const final {
- ShapedType operandTy = cast<ShapedType>(adaptor.getInput1().getType());
- ShapedType resultTy = cast<ShapedType>(reshape.getType());
- bool isDynamic = !operandTy.hasStaticShape();
-
- SmallVector<int64_t> intermediateShape;
- if (!findIntermediateShape(resultTy.getShape(), operandTy.getShape(),
- intermediateShape, isDynamic)) {
- return rewriter.notifyMatchFailure(
- reshape, "tosa.reshape Cannot identify an intermediate shape between "
- "the given two shapes");
+ auto loc = reshape.getLoc();
+ auto input = reshape.getInput1();
+
+ // Create list of values for new shape
+ SmallVector<Value> newShapeVector(reshape.getNewShape().size());
+ for (auto [index, size] : llvm::enumerate(reshape.getNewShape())) {
+ newShapeVector[index] = size == -1 ?
+ getReshapePlaceholderDimSize(rewriter, reshape, index) :
+ getIndexConstant(rewriter, loc, size);
}
- auto intermediateTy = RankedTensorType::get(
- intermediateShape, reshape.getType().getElementType());
- Value collapse = createCollapse(rewriter, reshape.getLoc(), intermediateTy,
- adaptor.getInput1());
- if (!collapse)
- return failure();
-
- Value expand = createExpand(rewriter, reshape.getLoc(), resultTy, collapse);
- if (!expand)
- return failure();
-
- rewriter.replaceOp(reshape, expand);
+ // Reshape tensor
+ auto newShapeTensor = rewriter.createOrFold<tensor::FromElementsOp>(
+ loc, newShapeVector);
+ rewriter.replaceOpWithNewOp<tensor::ReshapeOp>(
+ reshape, reshape.getResult().getType(), input, newShapeTensor);
return success();
}
};
@@ -416,8 +297,10 @@ struct ConcatConverter : public OpConversionPattern<tosa::ConcatOp> {
void mlir::tosa::populateTosaToTensorConversionPatterns(
RewritePatternSet *patterns) {
- patterns->add<SliceConverter, PadConverter, ConcatConverter>(
- patterns->getContext());
-
- patterns->add<ReshapeConverterCollapseExpand>(patterns->getContext());
+ patterns->add<
+ ConcatConverter,
+ PadConverter,
+ ReshapeConverter,
+ SliceConverter
+ >(patterns->getContext());
}
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index 4c50aaecfe9488..d23c9fe824c94a 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -795,7 +795,10 @@ OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) {
if (!inputTy || !outputTy)
return {};
- if (inputTy == outputTy)
+ // Fold when the input and output types are the same. This is only safe when
+ // there is at most 1 dynamic dimension. For 2 or more dynamic dimensions,
+ // there may still be a productive reshape.
+ if (inputTy == outputTy && inputTy.getNumDynamicDims() < 2)
return getInput1();
// reshape(reshape(x)) -> reshape(x)
diff --git a/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir b/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir
index daaa68a7260b71..e1fd7838293b6a 100644
--- a/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir
+++ b/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir
@@ -1,11 +1,15 @@
// RUN: mlir-opt --split-input-file --tosa-to-tensor %s -o -| FileCheck %s
+// -----
+
// CHECK-LABEL: @test_reshape_downrank
// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]
func.func @test_reshape_downrank(%arg0: tensor<2x3xf32>) -> tensor<6xf32> {
- // CHECK: [[RESHAPE:%.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0, 1]]
+ // CHECK: %[[SHAPE:.+]] = arith.constant dense<6> : tensor<1xindex>
+ // CHECK: %[[RESHAPE:.+]] = tensor.reshape %[[ARG0]](%[[SHAPE]]) : (tensor<2x3xf32>, tensor<1xindex>) -> tensor<6xf32>
%0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 6>} : (tensor<2x3xf32>) -> tensor<6xf32>
- // CHECK: return [[RESHAPE]]
+
+ // CHECK: return %[[RESHAPE]] : tensor<6xf32>
return %0 : tensor<6xf32>
}
@@ -14,9 +18,16 @@ func.func @test_reshape_downrank(%arg0: tensor<2x3xf32>) -> tensor<6xf32> {
// CHECK-LABEL: @test_reshape_downrank_dyn
// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]
func.func @test_reshape_downrank_dyn(%arg0: tensor<2x?xf32>) -> tensor<?xf32> {
- // CHECK: [[RESHAPE:%.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0, 1]]
+ // CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+
+ // CHECK-DAG: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[ARG0]] {{[\[]}}[0, 1]] : tensor<2x?xf32> into tensor<?xf32>
+ // CHECK-DAG: %[[SIZE:.+]] = tensor.dim %[[COLLAPSED]], %[[C0]] : tensor<?xf32>
+
+ // CHECK-DAG: %[[SHAPE:.+]] = tensor.from_elements %[[SIZE]] : tensor<1xindex>
+ // CHECK-DAG: %[[RESHAPED:.+]] = tensor.reshape %[[ARG0]](%[[SHAPE]]) : (tensor<2x?xf32>, tensor<1xindex>) -> tensor<?xf32>
%0 = "tosa.reshape"(%arg0) {new_shape = array<i64: -1>} : (tensor<2x?xf32>) -> tensor<?xf32>
- // CHECK: return [[RESHAPE]]
+
+ // CHECK: return %[[RESHAPED]] : tensor<?xf32>
return %0 : tensor<?xf32>
}
@@ -25,9 +36,10 @@ func.func @test_reshape_downrank_dyn(%arg0: tensor<2x?xf32>) -> tensor<?xf32> {
// CHECK-LABEL: @test_reshape_uprank
// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]
func.func @test_reshape_uprank(%arg0: tensor<6xf32>) -> tensor<2x3xf32> {
- // CHECK: [[RESHAPE:%.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1]]
+ // CHECK: %[[SHAPE:.+]] = arith.constant dense<[2, 3]> : tensor<2xindex>
+ // CHECK: %[[RESHAPE:.+]] = tensor.reshape %[[ARG0]](%[[SHAPE]]) : (tensor<6xf32>, tensor<2xindex>) -> tensor<2x3xf32>
%0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 2, 3>} : (tensor<6xf32>) -> tensor<2x3xf32>
- // CHECK: return [[RESHAPE]]
+ // CHECK: return %[[RESHAPE]] : tensor<2x3xf32>
return %0 : tensor<2x3xf32>
}
@@ -36,57 +48,166 @@ func.func @test_reshape_uprank(%arg0: tensor<6xf32>) -> tensor<2x3xf32> {
// CHECK-LABEL: @test_reshape_uprank_dyn
// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]
func.func @test_reshape_uprank_dyn(%arg0: tensor<?xf32>) -> tensor<2x?xf32> {
- // CHECK: [[RESHAPE:%.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1]]
+ // CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+ // CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
+ // CHECK-DAG: %[[C2_0:.+]] = arith.constant 2 : index
+
+ // CHECK-DAG: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[ARG0]] {{[\[]}}[0]] : tensor<?xf32> into tensor<?xf32>
+ // CHECK-DAG: %[[SIZE:.+]] = tensor.dim %[[COLLAPSED]], %[[C0]] : tensor<?xf32>
+ // CHECK-DAG: %[[PLACEHOLDER:.+]] = arith.divui %[[SIZE]], %[[C2_0]] : index
+
+ // CHECK: %[[SHAPE:.+]] = tensor.from_elements %[[C2]], %[[PLACEHOLDER]] : tensor<2xindex>
+ // CHECK: %[[RESHAPED:.+]] = tensor.reshape %[[ARG0]](%[[SHAPE]]) : (tensor<?xf32>, tensor<2xindex>) -> tensor<2x?xf32>
%0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 2, -1>} : (tensor<?xf32>) -> tensor<2x?xf32>
- // CHECK: return [[RESHAPE]]
+
+ // CHECK: return %[[RESHAPED]] : tensor<2x?xf32>
return %0 : tensor<2x?xf32>
}
// -----
// CHECK-LABEL: @test_reshape_samerank
-// CHECK-SAME: (%[[ARG0:.*]]: tensor<3x2xf32>)
+// CHECK-SAME: (%[[ARG0:.*]]: tensor<3x2xf32>)
func.func @test_reshape_samerank(%arg0: tensor<3x2xf32>) -> tensor<2x3xf32> {
- // CHECK-NEXT: %[[RESHAPE1:.*]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0, 1]]
- // CHECK-NEXT: %[[RESHAPE2:.*]] = tensor.expand_shape %[[RESHAPE1]] {{\[}}[0, 1]]
+ // CHECK: %[[SHAPE:.*]] = arith.constant dense<[2, 3]> : tensor<2xindex>
+ // CHECK: %[[RESHAPED:.*]] = tensor.reshape %[[ARG0]](%[[SHAPE]]) : (tensor<3x2xf32>, tensor<2xindex>) -> tensor<2x3xf32>
%0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 2, 3>} : (tensor<3x2xf32>) -> tensor<2x3xf32>
- // CHECK-NEXT: return %[[RESHAPE2]]
+
+ // CHECK: return %[[RESHAPED]] : tensor<2x3xf32>
return %0 : tensor<2x3xf32>
}
// -----
// CHECK-LABEL: @test_reshape_samerank_dyn
-// CHECK-SAME: (%[[ARG0:.*]]: tensor<?x2xf32>)
+// CHECK-SAME: (%[[ARG0:.*]]: tensor<?x2xf32>)
func.func @test_reshape_samerank_dyn(%arg0: tensor<?x2xf32>) -> tensor<2x?xf32> {
- // CHECK-NEXT: %[[RESHAPE1:.*]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0, 1]]
- // CHECK-NEXT: %[[RESHAPE2:.*]] = tensor.expand_shape %[[RESHAPE1]] {{\[}}[0, 1]]
+ // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+ // CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
+ // CHECK-DAG: %[[C2_0:.*]] = arith.constant 2 : index
+
+ // CHECK-DAG: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[ARG0]] {{\[\[}}0, 1]] : tensor<?x2xf32> into tensor<?xf32>
+ // CHECK-DAG: %[[SIZE:.*]] = tensor.dim %[[COLLAPSED]], %[[C0]] : tensor<?xf32>
+ // CHECK-DAG: %[[PLACEHOLDER:.*]] = arith.divui %[[SIZE]], %[[C2_0]] : index
+
+ // CHECK-DAG: %[[SHAPE:.*]] = tensor.from_elements %[[C2]], %[[PLACEHOLDER]] : tensor<2xindex>
+ // CHECK-DAG: %[[RESHAPED:.*]] = tensor.reshape %[[ARG0]](%[[SHAPE]]) : (tensor<?x2xf32>, tensor<2xindex>) -> tensor<2x?xf32>
%0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 2, -1>} : (tensor<?x2xf32>) -> tensor<2x?xf32>
- // CHECK-NEXT: return %[[RESHAPE2]]
+
+ // CHECK: return %[[RESHAPED]] : tensor<2x?xf32>
return %0 : tensor<2x?xf32>
}
// -----
-// CHECK-LABEL: @test_reshape_downrank_6D
+// CHECK-LABEL: @test_reshape_downrank_6d
// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]:
-func.func @test_reshape_downrank_6D(%arg0: tensor<1x2x3x5x7x11xf32>) -> tensor<6x5x77xf32> {
- // CHECK: tensor.collapse_shape %[[ARG0]] {{\[}}[0, 1, 2], [3], [4, 5]]
+func.func @test_reshape_downrank_6d(%arg0: tensor<1x2x3x5x7x11xf32>) -> tensor<6x5x77xf32> {
+ // CHECK: %[[SHAPE:.*]] = arith.constant dense<[6, 5, 77]> : tensor<3xindex>
+ // CHECK: %[[RESHAPED:.*]] = tensor.reshape %[[ARG0]](%[[SHAPE]]) : (tensor<1x2x3x5x7x11xf32>, tensor<3xindex>) -> tensor<6x5x77xf32>
%0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 6, 5, 77>} : (tensor<1x2x3x5x7x11xf32>) -> tensor<6x5x77xf32>
+
+ // CHECK: return %[[RESHAPED]] : tensor<6x5x77xf32>
return %0 : tensor<6x5x77xf32>
}
// -----
-// CHECK-LABEL: @test_reshape_downrank_6D_dyn
+// CHECK-LABEL: @test_reshape_downrank_6d_dyn
// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]:
-func.func @test_reshape_downrank_6D_dyn(%arg0: tensor<1x2x?x5x7x11xf32>) -> tensor<?x5x77xf32> {
- // CHECK: tensor.collapse_shape %[[ARG0]] {{\[}}[0, 1, 2, 3, 4, 5]]
- // CHECK: tensor.expand_shape %{{.*}} {{\[}}[0, 1, 2]]
+func.func @test_reshape_downrank_6d_dy...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/85798
More information about the Mlir-commits
mailing list