[Mlir-commits] [mlir] e76d236 - [mlir] tosa.concat - fix isCompatibleReturnTypes()
Robert Suderman
llvmlistbot at llvm.org
Mon Mar 27 15:09:16 PDT 2023
Author: Maya Amrami
Date: 2023-03-27T22:07:21Z
New Revision: e76d236c1cebfb7d31f722d242bc927b79dcdb49
URL: https://github.com/llvm/llvm-project/commit/e76d236c1cebfb7d31f722d242bc927b79dcdb49
DIFF: https://github.com/llvm/llvm-project/commit/e76d236c1cebfb7d31f722d242bc927b79dcdb49.diff
LOG: [mlir] tosa.concat - fix isCompatibleReturnTypes()
Reviewed By: eric-k256
Differential Revision: https://reviews.llvm.org/D146901
Added:
Modified:
mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
mlir/test/Dialect/Tosa/invalid.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 0a09cdd19e2d8..13a43516f8a2c 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -425,6 +425,8 @@ LogicalResult tosa::FFT2dOp::inferReturnTypeComponents(
bool tosa::ConcatOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
if (l.size() != r.size() || l.size() != 1)
return false;
+ if (getElementTypeOrSelf(l[0]) != getElementTypeOrSelf(r[0]))
+ return false;
return succeeded(verifyCompatibleShape(l[0], r[0]));
}
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index bb7a3f5287c7f..5a120eed3a8f6 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -46,6 +46,14 @@ func.func @test_concat(%arg0 : tensor<2x1xf32>, %arg1 : tensor<2x2xf32>) -> tens
// -----
+func.func @test_concat_element_type_mismatch(%arg0 : tensor<1x2xf32>, %arg1 : tensor<2x2xf32>) -> tensor<?x?xi8> {
+ // expected-error at +1 {{'tosa.concat' op inferred type(s) 'tensor<3x2xf32>' are incompatible with return type(s) of operation 'tensor<?x?xi8>}}
+ %0 = "tosa.concat"(%arg0, %arg1) {axis = 0 : i64} : (tensor<1x2xf32>, tensor<2x2xf32>) -> tensor<?x?xi8>
+ return %0 : tensor<?x?xi8>
+}
+
+// -----
+
func.func @test_pad_non_const(%arg0: tensor<13x21x3xf32>, %arg1: tensor<3x2xi32>) -> tensor<13x21x3xf32> {
// expected-error at +1 {{'tosa.pad' op padding of pad is not constant}}
%0 = "tosa.pad"(%arg0, %arg1) : (tensor<13x21x3xf32>, tensor<3x2xi32>) -> tensor<13x21x3xf32>
More information about the Mlir-commits
mailing list