[Mlir-commits] [mlir] fd004a4 - [mlir] tosa.concat - Add InferTensorType interface

Maya Amrami llvmlistbot at llvm.org
Tue Mar 21 08:01:18 PDT 2023


Author: Maya Amrami
Date: 2023-03-21T17:01:08+02:00
New Revision: fd004a4986eb3ecc14f03a4ff4eef9bc06c78059

URL: https://github.com/llvm/llvm-project/commit/fd004a4986eb3ecc14f03a4ff4eef9bc06c78059
DIFF: https://github.com/llvm/llvm-project/commit/fd004a4986eb3ecc14f03a4ff4eef9bc06c78059.diff

LOG: [mlir] tosa.concat - Add InferTensorType interface

When this interface is used, a call to inferReturnTypeComponents()
is generated on creation and verification of the op.
A few changes were required in inferReturnTypeComponents():
- Emit error when it fails.
  The verifier calls this method now, and it is preferable to
  indicate what caused the failure.
- Fix the inferred return shapes so they have a type too.

Reviewed By: rsuderman

Differential Revision: https://reviews.llvm.org/D146132

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
    mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
    mlir/test/Dialect/Tosa/invalid.mlir
    mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index be5720caeb0de..7c8018ad64606 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -1419,8 +1419,7 @@ def Tosa_ReduceSumOp : Tosa_Op<"reduce_sum", [
 // Operator: concat
 //===----------------------------------------------------------------------===//
 def Tosa_ConcatOp : Tosa_Op<"concat", [
-    DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
-                              ["inferReturnTypeComponents"]>,
+                              InferTensorType,
     Pure]> {
   let summary = "Concatenates tensors along one dimension.";
 
@@ -1439,6 +1438,12 @@ def Tosa_ConcatOp : Tosa_Op<"concat", [
   );
 
   let hasCanonicalizer = 1;
+
+  let extraClassDeclaration = [{
+    /// Returns true when two result types are compatible for this op;
+    /// Method used by InferTypeOpInterface.
+    static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
+  }];
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index d7bb6d0bddbf6..0a09cdd19e2d8 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -422,6 +422,12 @@ LogicalResult tosa::FFT2dOp::inferReturnTypeComponents(
   return success();
 }
 
+bool tosa::ConcatOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
+  if (l.size() != r.size() || l.size() != 1)
+    return false;
+  return succeeded(verifyCompatibleShape(l[0], r[0]));
+}
+
 LogicalResult tosa::ConcatOp::inferReturnTypeComponents(
     MLIRContext *context, ::std::optional<Location> location,
     ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
@@ -447,14 +453,17 @@ LogicalResult tosa::ConcatOp::inferReturnTypeComponents(
       if (outputShape[i] == ShapedType::kDynamic)
         outputShape[i] = operandShape.getDimSize(i);
       if (outputShape[i] != operandShape.getDimSize(i))
-        return failure();
+        return emitOptionalError(location,
+                                 "Cannot concat tensors with 
diff erent sizes"
+                                 " on the non-axis dimension ",
+                                 i);
     }
 
     hasRankedInput = true;
   }
-
+  Type inputType = operands.getType()[0].cast<TensorType>().getElementType();
   if (!hasRankedInput) {
-    inferredReturnShapes.push_back(ShapedTypeComponents());
+    inferredReturnShapes.push_back(ShapedTypeComponents(inputType));
     return success();
   }
 
@@ -475,7 +484,7 @@ LogicalResult tosa::ConcatOp::inferReturnTypeComponents(
 
   outputShape[axis] = concatDimSize;
 
-  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
+  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape, inputType));
   return success();
 }
 

diff  --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index c81b19639cd64..9f9c6ca6ce641 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -36,4 +36,10 @@ func.func @test_conv2d(%arg0: tensor<1x29x29x4xi8>, %arg1: tensor<16x3x3x4xi8>,
   return %0 : tensor<1x27x27x16xi8>
 }
 
+// -----
 
+func.func @test_concat(%arg0 : tensor<2x1xf32>, %arg1 : tensor<2x2xf32>) -> tensor<?x?xf32> {
+  // expected-error at +1 {{Cannot concat tensors with 
diff erent sizes on the non-axis dimension 1}}
+  %0 = "tosa.concat"(%arg0, %arg1) {axis = 0 : i64} : (tensor<2x1xf32>, tensor<2x2xf32>) -> tensor<?x?xf32>
+  return %0 : tensor<?x?xf32>
+}

diff  --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
index 94eea3b36eae2..505350786d08d 100644
--- a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
@@ -491,16 +491,6 @@ func.func @test_concat_axis_1(%arg0 : tensor<2x1xf32>, %arg1 : tensor<2x2xf32>)
 
 // -----
 
-// CHECK-LABEL: @test_concat_failure
-func.func @test_concat_failure(%arg0 : tensor<2x1xf32>, %arg1 : tensor<2x2xf32>) -> () {
-  // CHECK: "tosa.concat"(%arg0, %arg1) {axis = 0 : i64} : (tensor<2x1xf32>, tensor<2x2xf32>) -> tensor<?x?xf32>
-  %0 = "tosa.concat"(%arg0, %arg1) {axis = 0 : i64} : (tensor<2x1xf32>, tensor<2x2xf32>) -> tensor<?x?xf32>
-
-  return
-}
-
-// -----
-
 // CHECK-LABEL: @test_padding_no_const
 func.func @test_padding_no_const(%arg0 : tensor<1x2xf32>, %arg1 : tensor<2x2xi32>) -> () {
   // CHECK: "tosa.pad"(%arg0, %arg1) : (tensor<1x2xf32>, tensor<2x2xi32>) -> tensor<?x?xf32>


        


More information about the Mlir-commits mailing list