[Mlir-commits] [mlir] fd004a4 - [mlir] tosa.concat - Add InferTensorType interface
Maya Amrami
llvmlistbot at llvm.org
Tue Mar 21 08:01:18 PDT 2023
Author: Maya Amrami
Date: 2023-03-21T17:01:08+02:00
New Revision: fd004a4986eb3ecc14f03a4ff4eef9bc06c78059
URL: https://github.com/llvm/llvm-project/commit/fd004a4986eb3ecc14f03a4ff4eef9bc06c78059
DIFF: https://github.com/llvm/llvm-project/commit/fd004a4986eb3ecc14f03a4ff4eef9bc06c78059.diff
LOG: [mlir] tosa.concat - Add InferTensorType interface
When this interface is used, a call to inferReturnTypeComponents()
is generated on creation and verification of the op.
A few changes were required in inferReturnTypeComponents():
- Emit error when it fails.
The verifier calls this method now, and it is preferable to
indicate what caused the failure.
- Fix the inferred return shapes so they have a type too.
Reviewed By: rsuderman
Differential Revision: https://reviews.llvm.org/D146132
Added:
Modified:
mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
mlir/test/Dialect/Tosa/invalid.mlir
mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index be5720caeb0de..7c8018ad64606 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -1419,8 +1419,7 @@ def Tosa_ReduceSumOp : Tosa_Op<"reduce_sum", [
// Operator: concat
//===----------------------------------------------------------------------===//
def Tosa_ConcatOp : Tosa_Op<"concat", [
- DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
- ["inferReturnTypeComponents"]>,
+ InferTensorType,
Pure]> {
let summary = "Concatenates tensors along one dimension.";
@@ -1439,6 +1438,12 @@ def Tosa_ConcatOp : Tosa_Op<"concat", [
);
let hasCanonicalizer = 1;
+
+ let extraClassDeclaration = [{
+ /// Returns true when two result types are compatible for this op;
+ /// Method used by InferTypeOpInterface.
+ static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
+ }];
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index d7bb6d0bddbf6..0a09cdd19e2d8 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -422,6 +422,12 @@ LogicalResult tosa::FFT2dOp::inferReturnTypeComponents(
return success();
}
+bool tosa::ConcatOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
+ if (l.size() != r.size() || l.size() != 1)
+ return false;
+ return succeeded(verifyCompatibleShape(l[0], r[0]));
+}
+
LogicalResult tosa::ConcatOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
@@ -447,14 +453,17 @@ LogicalResult tosa::ConcatOp::inferReturnTypeComponents(
if (outputShape[i] == ShapedType::kDynamic)
outputShape[i] = operandShape.getDimSize(i);
if (outputShape[i] != operandShape.getDimSize(i))
- return failure();
+ return emitOptionalError(location,
+ "Cannot concat tensors with
diff erent sizes"
+ " on the non-axis dimension ",
+ i);
}
hasRankedInput = true;
}
-
+ Type inputType = operands.getType()[0].cast<TensorType>().getElementType();
if (!hasRankedInput) {
- inferredReturnShapes.push_back(ShapedTypeComponents());
+ inferredReturnShapes.push_back(ShapedTypeComponents(inputType));
return success();
}
@@ -475,7 +484,7 @@ LogicalResult tosa::ConcatOp::inferReturnTypeComponents(
outputShape[axis] = concatDimSize;
- inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
+ inferredReturnShapes.push_back(ShapedTypeComponents(outputShape, inputType));
return success();
}
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index c81b19639cd64..9f9c6ca6ce641 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -36,4 +36,10 @@ func.func @test_conv2d(%arg0: tensor<1x29x29x4xi8>, %arg1: tensor<16x3x3x4xi8>,
return %0 : tensor<1x27x27x16xi8>
}
+// -----
+func.func @test_concat(%arg0 : tensor<2x1xf32>, %arg1 : tensor<2x2xf32>) -> tensor<?x?xf32> {
+ // expected-error at +1 {{Cannot concat tensors with
diff erent sizes on the non-axis dimension 1}}
+ %0 = "tosa.concat"(%arg0, %arg1) {axis = 0 : i64} : (tensor<2x1xf32>, tensor<2x2xf32>) -> tensor<?x?xf32>
+ return %0 : tensor<?x?xf32>
+}
diff --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
index 94eea3b36eae2..505350786d08d 100644
--- a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
@@ -491,16 +491,6 @@ func.func @test_concat_axis_1(%arg0 : tensor<2x1xf32>, %arg1 : tensor<2x2xf32>)
// -----
-// CHECK-LABEL: @test_concat_failure
-func.func @test_concat_failure(%arg0 : tensor<2x1xf32>, %arg1 : tensor<2x2xf32>) -> () {
- // CHECK: "tosa.concat"(%arg0, %arg1) {axis = 0 : i64} : (tensor<2x1xf32>, tensor<2x2xf32>) -> tensor<?x?xf32>
- %0 = "tosa.concat"(%arg0, %arg1) {axis = 0 : i64} : (tensor<2x1xf32>, tensor<2x2xf32>) -> tensor<?x?xf32>
-
- return
-}
-
-// -----
-
// CHECK-LABEL: @test_padding_no_const
func.func @test_padding_no_const(%arg0 : tensor<1x2xf32>, %arg1 : tensor<2x2xi32>) -> () {
// CHECK: "tosa.pad"(%arg0, %arg1) : (tensor<1x2xf32>, tensor<2x2xi32>) -> tensor<?x?xf32>
More information about the Mlir-commits
mailing list