[Mlir-commits] [mlir] e76d236 - [mlir] tosa.concat - fix isCompatibleReturnTypes()

Robert Suderman llvmlistbot at llvm.org
Mon Mar 27 15:09:16 PDT 2023


Author: Maya Amrami
Date: 2023-03-27T22:07:21Z
New Revision: e76d236c1cebfb7d31f722d242bc927b79dcdb49

URL: https://github.com/llvm/llvm-project/commit/e76d236c1cebfb7d31f722d242bc927b79dcdb49
DIFF: https://github.com/llvm/llvm-project/commit/e76d236c1cebfb7d31f722d242bc927b79dcdb49.diff

LOG: [mlir] tosa.concat - fix isCompatibleReturnTypes()

Reviewed By: eric-k256

Differential Revision: https://reviews.llvm.org/D146901

Added: 
    

Modified: 
    mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
    mlir/test/Dialect/Tosa/invalid.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 0a09cdd19e2d8..13a43516f8a2c 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -425,6 +425,8 @@ LogicalResult tosa::FFT2dOp::inferReturnTypeComponents(
 bool tosa::ConcatOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
   if (l.size() != r.size() || l.size() != 1)
     return false;
+  if (getElementTypeOrSelf(l[0]) != getElementTypeOrSelf(r[0]))
+    return false;
   return succeeded(verifyCompatibleShape(l[0], r[0]));
 }
 

diff  --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index bb7a3f5287c7f..5a120eed3a8f6 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -46,6 +46,14 @@ func.func @test_concat(%arg0 : tensor<2x1xf32>, %arg1 : tensor<2x2xf32>) -> tens
 
 // -----
 
+func.func @test_concat_element_type_mismatch(%arg0 : tensor<1x2xf32>, %arg1 : tensor<2x2xf32>) -> tensor<?x?xi8> {
+  // expected-error at +1 {{'tosa.concat' op inferred type(s) 'tensor<3x2xf32>' are incompatible with return type(s) of operation 'tensor<?x?xi8>}}
+  %0 = "tosa.concat"(%arg0, %arg1) {axis = 0 : i64} : (tensor<1x2xf32>, tensor<2x2xf32>) -> tensor<?x?xi8>
+  return %0 : tensor<?x?xi8>
+}
+
+// -----
+
 func.func @test_pad_non_const(%arg0: tensor<13x21x3xf32>, %arg1: tensor<3x2xi32>) -> tensor<13x21x3xf32> {
   // expected-error at +1 {{'tosa.pad' op padding of pad is not constant}}
   %0 = "tosa.pad"(%arg0, %arg1) : (tensor<13x21x3xf32>, tensor<3x2xi32>) -> tensor<13x21x3xf32>


        


More information about the Mlir-commits mailing list