[Mlir-commits] [mlir] [mlir][Tosa] Fix attr type of out_shape for `tosa.transpose_conv2d` (PR #108041)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Sep 10 08:17:32 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-ods
@llvm/pr-subscribers-mlir-tosa
Author: Longsheng Mou (CoTinker)
<details>
<summary>Changes</summary>
This patch fixes attr type of out_shape, which is i64 dense array attribute with at least 4 elements.
- Fix description of DenseArrayMaxCt
- Add DenseArrayMinCt and move it to CommonAttrConstraints.td
- Change type of out_shape to DenseArrayMinCt
Fixes #<!-- -->107804.
---
Full diff: https://github.com/llvm/llvm-project/pull/108041.diff
4 Files Affected:
- (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td (+1-1)
- (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td (+2-3)
- (modified) mlir/include/mlir/IR/CommonAttrConstraints.td (+8)
- (modified) mlir/test/Dialect/Tosa/invalid.mlir (+9)
``````````diff
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index ab6daa39708d13..8ad741b3e65fe5 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -352,7 +352,7 @@ def Tosa_TransposeConv2DOp : Tosa_InferShapedTypeOp<"transpose_conv2d"> {
Tosa_IntArrayAttr4:$out_pad,
Tosa_IntArrayAttr2:$stride,
- Tosa_IntArrayAttrUpto4:$out_shape,
+ Tosa_IntArrayAttrAtLeast4:$out_shape,
OptionalAttr<Tosa_ConvOpQuantizationAttr>:$quantization_info,
DefaultValuedOptionalAttr<BoolAttr, "false">:$local_bound
);
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
index 14fc9c7a6730cc..99f430cefa2f1e 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
@@ -151,9 +151,6 @@ def Tosa_Int8Like : Tosa_TypeLike<[Tosa_Int8], "signless-integer-8-bit-like">;
//===----------------------------------------------------------------------===//
// Attribute predicates and classes.
//===----------------------------------------------------------------------===//
-class DenseArrayMaxCt<int n> : AttrConstraint<
- CPred<"::llvm::cast<::mlir::DenseArrayAttr>($_self).size() <= " # n>,
- "with at least " # n # " elements">;
def Tosa_Fp32ArrayAttr2 : ConfinedAttr<DenseF32ArrayAttr, [DenseArrayCount<2>]>;
def Tosa_Fp32ArrayAttr3 : ConfinedAttr<DenseF32ArrayAttr, [DenseArrayCount<3>]>;
@@ -171,6 +168,8 @@ def Tosa_IntArrayAttrUpto2 : ConfinedAttr<DenseI64ArrayAttr, [DenseArrayMaxCt<2>
def Tosa_IntArrayAttrUpto4 : ConfinedAttr<DenseI64ArrayAttr, [DenseArrayMaxCt<4>]>;
def Tosa_IntArrayAttrUpto5 : ConfinedAttr<DenseI64ArrayAttr, [DenseArrayMaxCt<5>]>;
+def Tosa_IntArrayAttrAtLeast4 : ConfinedAttr<DenseI64ArrayAttr, [DenseArrayMinCt<4>]>;
+
def Tosa_FloatAttr : Attr<CPred<"::llvm::isa<::mlir::FloatAttr>($_self)">,
"arbitrary float attribute"> {
let storageType = [{ ::mlir::FloatAttr }];
diff --git a/mlir/include/mlir/IR/CommonAttrConstraints.td b/mlir/include/mlir/IR/CommonAttrConstraints.td
index 6774a7c568315d..853fb318c76e71 100644
--- a/mlir/include/mlir/IR/CommonAttrConstraints.td
+++ b/mlir/include/mlir/IR/CommonAttrConstraints.td
@@ -789,6 +789,14 @@ class DenseArrayCount<int n> : AttrConstraint<
CPred<"::llvm::cast<::mlir::DenseArrayAttr>($_self).size() == " #n>,
"with exactly " # n # " elements">;
+class DenseArrayMaxCt<int n> : AttrConstraint<
+ CPred<"::llvm::cast<::mlir::DenseArrayAttr>($_self).size() <= " # n>,
+ "with at most " # n # " elements">;
+
+class DenseArrayMinCt<int n> : AttrConstraint<
+ CPred<"::llvm::cast<::mlir::DenseArrayAttr>($_self).size() >= " # n>,
+ "with at least " # n # " elements">;
+
class DenseArrayStrictlyPositive<DenseArrayAttrBase arrayType> : AttrConstraint<
CPred<"::llvm::all_of(::llvm::cast<" # arrayType #">($_self).asArrayRef(), "
"[&](auto v) { return v > 0; })">,
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index 418f7687b3cce8..0c38206e69423f 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -526,3 +526,12 @@ func.func @test_table_io_shape_mismatch(%arg0: tensor<?x16xi16>, %arg1: tensor<6
%0 = tosa.table %arg0, %arg1 : (tensor<?x16xi16>, tensor<6xi16>) -> tensor<?x15xi16>
return
}
+
+// -----
+
+// CHECK-LABEL: test_transpose_conv2d_invalid_outshape
+func.func @test_transpose_conv2d_invalid_outshape(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<16x1x1x8xf32>, %arg2: tensor<16xf32>) -> tensor<1x32x32x16xf32> {
+ // expected-error at +1 {{'tosa.transpose_conv2d' op attribute 'out_shape' failed to satisfy constraint: i64 dense array attribute with at least 4 elements}}
+ %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2 {out_pad = array<i64: 0, 0, 0, 0>, out_shape = array<i64: 1, 32, 32>, stride = array<i64: 1, 1>} : (tensor<1x32x32x8xf32>, tensor<16x1x1x8xf32>, tensor<16xf32>) -> tensor<1x32x32x16xf32>
+ return %0 : tensor<1x32x32x16xf32>
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/108041
More information about the Mlir-commits
mailing list