[Mlir-commits] [mlir] 2dd396c - [mlir] tosa.reshape - Add InferTensorType interface

Aviad Cohen llvmlistbot at llvm.org
Fri Apr 21 22:53:13 PDT 2023


Author: Aviad Cohen
Date: 2023-04-22T08:53:07+03:00
New Revision: 2dd396c18bc035f8f87fb7ca2c33b8f00c287759

URL: https://github.com/llvm/llvm-project/commit/2dd396c18bc035f8f87fb7ca2c33b8f00c287759
DIFF: https://github.com/llvm/llvm-project/commit/2dd396c18bc035f8f87fb7ca2c33b8f00c287759.diff

LOG: [mlir] tosa.reshape - Add InferTensorType interface

When this interface is used, a call to inferReturnTypeComponents()
is generated on creation and verification of the op.

Reviewed By: jpienaar

Differential Revision: https://reviews.llvm.org/D148498

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
    mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
    mlir/test/Dialect/Tosa/invalid.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 287e62465251d..e36ab18777d14 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -1441,8 +1441,7 @@ def Tosa_ReduceSumOp : Tosa_Op<"reduce_sum", [
 // Operator: concat
 //===----------------------------------------------------------------------===//
 def Tosa_ConcatOp : Tosa_Op<"concat", [
-                              InferTensorType,
-    Pure]> {
+    InferTensorType, Pure]> {
   let summary = "Concatenates tensors along one dimension.";
 
   let description = [{
@@ -1503,9 +1502,7 @@ def Tosa_PadOp : Tosa_Op<"pad", [
 // Operator: reshape
 //===----------------------------------------------------------------------===//
 def Tosa_ReshapeOp: Tosa_Op<"reshape", [
-    DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
-                              ["inferReturnTypeComponents"]>,
-    Pure]> {
+  InferTensorType, Pure]> {
   let summary = "Reshape operator";
 
   let description = [{
@@ -1526,6 +1523,12 @@ def Tosa_ReshapeOp: Tosa_Op<"reshape", [
   let results = (outs
     Tosa_RankedTensor:$output
   );
+
+  let extraClassDeclaration = [{
+    /// Returns true when two result types are compatible for this op;
+    /// Method used by InferTypeOpInterface.
+    static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
+  }];
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index b22bd6590f37a..2da687b81663c 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -674,19 +674,27 @@ LogicalResult tosa::TileOp::inferReturnTypeComponents(
   return success();
 }
 
+bool tosa::ReshapeOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
+  if (l.size() != r.size() || l.size() != 1)
+    return false;
+  return getElementTypeOrSelf(l[0]) == getElementTypeOrSelf(r[0]);
+}
+
 LogicalResult tosa::ReshapeOp::inferReturnTypeComponents(
     MLIRContext *context, ::std::optional<Location> location,
     ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
   ReshapeOpAdaptor adaptor(operands, attributes);
   ShapeAdaptor inputShape = operands.getShape(0);
+  Type inputType = getElementTypeOrSelf(operands.getType()[0]);
   llvm::SmallVector<int64_t> newShapeValue =
       convertToMlirShape(adaptor.getNewShape());
 
   // We cannot infer from the total number of elements so we must take the
   // shape attribute as exact.
   if (!inputShape.hasRank() || !inputShape.hasStaticShape()) {
-    inferredReturnShapes.push_back(ShapedTypeComponents(newShapeValue));
+    inferredReturnShapes.push_back(
+        ShapedTypeComponents(newShapeValue, inputType));
     return success();
   }
 
@@ -707,7 +715,8 @@ LogicalResult tosa::ReshapeOp::inferReturnTypeComponents(
       val = numElements / staticMul;
   }
 
-  inferredReturnShapes.push_back(ShapedTypeComponents(newShapeValue));
+  inferredReturnShapes.push_back(
+      ShapedTypeComponents(newShapeValue, inputType));
   return success();
 }
 

diff  --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index c05a1c4577b7b..27661f4c57847 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -128,3 +128,11 @@ func.func @test_reduce_prod_type_mismatch(%arg0 : tensor<2x3x4x5xf32>) -> () {
   %0 = "tosa.reduce_prod"(%arg0) {axis = 1 : i64} : (tensor<2x3x4x5xf32>) -> tensor<2x3x4x5xf32>
   return
 }
+
+// -----
+
+func.func @test_reshape_type_mismatch(%arg0 : tensor<13x21x3xf32>) -> () {
+  // expected-error at +1 {{'tosa.reshape' op inferred type(s) 'tensor<13x21x3x1xf32>' are incompatible with return type(s) of operation 'tensor<13x21x3x1xi32>'}}
+  %0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 13, 21, 3, 1>} : (tensor<13x21x3xf32>) -> tensor<13x21x3x1xi32>
+  return
+}


        


More information about the Mlir-commits mailing list