[Mlir-commits] [mlir] [mlir][tosa] Add more verifiers for the following operators (PR #127923)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Feb 27 14:13:03 PST 2025


https://github.com/Jerry-Ge updated https://github.com/llvm/llvm-project/pull/127923

>From bb978ee15c57659815f1b7dc8b83af0a2b5e3e6c Mon Sep 17 00:00:00 2001
From: Jerry Ge <jerry.ge at arm.com>
Date: Wed, 19 Feb 2025 14:13:43 -0800
Subject: [PATCH] [mlir][tosa] Add more verifiers for the following operators

For ConcatOp this commit also enhances the verifier by
checking 4 another conditions:
- The input list is not empty
- The axis value is within range of the input shapes
- All inputs have the same rank
- All non concatenate axis dims have the same value

For MatmulOp:
- Checked input a, bs tensor type, element types

For the following operators, added the verifySameElementTypes check.
- PadOp
- SliceOp
- TileOp
- ReshapeOp
- TransposeOp
- GatherOp
- ScatterOp
- MaxPool2dOp
- ReverseOp
- SelectOp

Change-Id: I1e8a1017f21f617443bc40bae42189915048c750
Co-authored-by: Tai Ly <tai.ly at arm.com>
Co-authored-by: Luke Hutton <luke.hutton at arm.com>
Signed-off-by: Jerry Ge <jerry.ge at arm.com>
---
 mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td |   8 +
 mlir/lib/Dialect/Tosa/IR/TosaOps.cpp         | 208 ++++++++++++++++++-
 mlir/test/Dialect/Tosa/invalid.mlir          |  46 +++-
 3 files changed, 249 insertions(+), 13 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index ddfec2c9bfcd3..0e5df48fb9d15 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -310,6 +310,7 @@ def Tosa_MatMulOp : Tosa_InferShapedTypeOp<"matmul"> {
   ];
 
   let builders = [Tosa_MatMulOpQuantInfoBuilder];
+  let hasVerifier = 1;
 }
 
 //===----------------------------------------------------------------------===//
@@ -344,6 +345,7 @@ def Tosa_MaxPool2dOp : Tosa_InferShapedTypeOp<"max_pool2d"> {
   ];
 
   let hasCanonicalizer = 1;
+  let hasVerifier = 1;
 }
 
 //===----------------------------------------------------------------------===//
