[Mlir-commits] [mlir] 0600bb4 - [mlir][tosa] Elementwise operation dynamic shape support

Rob Suderman llvmlistbot at llvm.org
Thu Aug 26 11:20:16 PDT 2021


Author: Rob Suderman
Date: 2021-08-26T11:18:58-07:00
New Revision: 0600bb4d186799b55ac182484be23553407c3559

URL: https://github.com/llvm/llvm-project/commit/0600bb4d186799b55ac182484be23553407c3559
DIFF: https://github.com/llvm/llvm-project/commit/0600bb4d186799b55ac182484be23553407c3559.diff

LOG: [mlir][tosa] Elementwise operation dynamic shape support

Added dynamic shape support for elementwise operations. This assumes equal
sizes (broadcasting 1-length dynamic is problematic).

Reviewed By: NatashaKnk

Differential Revision: https://reviews.llvm.org/D108730

Added: 
    

Modified: 
    mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
    mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index f7e8b0e078ebc..beab93bc8bd0b 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -615,16 +615,27 @@ elementwiseMatchAndRewriteHelper(Operation *operation,
 
   SmallVector<Type> opResultTypes;
   SmallVector<Value> initTensors;
+
+  SmallVector<Value> dynDims;
+  dynDims.resize(results.front().getType().cast<ShapedType>().getRank());
+
+  for (auto arg : operation->getOperands()) {
+    auto operandTy = arg.getType().cast<ShapedType>();
+    for (int i = 0; i < operandTy.getRank(); i++) {
+      if (operandTy.isDynamicDim(i) && !dynDims[i])
+        dynDims[i] = rewriter.create<tensor::DimOp>(loc, arg, i);
+    }
+  }
+
+  SmallVector<Value> filteredDims;
+  for (auto dim : dynDims)
+    if (dim)
+      filteredDims.push_back(dim);
+
   for (auto result : results) {
     auto resultTy = result.getType().template cast<ShapedType>();
-    if (!resultTy.hasStaticShape())
-      return rewriter.notifyMatchFailure(
-          operation,
-          "tosa to linalg conversion expects statically shaped tensors");
-
     initTensors.push_back(rewriter.create<linalg::InitTensorOp>(
-        loc, ArrayRef<Value>({}), resultTy.getShape(),
-        resultTy.getElementType()));
+        loc, filteredDims, resultTy.getShape(), resultTy.getElementType()));
     opResultTypes.push_back(result.getType());
   }
 

diff  --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index c9e2f65907a43..8a3cf62405386 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -55,6 +55,34 @@ func @test_abs(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> {
 
 // -----
 
+// CHECK-LABEL: @test_abs
+func @test_abs(%arg0: tensor<?xf32>) -> tensor<?xf32> {
+  // CHECK: %[[C0:.+]] = constant 0
+  // CHECK: %[[DIM:.+]] = tensor.dim %arg0, %[[C0]]
+  // CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[DIM]]]
+  // CHECK: linalg.generic
+  // CHECK: absf
+  %0 = "tosa.abs"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
+  return %0 : tensor<?xf32>
+}
+
+// -----
+
+// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
+
+// CHECK-LABEL: @test_abs_dyn
+func @test_abs_dyn(%arg0: tensor<2x?xf32>) -> tensor<2x?xf32> {
+  // CHECK: %[[C1:.+]] = constant 1
+  // CHECK: %[[DIM:.+]] = tensor.dim %arg0, %[[C1]]
+  // CHECK: %[[INIT:.+]] = linalg.init_tensor [2, %[[DIM]]]
+  // CHECK: linalg.generic
+  // CHECK: absf
+  %0 = "tosa.abs"(%arg0) : (tensor<2x?xf32>) -> tensor<2x?xf32>
+  return %0 : tensor<2x?xf32>
+}
+// -----
+
+
 // CHECK: #[[$MAP0:.*]] = affine_map<(d0) -> ()>
 // CHECK: #[[$MAP1:.*]] = affine_map<(d0) -> (d0)>
 
@@ -111,14 +139,6 @@ func @test_multibroadcast(%arg0: tensor<1x3xf32>, %arg1: tensor<2x1xf32>) -> ten
 
 // -----
 
-func @test_abs(%arg0: tensor<?xf32>) -> tensor<?xf32> {
-  // expected-error @+1 {{failed to legalize operation 'tosa.abs'}}
-  %0 = "tosa.abs"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
-  return %0 : tensor<?xf32>
-}
-
-// -----
-
 // CHECK-LABEL: @test_simple_f32
 func @test_simple_f32(%arg0: tensor<1xf32>) -> () {
   // CHECK: linalg.generic


        


More information about the Mlir-commits mailing list