[Mlir-commits] [mlir] 0720576 - [mlir][tosa] Remove zero-fill of tosa.concat outputs when lowering to linalg.
Rob Suderman
llvmlistbot at llvm.org
Mon Nov 14 12:01:54 PST 2022
Author: Rob Suderman
Date: 2022-11-14T11:39:24-08:00
New Revision: 07205761417d718dd0bddde9a1165af438cb1eaa
URL: https://github.com/llvm/llvm-project/commit/07205761417d718dd0bddde9a1165af438cb1eaa
DIFF: https://github.com/llvm/llvm-project/commit/07205761417d718dd0bddde9a1165af438cb1eaa.diff
LOG: [mlir][tosa] Remove zero-fill of tosa.concat outputs when lowering to linalg.
Since all output elements are known to be overridden by construction the fill is not required. This change makes the tosa lowering consistent with the MHLO and Torch lowerings of concat which do not do the fill.
Reviewed By: rsuderman
Differential Revision: https://reviews.llvm.org/D137967
Added:
Modified:
mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index f56162ba0acd5..7e4c38b449ef9 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -1781,19 +1781,13 @@ struct ConcatConverter : public OpConversionPattern<tosa::ConcatOp> {
Value emptyTensor = rewriter.create<tensor::EmptyOp>(
loc, resultType.getShape(), resultType.getElementType(), dynDims);
- Value zeroVal = rewriter.createOrFold<arith::ConstantOp>(
- loc, rewriter.getZeroAttr(resultType.getElementType()));
- Value result = rewriter
- .create<linalg::FillOp>(loc, ValueRange{zeroVal},
- ValueRange{emptyTensor})
- .result();
-
auto toOpFoldResult = [](Value v) -> OpFoldResult {
auto op = v.getDefiningOp<arith::ConstantIndexOp>();
if (!op)
return v;
return op.getValue();
};
+ Value result = emptyTensor;
for (auto arg : adaptor.getOperands()) {
sizes[axis] = rewriter.createOrFold<tensor::DimOp>(loc, arg, axisValue);
result = rewriter.createOrFold<tensor::InsertSliceOp>(
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index 2aeb7c8607719..52674d4b58ad0 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -929,9 +929,7 @@ func.func @concat(%arg0: tensor<5x1xf32>, %arg1: tensor<6x1xf32>) -> () {
// CHECK: [[IDX0:%.+]] = arith.constant 0 : index
// CHECK: [[IDX1:%.+]] = arith.constant 1 : index
// CHECK: [[INIT:%.+]] = tensor.empty() : tensor<11x1xf32>
- // CHECK: [[CST:%.+]] = arith.constant 0.0
- // CHECK: [[FILL:%.+]] = linalg.fill ins([[CST]]{{.*}}outs([[INIT]]
- // CHECK: [[INSERT0:%.+]] = tensor.insert_slice %[[ARG0]] into [[FILL]][0, 0] [5, 1] [1, 1]
+ // CHECK: [[INSERT0:%.+]] = tensor.insert_slice %[[ARG0]] into [[INIT]][0, 0] [5, 1] [1, 1]
// CHECK: [[INSERT1:%.+]] = tensor.insert_slice %[[ARG1]] into [[INSERT0]][5, 0] [6, 1] [1, 1]
%0 = "tosa.concat"(%arg0, %arg1) { axis = 0 : i64} : (tensor<5x1xf32>, tensor<6x1xf32>) -> (tensor<11x1xf32>)
@@ -941,9 +939,7 @@ func.func @concat(%arg0: tensor<5x1xf32>, %arg1: tensor<6x1xf32>) -> () {
// CHECK: [[IDX0:%.+]] = arith.constant 0 : index
// CHECK: [[IDX1:%.+]] = arith.constant 1 : index
// CHECK: [[INIT:%.+]] = tensor.empty() : tensor<5x2xf32>
- // CHECK: [[CST:%.+]] = arith.constant 0.0
- // CHECK: [[FILL:%.+]] = linalg.fill ins([[CST]]{{.*}}outs([[INIT]]
- // CHECK: [[INSERT0:%.+]] = tensor.insert_slice %[[ARG0]] into [[FILL]][0, 0] [5, 1] [1, 1]
+ // CHECK: [[INSERT0:%.+]] = tensor.insert_slice %[[ARG0]] into [[INIT]][0, 0] [5, 1] [1, 1]
// CHECK: [[INSERT1:%.+]] = tensor.insert_slice %[[ARG0]] into [[INSERT0]][0, 1] [5, 1] [1, 1]
%1 = "tosa.concat"(%arg0, %arg0) { axis = 1 : i64} : (tensor<5x1xf32>, tensor<5x1xf32>) -> (tensor<5x2xf32>)
return
@@ -964,9 +960,7 @@ func.func @concat_non_axis_dyn(%arg0: tensor<5x?xf32>, %arg1: tensor<6x?xf32>) -
// CHECK: %[[IDX1_2:.+]] = arith.constant 1 : index
// CHECK: %[[DYN:.+]] = tensor.dim %[[ARG0]], %[[IDX1_2]]
// CHECK: %[[INIT:.+]] = tensor.empty(%[[DYN]]) : tensor<11x?xf32>
- // CHECK: %[[CST:.+]] = arith.constant 0.0
- // CHECK: %[[FILL:.+]] = linalg.fill ins(%[[CST]]{{.*}}outs(%[[INIT]]
- // CHECK: %[[INSERT0:.+]] = tensor.insert_slice %[[ARG0]] into %[[FILL]][0, 0] [5, %[[SIZE]]] [1, 1]
+ // CHECK: %[[INSERT0:.+]] = tensor.insert_slice %[[ARG0]] into %[[INIT]][0, 0] [5, %[[SIZE]]] [1, 1]
// CHECK: %[[INSERT1:.+]] = tensor.insert_slice %[[ARG1]] into %[[INSERT0]][5, 0] [6, %[[SIZE]]] [1, 1]
%0 = "tosa.concat"(%arg0, %arg1) { axis = 0 : i64} : (tensor<5x?xf32>, tensor<6x?xf32>) -> (tensor<11x?xf32>)
return
@@ -987,10 +981,8 @@ func.func @concat_axis_dyn(%arg0: tensor<?x3xf32>, %arg1: tensor<?x3xf32>) -> ()
// CHECK: %[[DYN:.+]] = tensor.dim %[[ARG0]], %[[IDX0_2]]
// CHECK: %[[IDX1:.+]] = arith.constant 1 : index
// CHECK: %[[INIT:.+]] = tensor.empty(%[[DYN]]) : tensor<?x3xf32>
- // CHECK: %[[CST:.+]] = arith.constant 0.0
- // CHECK: %[[FILL:.+]] = linalg.fill ins(%[[CST]]{{.*}}outs(%[[INIT]]
// CHECK: %[[DYN1:.+]] = tensor.dim %[[ARG0]], %[[AXIS]]
- // CHECK: %[[INSERT0:.+]] = tensor.insert_slice %[[ARG0]] into %[[FILL]][0, 0] [%[[DYN1]], 3] [1, 1]
+ // CHECK: %[[INSERT0:.+]] = tensor.insert_slice %[[ARG0]] into %[[INIT]][0, 0] [%[[DYN1]], 3] [1, 1]
// CHECK: %[[SUM:.+]] = arith.addi %[[OFFSET]], %[[DYN1]]
// CHECK: %[[DYN2:.+]] = tensor.dim %[[ARG1]], %[[AXIS]]
// CHECK: %[[INSERT1:.+]] = tensor.insert_slice %[[ARG1]] into %[[INSERT0]][%[[SUM]], 0] [%[[DYN2]], 3] [1, 1]
More information about the Mlir-commits
mailing list