[Mlir-commits] [mlir] 764ad3b - [mlir][tosa] Tosa elementwise broadcasting had some minor bugs

Rob Suderman llvmlistbot at llvm.org
Tue May 11 14:04:16 PDT 2021


Author: Rob Suderman
Date: 2021-05-11T13:58:06-07:00
New Revision: 764ad3b3fafbf57ca916715625fffb7df5dbeb92

URL: https://github.com/llvm/llvm-project/commit/764ad3b3fafbf57ca916715625fffb7df5dbeb92
DIFF: https://github.com/llvm/llvm-project/commit/764ad3b3fafbf57ca916715625fffb7df5dbeb92.diff

LOG: [mlir][tosa] Tosa elementwise broadcasting had some minor bugs

Updated tests to include broadcast of left and right. Includes
bypass if in-type and out-type match shape (no broadcasting).

Differential Revision: https://reviews.llvm.org/D102276

Added: 
    

Modified: 
    mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
    mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 4bf2dc7b94eb2..3718a56b38f02 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -522,8 +522,12 @@ static LogicalResult
 elementwiseMatchAndRewriteHelper(Operation *operation,
                                  PatternRewriter &rewriter) {
   auto loc = operation->getLoc();
+
+  assert(operation->getNumResults() == 1 &&
+         "All TOSA elementwise ops should only return a single result.");
+
   auto results = operation->getResults();
-  auto resultTy = operation->getOperand(0).getType().dyn_cast<ShapedType>();
+  auto resultTy = operation->getResult(0).getType().dyn_cast<ShapedType>();
 
   if (!resultTy)
     return rewriter.notifyMatchFailure(operation,
@@ -531,9 +535,6 @@ elementwiseMatchAndRewriteHelper(Operation *operation,
 
   unsigned rank = resultTy.getRank();
 
-  assert(operation->getNumResults() == 1 &&
-         "All TOSA elementwise ops should only return a single result.");
-
   // Construct the indexing maps needed for linalg.generic ops.
   SmallVector<Type> bodyArgTypes;
 
@@ -565,11 +566,18 @@ elementwiseMatchAndRewriteHelper(Operation *operation,
   // Input indexing maps may be broadcasted.
   for (Value operand : operation->getOperands()) {
     ShapedType type = operand.getType().cast<ShapedType>();
+
+    if (type.getShape() == resultTy.getShape()) {
+      operands.push_back(operand);
+      indexingMaps.push_back(rewriter.getMultiDimIdentityMap(rank));
+      continue;
+    }
+
     SmallVector<int64_t, 5> newShape;
     SmallVector<AffineExpr, 4> affineExprs;
     newShape.reserve(type.getRank());
     for (auto it : llvm::enumerate(type.getShape())) {
-      if (it.value() != 1) {
+      if (it.value() == resultTy.getDimSize(it.index())) {
         newShape.push_back(it.value());
         affineExprs.push_back(
             mlir::getAffineDimExpr(it.index(), rewriter.getContext()));

diff  --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index dbd4f9093ff40..4916c703ef8c5 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -73,6 +73,24 @@ func @test_broadcast(%arg0: tensor<1xf32>, %arg1: tensor<2xf32>) -> tensor<2xf32
 
 // -----
 
+// CHECK: #[[$MAP0:.*]] = affine_map<(d0) -> (d0)>
+// CHECK: #[[$MAP1:.*]] = affine_map<(d0) -> ()>
+
+// CHECK-LABEL: @test_broadcast_swapped_args
+func @test_broadcast_swapped_args(%arg0: tensor<2xf32>, %arg1: tensor<1xf32>) -> tensor<2xf32> {
+  // CHECK: [[INIT:%.+]] = linalg.init_tensor [2] : tensor<2xf32>
+  // CHECK: [[RESHAPE:%.+]] = linalg.tensor_reshape %arg1
+  // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP0]]], iterator_types = ["parallel"]} ins(%arg0, [[RESHAPE]] : tensor<2xf32>, tensor<f32>) outs([[INIT]] : tensor<2xf32>) {
+  // CHECK: ^bb0(%arg2: f32, %arg3: f32, %arg4: f32):
+  // CHECK:   [[ELEMENT:%.+]] = addf %arg2, %arg3 : f32
+  // CHECK:   linalg.yield [[ELEMENT]] : f32
+  // CHECK: } -> tensor<2xf32>
+  %0 = "tosa.add"(%arg0, %arg1) : (tensor<2xf32>, tensor<1xf32>) -> tensor<2xf32>
+  return %0 : tensor<2xf32>
+}
+
+// -----
+
 // CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
 // CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d1)>
 // CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1) -> (d0)>


        


More information about the Mlir-commits mailing list