[Mlir-commits] [mlir] 2dd396c - [mlir] tosa.reshape - Add InferTensorType interface
Aviad Cohen
llvmlistbot at llvm.org
Fri Apr 21 22:53:13 PDT 2023
Author: Aviad Cohen
Date: 2023-04-22T08:53:07+03:00
New Revision: 2dd396c18bc035f8f87fb7ca2c33b8f00c287759
URL: https://github.com/llvm/llvm-project/commit/2dd396c18bc035f8f87fb7ca2c33b8f00c287759
DIFF: https://github.com/llvm/llvm-project/commit/2dd396c18bc035f8f87fb7ca2c33b8f00c287759.diff
LOG: [mlir] tosa.reshape - Add InferTensorType interface
When this interface is used, a call to inferReturnTypeComponents()
is generated on creation and verification of the op.
Reviewed By: jpienaar
Differential Revision: https://reviews.llvm.org/D148498
Added:
Modified:
mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
mlir/test/Dialect/Tosa/invalid.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 287e62465251d..e36ab18777d14 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -1441,8 +1441,7 @@ def Tosa_ReduceSumOp : Tosa_Op<"reduce_sum", [
// Operator: concat
//===----------------------------------------------------------------------===//
def Tosa_ConcatOp : Tosa_Op<"concat", [
- InferTensorType,
- Pure]> {
+ InferTensorType, Pure]> {
let summary = "Concatenates tensors along one dimension.";
let description = [{
@@ -1503,9 +1502,7 @@ def Tosa_PadOp : Tosa_Op<"pad", [
// Operator: reshape
//===----------------------------------------------------------------------===//
def Tosa_ReshapeOp: Tosa_Op<"reshape", [
- DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
- ["inferReturnTypeComponents"]>,
- Pure]> {
+ InferTensorType, Pure]> {
let summary = "Reshape operator";
let description = [{
@@ -1526,6 +1523,12 @@ def Tosa_ReshapeOp: Tosa_Op<"reshape", [
let results = (outs
Tosa_RankedTensor:$output
);
+
+ let extraClassDeclaration = [{
+ /// Returns true when two result types are compatible for this op;
+ /// Method used by InferTypeOpInterface.
+ static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
+ }];
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index b22bd6590f37a..2da687b81663c 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -674,19 +674,27 @@ LogicalResult tosa::TileOp::inferReturnTypeComponents(
return success();
}
+bool tosa::ReshapeOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
+ if (l.size() != r.size() || l.size() != 1)
+ return false;
+ return getElementTypeOrSelf(l[0]) == getElementTypeOrSelf(r[0]);
+}
+
LogicalResult tosa::ReshapeOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
ReshapeOpAdaptor adaptor(operands, attributes);
ShapeAdaptor inputShape = operands.getShape(0);
+ Type inputType = getElementTypeOrSelf(operands.getType()[0]);
llvm::SmallVector<int64_t> newShapeValue =
convertToMlirShape(adaptor.getNewShape());
// We cannot infer from the total number of elements so we must take the
// shape attribute as exact.
if (!inputShape.hasRank() || !inputShape.hasStaticShape()) {
- inferredReturnShapes.push_back(ShapedTypeComponents(newShapeValue));
+ inferredReturnShapes.push_back(
+ ShapedTypeComponents(newShapeValue, inputType));
return success();
}
@@ -707,7 +715,8 @@ LogicalResult tosa::ReshapeOp::inferReturnTypeComponents(
val = numElements / staticMul;
}
- inferredReturnShapes.push_back(ShapedTypeComponents(newShapeValue));
+ inferredReturnShapes.push_back(
+ ShapedTypeComponents(newShapeValue, inputType));
return success();
}
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index c05a1c4577b7b..27661f4c57847 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -128,3 +128,11 @@ func.func @test_reduce_prod_type_mismatch(%arg0 : tensor<2x3x4x5xf32>) -> () {
%0 = "tosa.reduce_prod"(%arg0) {axis = 1 : i64} : (tensor<2x3x4x5xf32>) -> tensor<2x3x4x5xf32>
return
}
+
+// -----
+
+func.func @test_reshape_type_mismatch(%arg0 : tensor<13x21x3xf32>) -> () {
+ // expected-error at +1 {{'tosa.reshape' op inferred type(s) 'tensor<13x21x3x1xf32>' are incompatible with return type(s) of operation 'tensor<13x21x3x1xi32>'}}
+ %0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 13, 21, 3, 1>} : (tensor<13x21x3xf32>) -> tensor<13x21x3x1xi32>
+ return
+}
More information about the Mlir-commits
mailing list