[Mlir-commits] [mlir] 2d61628 - [mlir][tosa] Swap reshape at end of reduce op with expand_shape
Robert Suderman
llvmlistbot at llvm.org
Tue Mar 14 11:54:17 PDT 2023
Author: Ramiro Leal-Cavazos
Date: 2023-03-14T18:51:39Z
New Revision: 2d61628c1f49963921b9ac1995218191dc5e3091
URL: https://github.com/llvm/llvm-project/commit/2d61628c1f49963921b9ac1995218191dc5e3091
DIFF: https://github.com/llvm/llvm-project/commit/2d61628c1f49963921b9ac1995218191dc5e3091.diff
LOG: [mlir][tosa] Swap reshape at end of reduce op with expand_shape
This commit swaps back the `tosa.reshape` op used at the end of the
lowering for reduce ops with the op `tensor.expand_shape`. This is
needed to properly support dynamically-sized tensors. In such cases,
lowering directly to `tensor.expand_shape` allows us to control which
dimension gets expanded at the end using the knowledge of the
reduction. This would not be possible when using `tosa.reshape`, since
the op does not have a way of knowing that we are only unsqueezing a
single dimension.
Note: this change had previously been performed in
https://reviews.llvm.org/D133877.
Reviewed By: rsuderman
Differential Revision: https://reviews.llvm.org/D145986
Added:
Modified:
mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index f6ca01949632..271a09539e46 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -807,9 +807,28 @@ static LogicalResult reduceMatchAndRewriteHelper(Operation *op, uint64_t axis,
return rewriter.notifyMatchFailure(
op, "unable to create linalg.generic body for reduce op");
- rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(
- op, resultTy, linalgOp.getResults()[0],
- rewriter.getDenseI64ArrayAttr(resultTy.getShape()));
+ SmallVector<ReassociationExprs, 4> reassociationMap;
+ uint64_t expandInputRank =
+ linalgOp.getResults()[0].getType().cast<ShapedType>().getRank();
+ reassociationMap.resize(expandInputRank);
+
+ for (uint64_t i = 0; i < expandInputRank; i++) {
+ int32_t dimToPush = i > axis ? i + 1 : i;
+ reassociationMap[i].push_back(rewriter.getAffineDimExpr(dimToPush));
+ }
+
+ if (expandInputRank != 0) {
+ int32_t expandedDim = axis < expandInputRank ? axis : expandInputRank - 1;
+ reassociationMap[expandedDim].push_back(
+ rewriter.getAffineDimExpr(expandedDim + 1));
+ }
+
+ // Lower directly to `tensor::ExpandShapeOp` instead of `tosa::ReshapeOp`,
+ // since here we know which dimension to expand, and `tosa::ReshapeOp` would
+ // not have access to such information. This matters when handling dynamically
+ // sized tensors.
+ rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
+ op, resultTy, linalgOp.getResults()[0], reassociationMap);
return success();
}
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index 427fe6b2f16b..133999eff1ec 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -626,7 +626,7 @@ func.func @reduce_float(%arg0: tensor<5x4xf32>) -> () {
// CHECK: ^bb0(%[[ARG1:.*]]: f32, %[[ARG2:.*]]: f32)
// CHECK: [[RES:%.+]] = arith.addf %[[ARG1]], %[[ARG2]] : f32
// CHECK: linalg.yield [[RES]] : f32
- // CHECK: "tosa.reshape"([[GENERIC]]) {new_shape = array<i64: 1, 4>}
+ // CHECK: tensor.expand_shape [[GENERIC]] {{\[}}[0, 1]] : tensor<4xf32> into tensor<1x4xf32>
%0 = "tosa.reduce_sum"(%arg0) {axis = 0 : i64} : (tensor<5x4xf32>) -> tensor<1x4xf32>
// CHECK: [[INIT:%.+]] = tensor.empty() : tensor<5xf32>
@@ -636,7 +636,7 @@ func.func @reduce_float(%arg0: tensor<5x4xf32>) -> () {
// CHECK: ^bb0(%[[ARG1:.*]]: f32, %[[ARG2:.*]]: f32)
// CHECK: [[RES:%.+]] = arith.addf %[[ARG1]], %[[ARG2]] : f32
// CHECK: linalg.yield [[RES]] : f32
- // CHECK: "tosa.reshape"([[GENERIC]]) {new_shape = array<i64: 5, 1>}
+ // CHECK: tensor.expand_shape [[GENERIC]] {{\[}}[0, 1]] : tensor<5xf32> into tensor<5x1xf32>
%1 = "tosa.reduce_sum"(%arg0) {axis = 1 : i64} : (tensor<5x4xf32>) -> tensor<5x1xf32>
// CHECK: arith.constant 1.0
@@ -676,7 +676,7 @@ func.func @reduce_float_dyn(%arg0: tensor<?x5x4xf32>) -> () {
// CHECK: ^bb0(%[[ARG1:.*]]: f32, %[[ARG2:.*]]: f32)
// CHECK: %[[RES:.+]] = arith.addf %[[ARG1]], %[[ARG2]] : f32
// CHECK: linalg.yield %[[RES]] : f32
- // CHECK: "tosa.reshape"(%[[GENERIC]]) {new_shape = array<i64: -9223372036854775808, 1, 4>}
+ // CHECK: tensor.expand_shape %[[GENERIC]] {{\[}}[0], [1, 2]] : tensor<?x4xf32> into tensor<?x1x4xf32>
%0 = "tosa.reduce_sum"(%arg0) {axis = 1 : i64} : (tensor<?x5x4xf32>) -> tensor<?x1x4xf32>
return
}
@@ -696,7 +696,7 @@ func.func @reduce_float_dyn_rank_1(%arg0: tensor<?xf32>) -> () {
// CHECK: ^bb0(%[[ARG1:.*]]: f32, %[[ARG2:.*]]: f32)
// CHECK: %[[RES:.+]] = arith.addf %[[ARG1]], %[[ARG2]] : f32
// CHECK: linalg.yield %[[RES]] : f32
- // CHECK: "tosa.reshape"(%[[GENERIC]]) {new_shape = array<i64: 1>} : (tensor<f32>) -> tensor<1xf32>
+ // CHECK: tensor.expand_shape %[[GENERIC]] {{\[}}] : tensor<f32> into tensor<1xf32>
%0 = "tosa.reduce_sum"(%arg0) {axis = 0 : i64} : (tensor<?xf32>) -> tensor<1xf32>
return
}
@@ -718,7 +718,7 @@ func.func @reduce_float_dyn_nonzero_batch(%arg0: tensor<5x?x4xf32>) -> () {
// CHECK: ^bb0(%[[ARG1:.*]]: f32, %[[ARG2:.*]]: f32)
// CHECK: %[[RES:.+]] = arith.mulf %[[ARG1]], %[[ARG2]] : f32
// CHECK: linalg.yield %[[RES]] : f32
- // CHECK: "tosa.reshape"(%[[GENERIC]]) {new_shape = array<i64: 5, -9223372036854775808, 1>}
+ // CHECK: tensor.expand_shape %[[GENERIC]] {{\[}}[0], [1, 2]] : tensor<5x?xf32> into tensor<5x?x1xf32>
%0 = "tosa.reduce_prod"(%arg0) {axis = 2 : i64} : (tensor<5x?x4xf32>) -> tensor<5x?x1xf32>
return
}
@@ -740,7 +740,7 @@ func.func @reduce_float_dyn_multiple(%arg0: tensor<?x?xf32>) -> () {
// CHECK: ^bb0(%[[ARG1:.*]]: f32, %[[ARG2:.*]]: f32)
// CHECK: %[[MAX:.+]] = arith.maxf %[[ARG1]], %[[ARG2]] : f32
// CHECK: linalg.yield %[[MAX]] : f32
- // CHECK: "tosa.reshape"(%[[GENERIC]]) {new_shape = array<i64: -9223372036854775808, 1>}
+ // CHECK: tensor.expand_shape %[[GENERIC]] {{\[}}[0, 1]] : tensor<?xf32> into tensor<?x1xf32>
%0 = "tosa.reduce_max"(%arg0) {axis = 1 : i64} : (tensor<?x?xf32>) -> tensor<?x1xf32>
return
}
@@ -761,7 +761,7 @@ func.func @reduce_int(%arg0: tensor<5x4xi32>) -> () {
// CHECK: ^bb0(%[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32)
// CHECK: [[RES:%.+]] = arith.addi %[[ARG1]], %[[ARG2]] : i32
// CHECK: linalg.yield [[RES]] : i32
- // CHECK: "tosa.reshape"([[GENERIC]]) {new_shape = array<i64: 1, 4>}
+ // CHECK: tensor.expand_shape [[GENERIC]] {{\[}}[0, 1]] : tensor<4xi32> into tensor<1x4xi32>
%0 = "tosa.reduce_sum"(%arg0) {axis = 0 : i64} : (tensor<5x4xi32>) -> tensor<1x4xi32>
// CHECK: [[INIT:%.+]] = tensor.empty()
@@ -771,7 +771,7 @@ func.func @reduce_int(%arg0: tensor<5x4xi32>) -> () {
// CHECK: ^bb0(%[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32)
// CHECK: [[RES:%.+]] = arith.addi %[[ARG1]], %[[ARG2]] : i32
// CHECK: linalg.yield [[RES]] : i32
- // CHECK: "tosa.reshape"([[GENERIC]]) {new_shape = array<i64: 5, 1>}
+ // CHECK: tensor.expand_shape [[GENERIC]] {{\[}}[0, 1]] : tensor<5xi32> into tensor<5x1xi32>
%1 = "tosa.reduce_sum"(%arg0) {axis = 1 : i64} : (tensor<5x4xi32>) -> tensor<5x1xi32>
// CHECK: arith.constant 1
@@ -811,7 +811,7 @@ func.func @reduce_bool(%arg0: tensor<5x4xi1>) -> () {
// CHECK: ^bb0(%[[ARG1:[0-9a-zA-Z_]+]]: i1, %[[ARG2:[0-9a-zA-Z_]+]]: i1)
// CHECK: [[RES:%.+]] = arith.andi %[[ARG1]], %[[ARG2]] : i1
// CHECK: linalg.yield [[RES]] : i1
- // CHECK: "tosa.reshape"([[GENERIC]]) {new_shape = array<i64: 1, 4>}
+ // CHECK: tensor.expand_shape [[GENERIC]] {{\[}}[0, 1]] : tensor<4xi1> into tensor<1x4xi1>
%0 = "tosa.reduce_all"(%arg0) {axis = 0 : i64} : (tensor<5x4xi1>) -> tensor<1x4xi1>
// CHECK: arith.constant false
More information about the Mlir-commits
mailing list