@@ -1471,6 +1473,7 @@ def Tosa_SelectOp : Tosa_ElementwiseOp<"select"> {
 
   let hasCanonicalizeMethod = 1;
   let hasFolder = 1;
+  let hasVerifier = 1;
 
   let assemblyFormat = [{
     operands attr-dict `:` `(` type($input1) `,` type($input2) `,` type($input3)
@@ -1846,6 +1849,7 @@ def Tosa_ConcatOp : Tosa_InferTensorTypeOp<"concat"> {
 
   let hasCanonicalizer = 1;
   let hasFolder = 1;
+  let hasVerifier = 1;
 
   let extraClassDeclaration = [{
     /// Returns true when two result types are compatible for this op;
@@ -2102,6 +2106,8 @@ def Tosa_GatherOp : Tosa_InferShapedTypeOp<"gather"> {
     Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>,
     Extension<[Tosa_EXT_FP8E4M3, Tosa_EXT_FP8E5M2, Tosa_EXT_BF16]>,
   ];
+
+  let hasVerifier = 1;
 }
 
 //===----------------------------------------------------------------------===//
@@ -2135,6 +2141,8 @@ def Tosa_ScatterOp : Tosa_InferShapedTypeOp<"scatter"> {
     Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>,
     Extension<[Tosa_EXT_FP8E4M3, Tosa_EXT_FP8E5M2, Tosa_EXT_BF16]>,
   ];
+
+  let hasVerifier = 1;
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 7b50eceb081dd..cb198471818e7 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -871,6 +871,71 @@ LogicalResult tosa::ConcatOp::inferReturnTypeComponents(
   return success();
 }
 
+LogicalResult tosa::ConcatOp::verify() {
+  // check that each input has same element type as output
+  auto outType = getOutput().getType();
+  const Operation::operand_range inputList = getInput1();
+
+  // Check there is at least one input
+  if (inputList.empty())
+    return emitOpError("expect at least one input");
+
+  if (!llvm::all_of(inputList, [&](auto input) {
+        return succeeded(verifySameElementTypes(
+            *this, /* inType = */ input.getType(), outType));
+      })) {
+    return failure();
+  }
+
+  const Type firstInputType = inputList.front().getType();
+  const ShapeAdaptor firstInputShape(firstInputType);
+  const int32_t axis = getAxis();
+
+  if (firstInputShape.hasRank()) {
+    // Check axis is in expected range
+    if (axis < 0 || axis >= firstInputShape.getRank())
+      return emitOpError("expect axis to be within range 0 < axis < "
+                         "rank(input1[0]), got ")
+             << axis;
+  }
+
+  const auto allOperandsHasRank = [](const Value input) {
+    return ShapeAdaptor(input.getType()).hasRank();
+  };
+  if (llvm::all_of(inputList, allOperandsHasRank)) {
+    const int64_t firstInputRank = firstInputShape.getRank();
+
+    for (const auto [index, input] : llvm::enumerate(inputList.drop_front())) {
+      const ShapeAdaptor inputShape(input.getType());
+      const int64_t inputRank = inputShape.getRank();
+      const size_t operandNum = index + 1;
+
+      // Check that each operand has the same rank
+      if (inputRank != firstInputRank)
+        return emitOpError(
+                   "expect all operands to have the same rank, but got ")
+               << firstInputRank << " vs " << inputRank << " on operands 0 and "
+               << operandNum;
+
+      // Check non-axis dims match
+      for (int i = 0; i < inputRank; i++) {
+        const int64_t inputDim = inputShape.getDimSize(i);
+        const int64_t firstInputDim = firstInputShape.getDimSize(i);
+        if (i == axis || firstInputShape.isDynamicDim(i) ||
+            inputShape.isDynamicDim(i))
+          continue;
+        if (inputDim != firstInputDim)
+          return emitOpError("expect all operand shapes to have the same sizes "
+                             "on non-axis dimensions, but got ")
+                 << inputDim << " vs " << firstInputDim << " at index " << i
+                 << " on operands 0 and " << operandNum;
+      }
+    }
+  }
+
+  return success();
+}
+
 LogicalResult tosa::EqualOp::inferReturnTypeComponents(
     MLIRContext *context, ::std::optional<Location> location,
     ValueShapeRange operands, DictionaryAttr attributes,
@@ -920,6 +985,57 @@ LogicalResult tosa::MatMulOp::inferReturnTypeComponents(
   return success();
 }
 
+LogicalResult MatMulOp::verify() {
+  auto aType = llvm::dyn_cast<ShapedType>(getA().getType());
+  auto bType = llvm::dyn_cast<ShapedType>(getB().getType());
+
+  // Must be shaped tensor types
+  if (!aType) {
+    emitOpError("expect a shaped tensor for input a, got ") << getA().getType();
+    return failure();
+  }
+  if (!bType) {
+    emitOpError("expect a shaped tensor for input b, got ") << getB().getType();
+    return failure();
+  }
+
+  auto aElementType = aType.getElementType();
+  auto bElementType = bType.getElementType();
+
+  auto aQuantizedEType =
+      llvm::dyn_cast<quant::UniformQuantizedType>(aElementType);
+  auto bQuantizedEType =
+      llvm::dyn_cast<quant::UniformQuantizedType>(bElementType);
+
+  if (aQuantizedEType || bQuantizedEType) {
+    if (!aQuantizedEType || !bQuantizedEType) {
+      emitOpError(
+          "expect operands to be both quantized or both not quantized, got ")
+          << aElementType << " and " << bElementType;
+      return failure();
+    }
+    // both a and b have quantized element types
+    auto aQuantWidth = aQuantizedEType.getStorageTypeIntegralWidth();
+    auto bQuantWidth = bQuantizedEType.getStorageTypeIntegralWidth();
+    if (aQuantWidth != bQuantWidth) {
+      emitOpError("expect quantized operands to have same widths, got ")
+          << aQuantWidth << " and " << bQuantWidth;
+      return failure();
+    }
+
+    return success();
+  }
+
+  // non-quantized element types
+  if (aElementType != bElementType) {
+    emitOpError("expect same element type for inputs a and b, got ")
+        << aElementType << " and " << bElementType;
+    return failure();
+  }
+
+  return success();
+}
+
 LogicalResult tosa::PadOp::inferReturnTypeComponents(
     MLIRContext *context, ::std::optional<Location> location,
     PadOp::Adaptor adaptor,
@@ -968,6 +1084,20 @@ LogicalResult tosa::PadOp::inferReturnTypeComponents(
 }
 
 LogicalResult tosa::PadOp::verify() {
+  if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(),
+                             /* outType = */ getOutput().getType())
+          .failed()) {
+    return failure();
+  }
+
+  if (auto padConst = getPadConst()) {
+    if (verifySameElementTypes(*this, /* inType = */ padConst.getType(),
+                               /* outType = */ getOutput().getType())
+            .failed()) {
+      return failure();
+    }
+  }
+
   RankedTensorType inputType = getInput1().getType();
   RankedTensorType outputType = getOutput().getType();
   auto paddingRank = cast<tosa::shapeType>(getPadding().getType()).getRank();
@@ -1041,6 +1171,10 @@ LogicalResult tosa::SliceOp::inferReturnTypeComponents(
 }
 
 LogicalResult tosa::SliceOp::verify() {
+  if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(),
+                             /* outType = */ getOutput().getType())
+          .failed())
+    return failure();
   auto inputType = llvm::dyn_cast<RankedTensorType>(getInput1().getType());
   if (!inputType)
     return success();
@@ -1048,14 +1182,12 @@ LogicalResult tosa::SliceOp::verify() {
   auto startShapeRank =
       llvm::cast<tosa::shapeType>(getStart().getType()).getRank();
   if (inputType.getRank() != startShapeRank)
-    return emitOpError(
-        "length of start attribute is not equal rank of input shape");
+    return emitOpError("length of start is not equal to rank of input shape");
 
   auto sizeShapeRank =
       llvm::cast<tosa::shapeType>(getSize().getType()).getRank();
   if (inputType.getRank() != sizeShapeRank)
-    return emitOpError(
-        "length of size attribute is not equal rank of input shape");
+    return emitOpError("length of size is not equal to rank of input shape");
 
   return success();
 }
@@ -1260,6 +1392,11 @@ LogicalResult tosa::TileOp::inferReturnTypeComponents(
 }
 
 LogicalResult tosa::TileOp::verify() {
+  if (verifySameElementTypes(*this, /* intype = */ getInput1().getType(),
+                             /* outType = */ getOutput().getType())
+          .failed()) {
+    return failure();
+  }
   ShapedType inputType = llvm::cast<ShapedType>(getInput1().getType());
   ShapedType outputType = llvm::cast<ShapedType>(getType());
 
@@ -1341,6 +1478,11 @@ LogicalResult tosa::ReshapeOp::inferReturnTypeComponents(
 }
 
 llvm::LogicalResult tosa::ReshapeOp::verify() {
+  if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(),
+                             /* outType = */ getOutput().getType())
+          .failed()) {
+    return failure();
+  }
   TensorType inputType = getInput1().getType();
   RankedTensorType outputType = getType();
 
@@ -1528,6 +1670,11 @@ LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
 }
 
 LogicalResult tosa::TransposeOp::verify() {
+  if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(),
+                             /* outType = */ getOutput().getType())
+          .failed()) {
+    return failure();
+  }
   TensorType inputType = getInput1().getType();
   TensorType outputType = getOutput().getType();
   const llvm::ArrayRef<int32_t> constantPerms = getPerms();
@@ -1628,6 +1775,11 @@ LogicalResult tosa::GatherOp::inferReturnTypeComponents(
   return success();
 }
 
+LogicalResult tosa::GatherOp::verify() {
+  return verifySameElementTypes(*this, /* inType = */ getValues().getType(),
+                                /* outType = */ getOutput().getType());
+}
+
 LogicalResult tosa::ResizeOp::inferReturnTypeComponents(
     MLIRContext *context, ::std::optional<Location> location,
     ResizeOp::Adaptor adaptor,
@@ -1789,6 +1941,18 @@ LogicalResult tosa::ScatterOp::inferReturnTypeComponents(
   return success();
 }
 
+LogicalResult tosa::ScatterOp::verify() {
+  if (verifySameElementTypes(*this, /* inType = */ getValuesIn().getType(),
+                             /* outType = */ getValuesOut().getType())
+          .failed() ||
+      verifySameElementTypes(*this, /* inType = */ getInput().getType(),
+                             /* outType = */ getValuesOut().getType())
+          .failed()) {
+    return failure();
+  }
+  return success();
+}
+
 static LogicalResult ReduceInferReturnTypes(
     ShapeAdaptor operandShape, Type inputType, IntegerAttr axis,
     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
@@ -2244,6 +2408,11 @@ LogicalResult MaxPool2dOp::inferReturnTypeComponents(
                                  inferredReturnShapes);
 }
 
+LogicalResult MaxPool2dOp::verify() {
+  return verifySameElementTypes(*this, /* intype = */ getInput().getType(),
+                                /* outType = */ getOutput().getType());
+}
+
 LogicalResult DepthwiseConv2DOp::inferReturnTypeComponents(
     MLIRContext *context, ::std::optional<Location> location,
     DepthwiseConv2DOp::Adaptor adaptor,
@@ -2546,6 +2715,10 @@ void IfOp::print(OpAsmPrinter &p) {
 }
 
 LogicalResult ReverseOp::verify() {
+  if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(),
+                             /* outType = */ getOutput().getType())
+          .failed())
+    return failure();
   TensorType inputType = getInput1().getType();
   TensorType outputType = getOutput().getType();
   int32_t reverseAxis = getAxis();
@@ -2574,6 +2747,33 @@ LogicalResult ReverseOp::verify() {
   return success();
 }
 
+LogicalResult tosa::SelectOp::verify() {
+  // verify input2 and input3 have same element type as output
+  if (verifySameElementTypes(*this, /* inType = */ getInput2().getType(),
+                             /* outType = */ getOutput().getType())
+          .failed() ||
+      verifySameElementTypes(*this, /* inType = */ getInput3().getType(),
+                             /* outType = */ getOutput().getType())
+          .failed()) {
+    return failure();
+  }
+  // verify input1 has element type of bool
+  auto predicateType = llvm::dyn_cast<ShapedType>(getInput1().getType());
+  if (!predicateType) {
+    emitOpError("expect shaped tensor for input1, got ")
+        << getInput1().getType();
+    return failure();
+  }
+  auto predicateElementType = predicateType.getElementType();
+  if (!predicateElementType.isInteger(1)) {
+    emitOpError("expect element type of bool for input1, got ")
+        << predicateElementType;
+    return failure();
+  }
+
+  return success();
+}
+
 // parse and print of WhileOp refer to the implementation of SCF dialect.
 ParseResult WhileOp::parse(OpAsmParser &parser, OperationState &result) {
   SmallVector<OpAsmParser::Argument, 4> regionArgs;
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index 123c65e1b4fcd..d0d98af38abc4 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -193,8 +193,7 @@ func.func @test_conv2d_quant_any_result(%arg0: tensor<1x4x4x4x!quant.any<i8<-8:7
 // -----
 
 func.func @test_concat(%arg0 : tensor<2x1xf32>, %arg1 : tensor<2x2xf32>) -> tensor<?x?xf32> {
-  // expected-error at +2 {{failed to infer returned types}}
-  // expected-error at +1 {{Cannot concat tensors with different sizes on the non-axis dimension 1}}
+  // expected-error at +1 {{'tosa.concat' op expect all operand shapes to have the same sizes on non-axis dimensions, but got 2 vs 1 at index 1 on operands 0 and 1}}
   %0 = tosa.concat %arg0, %arg1 {axis = 0 : i32} : (tensor<2x1xf32>, tensor<2x2xf32>) -> tensor<?x?xf32>
   return %0 : tensor<?x?xf32>
 }
@@ -202,14 +201,36 @@ func.func @test_concat(%arg0 : tensor<2x1xf32>, %arg1 : tensor<2x2xf32>) -> tens
 // -----
 
 func.func @test_concat_element_type_mismatch(%arg0 : tensor<1x2xf32>, %arg1 : tensor<2x2xf32>) -> tensor<?x?xi8> {
-  // expected-error at +2 {{failed to infer returned types}}
-  // expected-error at +1 {{'tosa.concat' op inferred type(s) 'tensor<3x2xf32>' are incompatible with return type(s) of operation 'tensor<?x?xi8>}}
+  // expected-error at +1 {{'tosa.concat' op expect input and output to have same element type, got 'f32' and 'i8'}}
   %0 = tosa.concat %arg0, %arg1 {axis = 0 : i32} : (tensor<1x2xf32>, tensor<2x2xf32>) -> tensor<?x?xi8>
   return %0 : tensor<?x?xi8>
 }
 
 // -----
 
+func.func @test_concat_zero_inputs() {
+  // expected-error at +1 {{'tosa.concat' op expect at least one input}}
+  %0 = tosa.concat {axis = 0 : i32} : () -> tensor<*xf32>
+}
+
+// -----
+
+func.func @test_concat_axis_negative(%arg0: tensor<1x2xf32>, %arg1: tensor<2x2xf32>) -> tensor<2x2xf32> {
+  // expected-error at +1 {{'tosa.concat' op expect axis to be within range 0 < axis < rank(input1[0]), got -1}}
+  %0 = tosa.concat %arg0, %arg1 {axis = -1 : i32} : (tensor<1x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
+  return %0 : tensor<2x2xf32>
+}
+
+// -----
+
+func.func @test_concat_axis_out_of_range(%arg0: tensor<1x2xf32>, %arg1: tensor<2x2xf32>) -> tensor<2x2xf32> {
+  // expected-error at +1 {{'tosa.concat' op expect axis to be within range 0 < axis < rank(input1[0]), got 3}}
+  %0 = tosa.concat %arg0, %arg1 {axis = 3 : i32} : (tensor<1x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
+  return %0 : tensor<2x2xf32>
+}
+
+// -----
+
 func.func @test_pad_non_const(%arg0: tensor<13x21x3xf32>, %arg1: !tosa.shape<6>) -> tensor<13x21x3xf32> {
   // expected-error at +1 {{'tosa.pad' op shape operand is not compile time resolvable}}
   %0 = tosa.pad %arg0, %arg1 : (tensor<13x21x3xf32>, !tosa.shape<6>) -> tensor<13x21x3xf32>
@@ -236,6 +257,14 @@ func.func @test_pad_io_rank_mismatch(%arg0: tensor<13x21xf32>) {
 
 // -----
 
+func.func @test_concat_input_rank_mismatch(%arg0: tensor<1x2x3xf32>, %arg1: tensor<1x2xf32>) -> tensor<2x2x3xf32> {
+  // expected-error at +1 {{'tosa.concat' op expect all operands to have the same rank, but got 3 vs 2 on operands 0 and 1}}
+  %0 = tosa.concat %arg0, %arg1 {axis = 0 : i32} : (tensor<1x2x3xf32>, tensor<1x2xf32>) -> tensor<2x2x3xf32>
+  return %0 : tensor<2x2x3xf32>
+}
+
+// -----
+
 func.func @test_pad_invalid_padding_rank(%arg0: tensor<13x21xf32>) {
   %0 = tosa.const_shape {value = dense<1> : tensor<6xindex>} : () -> !tosa.shape<6>
   // expected-error at +1 {{'tosa.pad' op expected padding tensor dim 0 to have size 4 (2*rank(shape1)) but got size 6}}
@@ -430,8 +459,7 @@ func.func @test_reduce_min_invalid_output_rank(%arg0 : tensor<i32>) -> () {
 
 func.func @test_reshape_type_mismatch(%arg0 : tensor<13x21x3xf32>) -> () {
   %1 = tosa.const_shape {value = dense<[13, 21, 3, 1]> : tensor<4xindex>} : () -> !tosa.shape<4>
-  // expected-error at +2 {{failed to infer returned types}}
-  // expected-error at +1 {{'tosa.reshape' op inferred type(s) 'tensor<13x21x3x1xf32>' are incompatible with return type(s) of operation 'tensor<13x21x3x1xi32>'}}
+  // expected-error at +1 {{'tosa.reshape' op expect input and output to have same element type, got 'f32' and 'i32'}}
   %0 = tosa.reshape %arg0, %1 : (tensor<13x21x3xf32>, !tosa.shape<4>) -> tensor<13x21x3x1xi32>
   return
 }
@@ -521,7 +549,7 @@ func.func @test_reshape_invalid_tensor_dim(%arg0 : tensor<4x?xf32>) -> () {
 
 func.func @test_reverse_axis_out_of_range(%arg0 : tensor<13x21x3xf32>) -> () {
   // expected-error at +1 {{'tosa.reverse' op expect input tensor rank (3) to be larger than reverse axis (5)}}
-  %0 = tosa.reverse %arg0 {axis = 5 : i32} : (tensor<13x21x3xf32>) -> tensor<?x?x?xi32>
+  %0 = tosa.reverse %arg0 {axis = 5 : i32} : (tensor<13x21x3xf32>) -> tensor<?x?x?xf32>
   return
 }
 
@@ -634,7 +662,7 @@ func.func @test_slice_invalid_start() {
   %0 = tensor.empty() : tensor<4x31x31xf32>
   %start = tosa.const_shape {value = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
   %size = tosa.const_shape {value = dense<[1, 1, 1]> : tensor<3xindex>} : () -> !tosa.shape<3>
-  // expected-error at +1 {{'tosa.slice' op length of start attribute is not equal rank of input shape}}
+  // expected-error at +1 {{'tosa.slice' op length of start is not equal to rank of input shape}}
   %3 = tosa.slice %0, %start, %size : (tensor<4x31x31xf32>, !tosa.shape<2>, !tosa.shape<3>) -> tensor<*xf32>
   return
 }
@@ -645,7 +673,7 @@ func.func @test_slice_invalid_size() {
   %0 = tensor.empty() : tensor<4x31x31xf32>
   %start = tosa.const_shape {value = dense<[1, 1, 1]> : tensor<3xindex>} : () -> !tosa.shape<3>
   %size = tosa.const_shape {value = dense<[1]> : tensor<1xindex>} : () -> !tosa.shape<1>
-  // expected-error at +1 {{'tosa.slice' op length of size attribute is not equal rank of input shape}}
+  // expected-error at +1 {{'tosa.slice' op length of size is not equal to rank of input shape}}
   %3 = tosa.slice %0, %start, %size : (tensor<4x31x31xf32>, !tosa.shape<3>, !tosa.shape<1>) -> tensor<*xf32>
   return
 }



More information about the Mlir-commits mailing list