[Mlir-commits] [mlir] c8ac14d - [MLIR][Tosa] Pass encoding through `tosa-to-linalg`

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Jun 15 10:45:29 PDT 2023


Author: rikhuijzer
Date: 2023-06-15T19:44:52+02:00
New Revision: c8ac14d754088b19c659ca0915229f1f28776831

URL: https://github.com/llvm/llvm-project/commit/c8ac14d754088b19c659ca0915229f1f28776831
DIFF: https://github.com/llvm/llvm-project/commit/c8ac14d754088b19c659ca0915229f1f28776831.diff

LOG: [MLIR][Tosa] Pass encoding through `tosa-to-linalg`

As pointed out by @Sinclair-Dee in
https://github.com/llvm/llvm-project/issues/62304, the `tosa-to-linalg`
conversion ignored the `encoding` attribute.

Also, this patch avoids an assertion error crash on unranked tensors.
Instead, the conversion now throws a "failed to legalize" error.

Fixes #62304 and fixes #63165.

Reviewed By: mehdi_amini

Differential Revision: https://reviews.llvm.org/D152171

Added: 
    

Modified: 
    mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
    mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir
    mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 0ca05882cca74..f30aa0b1521a8 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -526,12 +526,12 @@ elementwiseMatchAndRewriteHelper(Operation *operation,
   assert(operation->getNumResults() == 1 &&
          "All TOSA elementwise ops should only return a single result.");
 
-  auto results = operation->getResults();
-  auto resultTy = dyn_cast<ShapedType>(operation->getResult(0).getType());
+  auto result = operation->getResult(0);
+  auto resultTy = dyn_cast<RankedTensorType>(result.getType());
 
   if (!resultTy)
-    return rewriter.notifyMatchFailure(operation,
-                                       "All results must be a shaped type");
+    return rewriter.notifyMatchFailure(
+        operation, "All results must be a ranked tensor type");
 
   unsigned rank = resultTy.getRank();
 
@@ -545,7 +545,7 @@ elementwiseMatchAndRewriteHelper(Operation *operation,
   SmallVector<Value> emptyTensors;
 
   SmallVector<Value> dynDims;
-  dynDims.resize(cast<ShapedType>(results.front().getType()).getRank());
+  dynDims.resize(rank);
 
   for (auto arg : operation->getOperands()) {
     auto operandTy = cast<ShapedType>(arg.getType());
@@ -557,12 +557,9 @@ elementwiseMatchAndRewriteHelper(Operation *operation,
 
   SmallVector<Value> filteredDims = condenseValues(dynDims);
 
-  for (auto result : results) {
-    auto resultTy = cast<ShapedType>(result.getType());
-    emptyTensors.push_back(rewriter.create<tensor::EmptyOp>(
-        loc, resultTy.getShape(), resultTy.getElementType(), filteredDims));
-    opResultTypes.push_back(result.getType());
-  }
+  emptyTensors.push_back(
+      rewriter.create<tensor::EmptyOp>(loc, resultTy, filteredDims));
+  opResultTypes.push_back(result.getType());
 
   auto bodyResultTypes = llvm::to_vector<4>(llvm::map_range(
       emptyTensors, [](Value v) { return getElementTypeOrSelf(v); }));

diff  --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir
index 2310ba986b5dd..17eec59369186 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt --split-input-file -pass-pipeline="builtin.module(func.func(tosa-to-linalg-named))" %s -verify-diagnostics
+// RUN: mlir-opt --split-input-file -pass-pipeline="builtin.module(func.func(tosa-to-linalg))" %s -verify-diagnostics
 
 // CHECK-LABEL: @avg_pool2d_with_unsupported_quant_type
 func.func @avg_pool2d_with_unsupported_quant_type(%arg0: tensor<1x7x7x9x!quant.uniform<i8:f32, 0.01>>) -> tensor<1x7x7x9x!quant.uniform<i8:f32, 0.01>> {
@@ -6,3 +6,12 @@ func.func @avg_pool2d_with_unsupported_quant_type(%arg0: tensor<1x7x7x9x!quant.u
   %0 = "tosa.avg_pool2d"(%arg0) {acc_type = i32, kernel = array<i64: 2, 2>, pad = array<i64: 0, 1, 0, 1>, stride = array<i64: 1, 1>} : (tensor<1x7x7x9x!quant.uniform<i8:f32, 0.01>>) -> tensor<1x7x7x9x!quant.uniform<i8:f32, 0.01>>
   return %0 : tensor<1x7x7x9x!quant.uniform<i8:f32, 0.01>>
 }
+
+// -----
+
+// CHECK-LABEL: @tensor_with_unknown_rank
+func.func @tensor_with_unknown_rank(%arg0: tensor<*xi8>) -> tensor<*xi8> {
+  // expected-error at +1 {{failed to legalize operation 'tosa.abs'}}
+  %0 = "tosa.abs"(%arg0) : (tensor<*xi8>) -> tensor<*xi8>
+  return %0 : tensor<*xi8>
+}

diff  --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index 1f66c669bafb6..54c8e574125b7 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -85,8 +85,20 @@ func.func @test_abs_dyn(%arg0: tensor<2x?xf32>) -> tensor<2x?xf32> {
   %0 = "tosa.abs"(%arg0) : (tensor<2x?xf32>) -> tensor<2x?xf32>
   return %0 : tensor<2x?xf32>
 }
+
 // -----
 
+#SparseVector = #sparse_tensor.encoding<{ lvlTypes = [ "compressed" ] }>
+
+// CHECK-LABEL: @test_encoding_passthrough
+func.func @test_encoding_passthrough(%arg0: tensor<2xi8, #SparseVector>) -> tensor<2xi8, #SparseVector> {
+  // CHECK: linalg.generic
+  // CHECK: sparse_tensor
+  %0 = "tosa.abs"(%arg0) : (tensor<2xi8, #SparseVector>) -> tensor<2xi8, #SparseVector>
+  return %0 : tensor<2xi8, #SparseVector>
+}
+
+// -----
 
 // CHECK: #[[$MAP0:.*]] = affine_map<(d0) -> ()>
 // CHECK: #[[$MAP1:.*]] = affine_map<(d0) -> (d0)>


        


More information about the Mlir-commits mailing list