[Mlir-commits] [mlir] c19a412 - [MLIR][TOSA] Tosa elementwise broadcasting
Rob Suderman
llvmlistbot at llvm.org
Wed Feb 10 15:29:03 PST 2021
Author: Rob Suderman
Date: 2021-02-10T15:28:18-08:00
New Revision: c19a4128095da2191da7c04a862fb298bcf1298c
URL: https://github.com/llvm/llvm-project/commit/c19a4128095da2191da7c04a862fb298bcf1298c
DIFF: https://github.com/llvm/llvm-project/commit/c19a4128095da2191da7c04a862fb298bcf1298c.diff
LOG: [MLIR][TOSA] Tosa elementwise broadcasting
Added support for broadcasting size-1 dimensions for TOSA elemtnwise
operations.
Differential Revision: https://reviews.llvm.org/D96190
Added:
Modified:
mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 8ecd15d62b88..fcc5a5230a45 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -152,23 +152,8 @@ elementwiseMatchAndRewriteHelper(Operation *operation,
return rewriter.notifyMatchFailure(operation,
"All results must be a shaped type");
- // For now require no broadcasting. Consider making it support broadcasting
- // operations.
- Type uniqueInTy = operation->getOperand(0).getType();
- bool allInputTypesEqual =
- llvm::all_of(operation->getOperandTypes(),
- [&](Type operandTy) { return operandTy == uniqueInTy; });
- if (!allInputTypesEqual)
- return rewriter.notifyMatchFailure(operation,
- "All operands must have the same type");
- bool resultAndInputShapeEqual =
- llvm::all_of(operation->getResultTypes(), [&](Type resultTy) {
- return resultTy.cast<ShapedType>().getShape() == t0.getShape();
- });
-
- if (!resultAndInputShapeEqual)
- return rewriter.notifyMatchFailure(
- operation, "All results must have the same shape as the input");
+ assert(operation->getNumResults() == 1 &&
+ "All TOSA elementwise ops should only return a single result.");
// Construct the indexing maps needed for linalg.generic ops.
SmallVector<Type> bodyArgTypes;
@@ -194,12 +179,30 @@ elementwiseMatchAndRewriteHelper(Operation *operation,
auto bodyResultTypes = llvm::to_vector<4>(llvm::map_range(
initTensors, [](Value v) { return getElementTypeOrSelf(v); }));
- // Supports only non-broadcasted operation. Shoudl consider update indexing
- // map to be multidimensional.
unsigned nloops = t0.getRank();
- AffineMap commonIndexingMap = rewriter.getMultiDimIdentityMap(nloops);
- SmallVector<AffineMap, 2> indexingMaps(
- operation->getNumOperands() + bodyResultTypes.size(), commonIndexingMap);
+ SmallVector<AffineMap, 2> indexingMaps;
+ indexingMaps.reserve(operation->getNumOperands() + bodyResultTypes.size());
+
+ // Input indexing maps may be broadcasted.
+ for (Type types : operation->getOperandTypes()) {
+ auto shape = types.cast<ShapedType>().getShape();
+ SmallVector<AffineExpr, 4> dimExprs;
+ dimExprs.reserve(nloops);
+ for (unsigned i = 0; i < nloops; ++i) {
+ // If the dimension is one we can broadcast the input with a constant
+ // affine expression.
+ if (shape[i] == 1)
+ dimExprs.push_back(rewriter.getAffineConstantExpr(0));
+ else
+ dimExprs.push_back(rewriter.getAffineDimExpr(i));
+ }
+ indexingMaps.push_back(AffineMap::get(/*dimCount=*/nloops,
+ /*symbolCount=*/0, dimExprs,
+ rewriter.getContext()));
+ }
+
+ indexingMaps.append(operation->getNumResults(),
+ rewriter.getMultiDimIdentityMap(nloops));
bool didEncounterError = false;
auto linalgOp = rewriter.create<linalg::GenericOp>(
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index e416246a19a4..8963544838e1 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -1,11 +1,11 @@
// RUN: mlir-opt --split-input-file --tosa-to-linalg-on-tensors %s -verify-diagnostics -o -| FileCheck %s
-// CHECK: #map = affine_map<() -> ()>
+// CHECK: #[[$MAP0:.*]] = affine_map<() -> ()>
// CHECK-LABEL: @test_abs
func @test_abs(%arg0: tensor<f32>) -> tensor<f32> {
// CHECK: [[INIT:%.+]] = linalg.init_tensor [] : tensor<f32>
- // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = []} ins(%arg0 : tensor<f32>) outs([[INIT]] : tensor<f32>) {
+ // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = []} ins(%arg0 : tensor<f32>) outs([[INIT]] : tensor<f32>) {
// CHECK: ^bb0(%arg1: f32, %arg2: f32):
// CHECK: [[ELEMENT:%.+]] = absf %arg1
// CHECK: linalg.yield [[ELEMENT]] : f32
@@ -19,54 +19,73 @@ func @test_abs(%arg0: tensor<f32>) -> tensor<f32> {
// -----
-// CHECK: #map = affine_map<(d0) -> (d0)>
+// CHECK: #[[$MAP0:.*]] = affine_map<(d0) -> (d0)>
// CHECK-LABEL: @test_abs
-func @test_abs(%arg0: tensor<1xf32>) -> tensor<1xf32> {
- // CHECK: [[INIT:%.+]] = linalg.init_tensor [1] : tensor<1xf32>
- // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%arg0 : tensor<1xf32>) outs([[INIT]] : tensor<1xf32>) {
+func @test_abs(%arg0: tensor<2xf32>) -> tensor<2xf32> {
+ // CHECK: [[INIT:%.+]] = linalg.init_tensor [2] : tensor<2xf32>
+ // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel"]} ins(%arg0 : tensor<2xf32>) outs([[INIT]] : tensor<2xf32>) {
// CHECK: ^bb0(%arg1: f32, %arg2: f32):
// CHECK: [[ELEMENT:%.+]] = absf %arg1
// CHECK: linalg.yield [[ELEMENT]] : f32
- // CHECK: } -> tensor<1xf32>
- %0 = "tosa.abs"(%arg0) : (tensor<1xf32>) -> tensor<1xf32>
+ // CHECK: } -> tensor<2xf32>
+ %0 = "tosa.abs"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
// CHECK: return [[GENERIC]]
- return %0 : tensor<1xf32>
+ return %0 : tensor<2xf32>
}
// -----
-// CHECK: #map = affine_map<(d0, d1) -> (d0, d1)>
+// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
// CHECK-LABEL: @test_abs
-func @test_abs(%arg0: tensor<1x2xf32>) -> tensor<1x2xf32> {
- // CHECK: [[INIT:%.+]] = linalg.init_tensor [1, 2] : tensor<1x2xf32>
- // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0 : tensor<1x2xf32>) outs([[INIT]] : tensor<1x2xf32>) {
+func @test_abs(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> {
+ // CHECK: [[INIT:%.+]] = linalg.init_tensor [2, 3] : tensor<2x3xf32>
+ // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel", "parallel"]} ins(%arg0 : tensor<2x3xf32>) outs([[INIT]] : tensor<2x3xf32>) {
// CHECK: ^bb0(%arg1: f32, %arg2: f32):
// CHECK: [[ELEMENT:%.+]] = absf %arg1
// CHECK: linalg.yield [[ELEMENT]] : f32
- // CHECK: } -> tensor<1x2xf32>
- %0 = "tosa.abs"(%arg0) : (tensor<1x2xf32>) -> tensor<1x2xf32>
+ // CHECK: } -> tensor<2x3xf32>
+ %0 = "tosa.abs"(%arg0) : (tensor<2x3xf32>) -> tensor<2x3xf32>
// CHECK: return [[GENERIC]]
- return %0 : tensor<1x2xf32>
+ return %0 : tensor<2x3xf32>
}
// -----
-func @test_add(%arg0: tensor<1xf32>, %arg1: tensor<2xf32>) -> tensor<2xf32> {
- // expected-error @+1 {{failed to legalize operation 'tosa.add'}}
+// CHECK: #[[$MAP0:.*]] = affine_map<(d0) -> (0)>
+// CHECK: #[[$MAP1:.*]] = affine_map<(d0) -> (d0)>
+
+// CHECK-LABEL: @test_broadcast
+func @test_broadcast(%arg0: tensor<1xf32>, %arg1: tensor<2xf32>) -> tensor<2xf32> {
+ // CHECK: [[INIT:%.+]] = linalg.init_tensor [2] : tensor<2xf32>
+ // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP1]]], iterator_types = ["parallel"]} ins(%arg0, %arg1 : tensor<1xf32>, tensor<2xf32>) outs([[INIT]] : tensor<2xf32>) {
+ // CHECK: ^bb0(%arg2: f32, %arg3: f32, %arg4: f32):
+ // CHECK: [[ELEMENT:%.+]] = addf %arg2, %arg3 : f32
+ // CHECK: linalg.yield [[ELEMENT]] : f32
+ // CHECK: } -> tensor<2xf32>
%0 = "tosa.add"(%arg0, %arg1) : (tensor<1xf32>, tensor<2xf32>) -> tensor<2xf32>
return %0 : tensor<2xf32>
}
// -----
-func @test_add(%arg0: tensor<1xf32>, %arg1: tensor<f32>) -> tensor<1xf32> {
- // expected-error @+1 {{failed to legalize operation 'tosa.add'}}
- %0 = "tosa.add"(%arg0, %arg1) : (tensor<1xf32>, tensor<f32>) -> tensor<1xf32>
- return %0 : tensor<1xf32>
+// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (0, d1)>
+// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d0, 0)>
+// CHECK: #[[$MAP2:.*]] = affine_map<(d0, d1) -> (d0, d1)>
+
+// CHECK-LABEL: @test_multibroadcast
+func @test_multibroadcast(%arg0: tensor<1x3xf32>, %arg1: tensor<2x1xf32>) -> tensor<2x3xf32> {
+ // CHECK: [[INIT:%.+]] = linalg.init_tensor [2, 3] : tensor<2x3xf32>
+ // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<1x3xf32>, tensor<2x1xf32>) outs([[INIT]] : tensor<2x3xf32>) {
+ // CHECK: ^bb0(%arg2: f32, %arg3: f32, %arg4: f32):
+ // CHECK: [[ELEMENT:%.+]] = addf %arg2, %arg3 : f32
+ // CHECK: linalg.yield [[ELEMENT]] : f32
+ // CHECK: } -> tensor<2x3xf32>
+ %0 = "tosa.add"(%arg0, %arg1) : (tensor<1x3xf32>, tensor<2x1xf32>) -> tensor<2x3xf32>
+ return %0 : tensor<2x3xf32>
}
// -----
More information about the Mlir-commits
mailing list