[Mlir-commits] [mlir] 9bcda47 - [mlir][tosa] Swap the reshape at the end of the reduce op for an expand_shape in tosa-to-linalg

Rob Suderman llvmlistbot at llvm.org
Mon Oct 3 10:34:50 PDT 2022


Author: natashaknk
Date: 2022-10-03T10:32:23-07:00
New Revision: 9bcda47afcb4af2831654d7c31ad2e956202fed0

URL: https://github.com/llvm/llvm-project/commit/9bcda47afcb4af2831654d7c31ad2e956202fed0
DIFF: https://github.com/llvm/llvm-project/commit/9bcda47afcb4af2831654d7c31ad2e956202fed0.diff

LOG: [mlir][tosa] Swap the reshape at the end of the reduce op for an expand_shape in tosa-to-linalg

Reviewed By: rsuderman

Differential Revision: https://reviews.llvm.org/D133877

Added: 
    

Modified: 
    mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
    mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 11c7d06cfe604..b54dab82c58f1 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -815,8 +815,21 @@ static LogicalResult reduceMatchAndRewriteHelper(Operation *op, uint64_t axis,
   if (!didEncounterError)
     return failure();
 
-  rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(op, resultTy,
-                                               linalgOp.getResults());
+  SmallVector<ReassociationExprs, 4> reassociationMap;
+  uint64_t expandInputRank =
+      linalgOp.getResults()[0].getType().cast<ShapedType>().getRank();
+  reassociationMap.resize(expandInputRank);
+
+  for (uint64_t i = 0; i < expandInputRank; i++) {
+    int32_t dimToPush = i > axis ? i + 1 : i;
+    reassociationMap[i].push_back(rewriter.getAffineDimExpr(dimToPush));
+  }
+  int32_t expandedDim = axis < expandInputRank ? axis : expandInputRank - 1;
+  reassociationMap[expandedDim].push_back(
+      rewriter.getAffineDimExpr(expandedDim + 1));
+
+  rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
+      op, resultTy, linalgOp.getResults()[0], reassociationMap);
   return success();
 }
 

diff  --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index 2e4dfbfabe259..685f782d0c710 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -746,8 +746,7 @@ func.func @reduce_float_dyn(%arg0: tensor<?x5x4xf32>) -> () {
   // CHECK: ^bb0(%arg1: f32, %arg2: f32)
   // CHECK:   %[[RES:.+]] = arith.addf %arg1, %arg2 : f32
   // CHECK:   linalg.yield %[[RES]] : f32
-  // CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[GENERIC]] {{\[}}[0, 1]] : tensor<?x4xf32> into tensor<?xf32>
-  // CHECK: tensor.expand_shape %[[COLLAPSED]] {{\[}}[0, 1, 2]] : tensor<?xf32> into tensor<?x1x4xf32>
+  // CHECK: tensor.expand_shape %[[GENERIC]] {{\[}}[0], [1, 2]] : tensor<?x4xf32> into tensor<?x1x4xf32>
   %0 = "tosa.reduce_sum"(%arg0) {axis = 1 : i64} : (tensor<?x5x4xf32>) -> tensor<?x1x4xf32>
   return
 }
@@ -768,8 +767,7 @@ func.func @reduce_float_dyn_nonzero_batch(%arg0: tensor<5x?x4xf32>) -> () {
   // CHECK: ^bb0(%arg1: f32, %arg2: f32)
   // CHECK:   %[[RES:.+]] = arith.mulf %arg1, %arg2 : f32
   // CHECK:   linalg.yield %[[RES]] : f32
-  // CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[GENERIC]] {{\[}}[0, 1]] : tensor<5x?xf32> into tensor<?xf32>
-  // CHECK: tensor.expand_shape %[[COLLAPSED]] {{\[}}[0, 1, 2]] : tensor<?xf32> into tensor<5x?x1xf32>
+  // CHECK: tensor.expand_shape %[[GENERIC]] {{\[}}[0], [1, 2]] : tensor<5x?xf32> into tensor<5x?x1xf32>
   %0 = "tosa.reduce_prod"(%arg0) {axis = 2 : i64} : (tensor<5x?x4xf32>) -> tensor<5x?x1xf32>
   return
 }


        


More information about the Mlir-commits mailing list