[Mlir-commits] [mlir] 381677d - [tosa][mlir] Refactor tosa.reshape lowering to linalg for dynamic cases.
Rob Suderman
llvmlistbot at llvm.org
Mon Nov 15 15:32:31 PST 2021
Author: natashaknk
Date: 2021-11-15T15:31:37-08:00
New Revision: 381677dfbfea0aeba6ee70eeb4d1441356fb916f
URL: https://github.com/llvm/llvm-project/commit/381677dfbfea0aeba6ee70eeb4d1441356fb916f
DIFF: https://github.com/llvm/llvm-project/commit/381677dfbfea0aeba6ee70eeb4d1441356fb916f.diff
LOG: [tosa][mlir] Refactor tosa.reshape lowering to linalg for dynamic cases.
Split tosa.reshape into three individual lowerings: collapse, expand and a
combination of both. Add simple dynamic shape support.
Reviewed By: rsuderman
Differential Revision: https://reviews.llvm.org/D113936
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index e90d1533b5c0e..f4470d20fca4c 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -946,6 +946,112 @@ static LogicalResult reduceMatchAndRewriteHelper(Operation *op, uint64_t axis,
return success();
+static bool findIntermediateShape(ArrayRef<int64_t> lhsShape,
+ ArrayRef<int64_t> rhsShape,
+ SmallVector<int64_t> &intermediateShape,
+ bool isDynamic) {
+ if (isDynamic) {
+ // TODO (natashaknk): Make dynamic intermediate shape not always be rank-1
+ intermediateShape = {-1};
+ return true;
+ }
+ if (lhsShape.empty() || rhsShape.empty()) {
+ intermediateShape = {};
+ return true;
+ }
+ unsigned currLhsDim = 0, currRhsDim = 0;
+ while (currLhsDim < lhsShape.size() && currRhsDim < rhsShape.size()) {
+ int64_t rhsSize = rhsShape[currRhsDim];
+ int64_t lhsSize = lhsShape[currLhsDim];
+ while (lhsSize != rhsSize && currLhsDim < lhsShape.size() &&
+ currRhsDim < rhsShape.size()) {
+ if (lhsSize < rhsSize) {
+ currLhsDim++;
+ lhsSize *= lhsShape[currLhsDim];
+ } else {
+ currRhsDim++;
+ rhsSize *= rhsShape[currRhsDim];
+ }
+ }
+ if (lhsSize == rhsSize) {
+ intermediateShape.push_back(lhsSize);
+ }
+ currRhsDim++;
+ currLhsDim++;
+ }
+ // If the iterators didn't reach the end and their leftover dimensions are not
+ // equal to 1 an intermediate shape was not found.
+ while (currLhsDim < lhsShape.size()) {
+ if (lhsShape[currLhsDim++] != 1) {
+ return false;
+ }
+ }
+ while (currRhsDim < rhsShape.size()) {
+ if (rhsShape[currRhsDim++] != 1) {
+ return false;
+ }
+ }
+ return true;
+static bool createReassociationMapsForCollapse(
+ PatternRewriter &rewriter, ArrayRef<int64_t> srcShape,
+ ArrayRef<int64_t> dstShape,
+ SmallVector<ReassociationExprs, 4> &reassociationMap, bool isDynamic) {
+ // If the shape is dynamic, create a map for collapsing into one dimension.
+ if (isDynamic) {
+ SmallVector<AffineExpr, 2> exprs;
+ for (int i = 0, s = srcShape.size(); i < s; ++i)
+ exprs.push_back(rewriter.getAffineDimExpr(i));
+ reassociationMap = {exprs};
+ return true;
+ }
+ if (dstShape.empty()) {
+ reassociationMap = {};
+ return true;
+ }
+ reassociationMap.resize(dstShape.size());
+ unsigned currSrcDim = 0, currDstDim = 0;
+ while (currSrcDim < srcShape.size() && currDstDim < dstShape.size()) {
+ int64_t dstSize = dstShape[currDstDim];
+ int64_t srcSize = srcShape[currSrcDim];
+ while (srcSize < dstSize && currSrcDim < srcShape.size()) {
+ reassociationMap[currDstDim].push_back(
+ rewriter.getAffineDimExpr(currSrcDim++));
+ srcSize *= srcShape[currSrcDim];
+ }
+ if (srcSize == dstSize) {
+ reassociationMap[currDstDim].push_back(
+ rewriter.getAffineDimExpr(currSrcDim++));
+ // If the next dim in collapsedShape is not 1, treat subsequent dims in
+ // expandedShape which are 1 to be collapsed.
+ if (currDstDim == dstShape.size() - 1 || dstShape[currDstDim + 1] != 1) {
+ while (currSrcDim < srcShape.size() && srcShape[currSrcDim] == 1) {
+ reassociationMap[currDstDim].push_back(
+ rewriter.getAffineDimExpr(currSrcDim++));
+ }
+ }
+ }
+ currDstDim++;
+ }
+ // If both iterators didn't reach the end, we have leftover dimentions which
+ // implies that we have a mismatch in shape.
+ if (currSrcDim != srcShape.size() || currDstDim != dstShape.size()) {
+ return false;
+ }
+ return true;
namespace {
template <typename SrcOp>
@@ -1534,7 +1640,7 @@ class FullyConnectedConverter
-class ReshapeConverter : public OpConversionPattern<tosa::ReshapeOp> {
+class ReshapeConverterCollapse : public OpConversionPattern<tosa::ReshapeOp> {
using OpConversionPattern<tosa::ReshapeOp>::OpConversionPattern;
@@ -1543,103 +1649,116 @@ class ReshapeConverter : public OpConversionPattern<tosa::ReshapeOp> {
ConversionPatternRewriter &rewriter) const final {
ShapedType operandTy = adaptor.input1().getType().cast<ShapedType>();
ShapedType resultTy = reshape.getType().template cast<ShapedType>();
+ bool isDynamic = !operandTy.hasStaticShape();
+ if (isDynamic && resultTy.getRank() != 1) {
+ return rewriter.notifyMatchFailure(
+ reshape, "Cannot collapse dynamic dims to more than one dimension");
+ }
if (operandTy == resultTy) {
rewriter.replaceOp(reshape, adaptor.getOperands()[0]);
return success();
- if (!operandTy.hasStaticShape() || !resultTy.hasStaticShape())
- return failure();
+ SmallVector<ReassociationExprs, 4> reassociationMap;
+ if (!createReassociationMapsForCollapse(rewriter, operandTy.getShape(),
+ resultTy.getShape(),
+ reassociationMap, isDynamic)) {
+ return rewriter.notifyMatchFailure(
+ reshape,
+ "tosa.reshape Attempting to collapse into an incompatible shape");
+ }
- // Compute the reassociation maps for the linalg operation.
- ArrayRef<int64_t> expandedShape =
- (operandTy.getRank() > resultTy.getRank() ? operandTy.getShape()
- : resultTy.getShape());
- ArrayRef<int64_t> collapsedShape =
- (operandTy.getRank() > resultTy.getRank() ? resultTy.getShape()
- : operandTy.getShape());
- unsigned currSrcDim = 0, currDstDim = 0;
- SmallVector<ReassociationExprs, 4> reassociationMap(collapsedShape.size());
- // First scan all dimensions in the source shapes to see whether we have a
- // perfect case where consecutive dimensions in source are collapsed. For
- // such case we can just generate one single linalg.reshape.
- bool isCollapsingSource = true;
- while (currSrcDim < expandedShape.size() &&
- currDstDim < collapsedShape.size()) {
- int64_t dstSize = collapsedShape[currDstDim];
- int64_t srcSize = expandedShape[currSrcDim];
- while (srcSize < dstSize && currSrcDim < expandedShape.size()) {
- reassociationMap[currDstDim].push_back(
- rewriter.getAffineDimExpr(currSrcDim++));
- srcSize *= expandedShape[currSrcDim];
- }
- if (srcSize == dstSize) {
- reassociationMap[currDstDim].push_back(
- rewriter.getAffineDimExpr(currSrcDim++));
- // If the next dim in collapsedShape is not 1, treat subsequent dims in
- // expandedShape which are 1 to be collapsed.
- if (currDstDim == collapsedShape.size() - 1 ||
- collapsedShape[currDstDim + 1] != 1) {
- while (currSrcDim < expandedShape.size() &&
- expandedShape[currSrcDim] == 1) {
- reassociationMap[currDstDim].push_back(
- rewriter.getAffineDimExpr(currSrcDim++));
- }
- }
- } else {
- isCollapsingSource = false;
- break;
- }
- currDstDim++;
+ SmallVector<int64_t> intermediateShape;
+ if (!findIntermediateShape(operandTy.getShape(), resultTy.getShape(),
+ intermediateShape, isDynamic)) {
+ return rewriter.notifyMatchFailure(
+ reshape, "tosa.reshape Cannot collapse into given shape");
- // Check if any remaining dimensions exist. If either is rank-0 we only
- // require the directly lowering.
- if (currSrcDim != expandedShape.size() ||
- currDstDim != collapsedShape.size())
- isCollapsingSource = collapsedShape.empty() || expandedShape.empty();
- // Otherwise, we need to first reduce all source dimensions into one and
- // then expand to the destination dimensions.
- if (!isCollapsingSource) {
- auto getIdentityExprs = [&rewriter](int n) {
- SmallVector<AffineExpr, 4> exprs;
- for (int i = 0; i < n; ++i)
- exprs.push_back(rewriter.getAffineDimExpr(i));
- return exprs;
- };
- Location loc = reshape.getLoc();
- int64_t totalElems =
- std::accumulate(expandedShape.begin(), expandedShape.end(), 1,
- std::multiplies<int64_t>());
- auto elemTy = operandTy.getElementType();
- SmallVector<ReassociationExprs, 4> collapsingMap = {
- // Use operandTy here because we need to collapse all operands
- // dimensions.
- getIdentityExprs(operandTy.getShape().size())};
- SmallVector<ReassociationExprs, 4> expandingMap = {
- // Use resultTy here because we need to expand to all result
- // dimensions.
- getIdentityExprs(resultTy.getShape().size())};
- auto collapsedTy = RankedTensorType::get({totalElems}, elemTy);
- Value collapsedOp = rewriter.create<linalg::TensorCollapseShapeOp>(
- loc, collapsedTy, adaptor.getOperands()[0], collapsingMap);
- rewriter.replaceOpWithNewOp<linalg::TensorExpandShapeOp>(
- reshape, resultTy, collapsedOp, expandingMap);
+ rewriter.replaceOpWithNewOp<linalg::TensorCollapseShapeOp>(
+ reshape, resultTy, adaptor.getOperands()[0], reassociationMap);
+ return success();
+ }
+class ReshapeConverterExpand : public OpConversionPattern<tosa::ReshapeOp> {
+ using OpConversionPattern<tosa::ReshapeOp>::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(tosa::ReshapeOp reshape, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const final {
+ ShapedType operandTy = adaptor.input1().getType().cast<ShapedType>();
+ ShapedType resultTy = reshape.getType().template cast<ShapedType>();
+ bool isDynamic = !operandTy.hasStaticShape();
+ if (operandTy == resultTy) {
+ rewriter.replaceOp(reshape, adaptor.getOperands()[0]);
return success();
- if (resultTy.getRank() <
- adaptor.getOperands()[0].getType().cast<ShapedType>().getRank())
- rewriter.replaceOpWithNewOp<linalg::TensorCollapseShapeOp>(
- reshape, resultTy, adaptor.getOperands()[0], reassociationMap);
- else
- rewriter.replaceOpWithNewOp<linalg::TensorExpandShapeOp>(
- reshape, resultTy, adaptor.getOperands()[0], reassociationMap);
+ if (isDynamic && operandTy.getRank() != 1) {
+ return rewriter.notifyMatchFailure(
+ reshape, "Cannot expand dynamic dims from more than one dimension");
+ }
+ SmallVector<ReassociationExprs, 4> reassociationMap;
+ if (!createReassociationMapsForCollapse(rewriter, resultTy.getShape(),
+ operandTy.getShape(),
+ reassociationMap, isDynamic)) {
+ return rewriter.notifyMatchFailure(
+ reshape,
+ "tosa.reshape Attempting to expand into an incompatible shape");
+ }
+ SmallVector<int64_t> intermediateShape;
+ if (!findIntermediateShape(operandTy.getShape(), resultTy.getShape(),
+ intermediateShape, isDynamic) ||
+ intermediateShape != operandTy.getShape()) {
+ return rewriter.notifyMatchFailure(
+ reshape, "tosa.reshape Cannot expand into given shape");
+ }
+ rewriter.replaceOpWithNewOp<linalg::TensorExpandShapeOp>(
+ reshape, resultTy, adaptor.getOperands()[0], reassociationMap);
+ return success();
+ }
+class ReshapeConverterCollapseExpand
+ : public OpConversionPattern<tosa::ReshapeOp> {
+ using OpConversionPattern<tosa::ReshapeOp>::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(tosa::ReshapeOp reshape, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const final {
+ ShapedType operandTy = adaptor.input1().getType().cast<ShapedType>();
+ ShapedType resultTy = reshape.getType().template cast<ShapedType>();
+ bool isDynamic = !operandTy.hasStaticShape();
+ if (operandTy == resultTy) {
+ rewriter.replaceOp(reshape, adaptor.getOperands()[0]);
+ return success();
+ }
+ SmallVector<int64_t> intermediateShape;
+ if (!findIntermediateShape(resultTy.getShape(), operandTy.getShape(),
+ intermediateShape, isDynamic)) {
+ return rewriter.notifyMatchFailure(
+ reshape, "tosa.reshape Cannot identify an intermediate shape between "
+ "the given two shapes");
+ }
+ Value collapse = rewriter.create<tosa::ReshapeOp>(
+ reshape.getLoc(),
+ RankedTensorType::get(intermediateShape,
+ reshape.getType().getElementType()),
+ adaptor.input1());
+ Value expand =
+ rewriter.create<tosa::ReshapeOp>(reshape.getLoc(), resultTy, collapse);
+ rewriter.replaceOp(reshape, expand);
return success();
@@ -3072,7 +3191,9 @@ void mlir::tosa::populateTosaToLinalgConversionPatterns(
- ReshapeConverter,
+ ReshapeConverterCollapse,
+ ReshapeConverterExpand,
+ ReshapeConverterCollapseExpand,
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index d072808b1b476..2e25ad975a09e 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -541,6 +541,16 @@ func @test_reshape_downrank(%arg0: tensor<2x3xf32>) -> tensor<6xf32> {
// -----
+// CHECK-LABEL: @test_reshape_downrank_dyn
+func @test_reshape_downrank_dyn(%arg0: tensor<2x?xf32>) -> tensor<?xf32> {
+ // CHECK: [[RESHAPE:%.+]] = linalg.tensor_collapse_shape %arg0 {{\[}}[0, 1]]
+ %0 = "tosa.reshape"(%arg0) {new_shape = [-1]} : (tensor<2x?xf32>) -> tensor<?xf32>
+ // CHECK: return [[RESHAPE]]
+ return %0 : tensor<?xf32>
+// -----
// CHECK-LABEL: @test_reshape_uprank
func @test_reshape_uprank(%arg0: tensor<6xf32>) -> tensor<2x3xf32> {
// CHECK: [[RESHAPE:%.+]] = linalg.tensor_expand_shape %arg0 {{\[}}[0, 1]]
@@ -551,6 +561,16 @@ func @test_reshape_uprank(%arg0: tensor<6xf32>) -> tensor<2x3xf32> {
// -----
+// CHECK-LABEL: @test_reshape_uprank_dyn
+func @test_reshape_uprank_dyn(%arg0: tensor<?xf32>) -> tensor<2x?xf32> {
+ // CHECK: [[RESHAPE:%.+]] = linalg.tensor_expand_shape %arg0 {{\[}}[0, 1]]
+ %0 = "tosa.reshape"(%arg0) {new_shape = [2, -1]} : (tensor<?xf32>) -> tensor<2x?xf32>
+ // CHECK: return [[RESHAPE]]
+ return %0 : tensor<2x?xf32>
+// -----
// CHECK-LABEL: @test_reshape_samerank
func @test_reshape_samerank(%arg0: tensor<3x2xf32>) -> tensor<2x3xf32> {
// CHECK-SAME: (%[[ARG0:.*]]: tensor<3x2xf32>)
@@ -563,6 +583,18 @@ func @test_reshape_samerank(%arg0: tensor<3x2xf32>) -> tensor<2x3xf32> {
// -----
+// CHECK-LABEL: @test_reshape_samerank_dyn
+func @test_reshape_samerank_dyn(%arg0: tensor<?x2xf32>) -> tensor<2x?xf32> {
+ // CHECK-SAME: (%[[ARG0:.*]]: tensor<?x2xf32>)
+ // CHECK-NEXT: %[[RESHAPE1:.*]] = linalg.tensor_collapse_shape %[[ARG0]] {{\[}}[0, 1]]
+ // CHECK-NEXT: %[[RESHAPE2:.*]] = linalg.tensor_expand_shape %[[RESHAPE1]] {{\[}}[0, 1]]
+ %0 = "tosa.reshape"(%arg0) {new_shape = [2, -1]} : (tensor<?x2xf32>) -> tensor<2x?xf32>
+ // CHECK-NEXT: return %[[RESHAPE2]]
+ return %0 : tensor<2x?xf32>
+// -----
// CHECK-LABEL: @test_reshape_downrank_6D
func @test_reshape_downrank_6D(%arg0: tensor<1x2x3x5x7x11xf32>) -> tensor<6x5x77xf32> {
// CHECK: linalg.tensor_collapse_shape %arg0 {{\[}}[0, 1, 2], [3], [4, 5]]
@@ -572,6 +604,16 @@ func @test_reshape_downrank_6D(%arg0: tensor<1x2x3x5x7x11xf32>) -> tensor<6x5x77
// -----
+// CHECK-LABEL: @test_reshape_downrank_6D_dyn
+func @test_reshape_downrank_6D_dyn(%arg0: tensor<1x2x?x5x7x11xf32>) -> tensor<?x5x77xf32> {
+ // CHECK: linalg.tensor_collapse_shape %arg0 {{\[}}[0, 1, 2, 3, 4, 5]]
+ // CHECK: linalg.tensor_expand_shape %0 {{\[}}[0, 1, 2]]
+ %0 = "tosa.reshape"(%arg0) {new_shape = [-1, 5, 77]} : (tensor<1x2x?x5x7x11xf32>) -> tensor<?x5x77xf32>
+ return %0 : tensor<?x5x77xf32>
+// -----
// CHECK-LABEL: @test_identity
func @test_identity(%arg0: tensor<1xf32>, %arg1: tensor<1xi32>) -> (tensor<1xf32>, tensor<1xi32>) {
%0 = "tosa.identity"(%arg0) : (tensor<1xf32>) -> tensor<1xf32>
More information about the Mlir-commits
mailing list