[Mlir-commits] [mlir] 9bcda47 - [mlir][tosa] Swap the reshape at the end of the reduce op for an expand_shape in tosa-to-linalg
Rob Suderman
llvmlistbot at llvm.org
Mon Oct 3 10:34:50 PDT 2022
Author: natashaknk
Date: 2022-10-03T10:32:23-07:00
New Revision: 9bcda47afcb4af2831654d7c31ad2e956202fed0
URL: https://github.com/llvm/llvm-project/commit/9bcda47afcb4af2831654d7c31ad2e956202fed0
DIFF: https://github.com/llvm/llvm-project/commit/9bcda47afcb4af2831654d7c31ad2e956202fed0.diff
LOG: [mlir][tosa] Swap the reshape at the end of the reduce op for an expand_shape in tosa-to-linalg
Reviewed By: rsuderman
Differential Revision: https://reviews.llvm.org/D133877
Added:
Modified:
mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 11c7d06cfe604..b54dab82c58f1 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -815,8 +815,21 @@ static LogicalResult reduceMatchAndRewriteHelper(Operation *op, uint64_t axis,
if (!didEncounterError)
return failure();
- rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(op, resultTy,
- linalgOp.getResults());
+ SmallVector<ReassociationExprs, 4> reassociationMap;
+ uint64_t expandInputRank =
+ linalgOp.getResults()[0].getType().cast<ShapedType>().getRank();
+ reassociationMap.resize(expandInputRank);
+
+ for (uint64_t i = 0; i < expandInputRank; i++) {
+ int32_t dimToPush = i > axis ? i + 1 : i;
+ reassociationMap[i].push_back(rewriter.getAffineDimExpr(dimToPush));
+ }
+ int32_t expandedDim = axis < expandInputRank ? axis : expandInputRank - 1;
+ reassociationMap[expandedDim].push_back(
+ rewriter.getAffineDimExpr(expandedDim + 1));
+
+ rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
+ op, resultTy, linalgOp.getResults()[0], reassociationMap);
return success();
}
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index 2e4dfbfabe259..685f782d0c710 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -746,8 +746,7 @@ func.func @reduce_float_dyn(%arg0: tensor<?x5x4xf32>) -> () {
// CHECK: ^bb0(%arg1: f32, %arg2: f32)
// CHECK: %[[RES:.+]] = arith.addf %arg1, %arg2 : f32
// CHECK: linalg.yield %[[RES]] : f32
- // CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[GENERIC]] {{\[}}[0, 1]] : tensor<?x4xf32> into tensor<?xf32>
- // CHECK: tensor.expand_shape %[[COLLAPSED]] {{\[}}[0, 1, 2]] : tensor<?xf32> into tensor<?x1x4xf32>
+ // CHECK: tensor.expand_shape %[[GENERIC]] {{\[}}[0], [1, 2]] : tensor<?x4xf32> into tensor<?x1x4xf32>
%0 = "tosa.reduce_sum"(%arg0) {axis = 1 : i64} : (tensor<?x5x4xf32>) -> tensor<?x1x4xf32>
return
}
@@ -768,8 +767,7 @@ func.func @reduce_float_dyn_nonzero_batch(%arg0: tensor<5x?x4xf32>) -> () {
// CHECK: ^bb0(%arg1: f32, %arg2: f32)
// CHECK: %[[RES:.+]] = arith.mulf %arg1, %arg2 : f32
// CHECK: linalg.yield %[[RES]] : f32
- // CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[GENERIC]] {{\[}}[0, 1]] : tensor<5x?xf32> into tensor<?xf32>
- // CHECK: tensor.expand_shape %[[COLLAPSED]] {{\[}}[0, 1, 2]] : tensor<?xf32> into tensor<5x?x1xf32>
+ // CHECK: tensor.expand_shape %[[GENERIC]] {{\[}}[0], [1, 2]] : tensor<5x?xf32> into tensor<5x?x1xf32>
%0 = "tosa.reduce_prod"(%arg0) {axis = 2 : i64} : (tensor<5x?x4xf32>) -> tensor<5x?x1xf32>
return
}
More information about the Mlir-commits
mailing list