[Mlir-commits] [mlir] b9b10c0 - [tosa][mlir] Lowering for dynamic shapes in the reduce_x ops in tosa-to-linalg
Rob Suderman
llvmlistbot at llvm.org
Wed Jan 19 11:18:53 PST 2022
Author: natashaknk
Date: 2022-01-19T11:15:14-08:00
New Revision: b9b10c0e615fdd7cb0687434d37bbe6bfb804639
URL: https://github.com/llvm/llvm-project/commit/b9b10c0e615fdd7cb0687434d37bbe6bfb804639
DIFF: https://github.com/llvm/llvm-project/commit/b9b10c0e615fdd7cb0687434d37bbe6bfb804639.diff
LOG: [tosa][mlir] Lowering for dynamic shapes in the reduce_x ops in tosa-to-linalg
Reviewed By: rsuderman
Differential Revision: https://reviews.llvm.org/D117691
Added:
Modified:
mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 1fab060a6b62f..5e94633a7b60a 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -830,19 +830,22 @@ static LogicalResult reduceMatchAndRewriteHelper(Operation *op, uint64_t axis,
Value input = op->getOperand(0);
llvm::SmallVector<int64_t> reduceShape;
+ SmallVector<Value> dynDims;
for (unsigned i = 0; i < inputTy.getRank(); i++) {
- if (axis != i)
+ if (axis != i) {
reduceShape.push_back(inputTy.getDimSize(i));
+ if (inputTy.isDynamicDim(i))
+ dynDims.push_back(rewriter.create<tensor::DimOp>(loc, input, i));
+ }
}
Type reduceTy = RankedTensorType::get(reduceShape, resultTy.getElementType());
// First fill the output buffer with the init value.
- auto initTensor =
- rewriter
- .create<linalg::InitTensorOp>(loc, ArrayRef<Value>({}), reduceShape,
- resultTy.getElementType())
- .result();
+ auto initTensor = rewriter
+ .create<linalg::InitTensorOp>(loc, dynDims, reduceShape,
+ resultTy.getElementType())
+ .result();
auto fillValueAttr = createInitialValueForReduceOp(op, elementTy, rewriter);
if (!fillValueAttr)
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index 3706c4131dcbb..27487a4b8e8bf 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -736,6 +736,72 @@ func @reduce_float(%arg0: tensor<5x4xf32>) -> () {
// -----
+// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+
+// CHECK-LABEL: @reduce_float_dyn
+func @reduce_float_dyn(%arg0: tensor<?x5x4xf32>) -> () {
+ // CHECK: %[[C0:.+]] = arith.constant 0
+ // CHECK: %[[DYN:.+]] = tensor.dim %arg0, %[[C0]]
+ // CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[DYN]], 4]
+ // CHECK: %[[CST0:.+]] = arith.constant 0.0
+ // CHECK: %[[FILL:.+]] = linalg.fill(%[[CST0]], %[[INIT]])
+ // CHECK: %[[GENERIC:.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "reduction", "parallel"]} ins(%arg0 : tensor<?x5x4xf32>) outs(%[[FILL]] : tensor<?x4xf32>)
+ // CHECK: ^bb0(%arg1: f32, %arg2: f32)
+ // CHECK: %[[RES:.+]] = arith.addf %arg1, %arg2 : f32
+ // CHECK: linalg.yield %[[RES]] : f32
+ // CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[GENERIC]] {{\[}}[0, 1]] : tensor<?x4xf32> into tensor<?xf32>
+ // CHECK: tensor.expand_shape %[[COLLAPSED]] {{\[}}[0, 1, 2]] : tensor<?xf32> into tensor<?x1x4xf32>
+ %0 = "tosa.reduce_sum"(%arg0) {axis = 1 : i64} : (tensor<?x5x4xf32>) -> tensor<?x1x4xf32>
+ return
+}
+
+// -----
+
+// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+// CHECK-LABEL: @reduce_float_dyn_nonzero_batch
+func @reduce_float_dyn_nonzero_batch(%arg0: tensor<5x?x4xf32>) -> () {
+ // CHECK: %[[C1:.+]] = arith.constant 1
+ // CHECK: %[[DYN:.+]] = tensor.dim %arg0, %[[C1]]
+ // CHECK: %[[INIT:.+]] = linalg.init_tensor [5, %[[DYN]]]
+ // CHECK: %[[CST1:.+]] = arith.constant 1.0
+ // CHECK: %[[FILL:.+]] = linalg.fill(%[[CST1]], %[[INIT]])
+ // CHECK: %[[GENERIC:.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0 : tensor<5x?x4xf32>) outs(%[[FILL]] : tensor<5x?xf32>)
+ // CHECK: ^bb0(%arg1: f32, %arg2: f32)
+ // CHECK: %[[RES:.+]] = arith.mulf %arg1, %arg2 : f32
+ // CHECK: linalg.yield %[[RES]] : f32
+ // CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[GENERIC]] {{\[}}[0, 1]] : tensor<5x?xf32> into tensor<?xf32>
+ // CHECK: tensor.expand_shape %[[COLLAPSED]] {{\[}}[0, 1, 2]] : tensor<?xf32> into tensor<5x?x1xf32>
+ %0 = "tosa.reduce_prod"(%arg0) {axis = 2 : i64} : (tensor<5x?x4xf32>) -> tensor<5x?x1xf32>
+ return
+}
+
+// -----
+
+// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
+// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d0)>
+
+// CHECK-LABEL: @reduce_float_dyn_multiple
+func @reduce_float_dyn_multiple(%arg0: tensor<?x?xf32>) -> () {
+ // CHECK: %[[C0:.+]] = arith.constant 0
+ // CHECK: %[[DYN:.+]] = tensor.dim %arg0, %[[C0]]
+ // CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[DYN]]]
+ // CHECK: %[[CMIN:.+]] = arith.constant -3.40282347E+38
+ // CHECK: %[[FILL:.+]] = linalg.fill(%[[CMIN]], %[[INIT]])
+ // CHECK: %[[GENERIC:.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "reduction"]} ins(%arg0 : tensor<?x?xf32>) outs(%[[FILL]] : tensor<?xf32>)
+ // CHECK: ^bb0(%arg1: f32, %arg2: f32)
+ // CHECK: %[[CMP:.+]] = arith.cmpf ogt, %arg1, %arg2 : f32
+ // CHECK: %[[RES:.+]] = select %[[CMP]], %arg1, %arg2 : f32
+ // CHECK: linalg.yield %[[RES]] : f32
+ // CHECK: tensor.expand_shape %[[GENERIC]] {{\[}}[0, 1]] : tensor<?xf32> into tensor<?x1xf32>
+ %0 = "tosa.reduce_max"(%arg0) {axis = 1 : i64} : (tensor<?x?xf32>) -> tensor<?x1xf32>
+ return
+}
+
+// -----
+
// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d1)>
// CHECK: #[[$MAP2:.*]] = affine_map<(d0, d1) -> (d0)>
More information about the Mlir-commits
mailing list