[Mlir-commits] [mlir] [MLIR][TOSA] add additional verification to TOSA (PR #108133)
Arteen Abrishami
llvmlistbot at llvm.org
Tue Sep 10 19:02:52 PDT 2024
https://github.com/arteen1000 created https://github.com/llvm/llvm-project/pull/108133
----------
Motivation:
----------
Spec conformance. Allows assumptions to be made in TOSA code.
------------
Changes Made:
------------
Add full permutation tensor verification to tosa.TRANSPOSE. Priorly would not verify that permuted values were between 0 - (rank - 1).
Update tosa.TRANSPOSE perms data type to be strictly i32.
Verify input/output shapes for tosa.TRANSPOSE.
Add verifier to tosa.CONST, with consideration for quantization.
Fix TOSA conformance of tensor type to disallow dimensions with size 0 for ranked tensors, per spec.
This is not the same as rank 0 tensors. Here is an example of a disallowed tensor: tensor<3x0xi32>. Naturally, this means that the number of elements in a TOSA tensor will always be greater than 0.
>From 925f1f3f77fe437ddb932f0406be672b933b43cf Mon Sep 17 00:00:00 2001
From: Arteen Abrishami <arteen.abrishami at arm.com>
Date: Thu, 5 Sep 2024 20:15:51 +0000
Subject: [PATCH] [MLIR][TOSA] add additional verification to TOSA
----------
Motivation:
----------
Spec conformance. Allows assumptions to be made in TOSA
code.
------------
Changes Made:
------------
Add full permutation tensor verification to tosa.TRANSPOSE.
Priorly would not verify that permuted values were between 0 - (rank - 1).
Update tosa.TRANSPOSE perms data type to be strictly i32.
Verify input/output shapes for tosa.TRANSPOSE.
Add verifier to tosa.CONST, with consideration for quantization.
Fix TOSA conformance of tensor type to disallow dimensions with size 0
for ranked tensors, per spec.
This is not the same as rank 0 tensors. Here is an example of a disallowed
tensor: tensor<3x0xi32>. Naturally, this means that the number of elements
in a TOSA tensor will always be greater than 0.
Signed-off-by: Arteen Abrishami <arteen.abrishami at arm.com>
---
.../mlir/Dialect/Tosa/IR/CMakeLists.txt | 4 +-
mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td | 58 ++++----
.../mlir/Dialect/Tosa/IR/TosaTypesBase.td | 61 +++++---
.../mlir/Dialect/Tosa/Utils/ConversionUtils.h | 13 ++
.../TosaToLinalg/TosaToLinalgNamed.cpp | 21 +--
.../Dialect/Tosa/IR/TosaCanonicalizations.cpp | 64 ++++++---
mlir/lib/Dialect/Tosa/IR/TosaOps.cpp | 135 ++++++++++--------
.../TosaToLinalg/tosa-to-linalg-named.mlir | 10 +-
.../TosaToLinalg/tosa-to-linalg-pipeline.mlir | 2 +-
mlir/test/Dialect/Tosa/constant-op-fold.mlir | 4 +-
mlir/test/Dialect/Tosa/invalid.mlir | 74 ++++++++--
mlir/test/Dialect/Tosa/ops.mlir | 16 +++
12 files changed, 301 insertions(+), 161 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/Tosa/IR/CMakeLists.txt
index 12b4fc402c390f..1ee105f0ceb98b 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/Tosa/IR/CMakeLists.txt
@@ -3,8 +3,8 @@ add_mlir_doc(TosaOps TosaOps Dialects/ -gen-op-doc)
add_mlir_interface(TosaInterfaces)
set(LLVM_TARGET_DEFINITIONS TosaOps.td)
-mlir_tablegen(TosaAttributes.h.inc -gen-attrdef-decls)
-mlir_tablegen(TosaAttributes.cpp.inc -gen-attrdef-defs)
+mlir_tablegen(TosaAttributes.h.inc -gen-attrdef-decls -attrdefs-dialect=tosa)
+mlir_tablegen(TosaAttributes.cpp.inc -gen-attrdef-defs -attrdefs-dialect=tosa)
add_public_tablegen_target(MLIRTosaAttributesIncGen)
set(LLVM_TARGET_DEFINITIONS TosaDialectBytecode.td)
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index ab6daa39708d13..63572f287b7dde 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -73,7 +73,6 @@ def Tosa_AvgPool2dOp : Tosa_InferShapedTypeOp<"avg_pool2d"> {
let arguments = (ins
Tosa_Tensor4D:$input,
-
Tosa_IntArrayAttr2:$kernel,
Tosa_IntArrayAttr2:$stride,
Tosa_IntArrayAttr4:$pad,
@@ -102,9 +101,8 @@ def Tosa_Conv2DOp : Tosa_InferShapedTypeOp<"conv2d"> {
let arguments = (ins
Tosa_Tensor4D:$input,
- 4DTensorOf<[Tosa_Weight]>:$weight,
+ TosaTensorRankOf<[Tosa_Weight], [4]>:$weight,
Tosa_Tensor1D:$bias,
-
Tosa_IntArrayAttr4:$pad,
Tosa_IntArrayAttr2:$stride,
Tosa_IntArrayAttr2:$dilation,
@@ -132,9 +130,8 @@ def Tosa_Conv3DOp : Tosa_InferShapedTypeOp<"conv3d"> {
let arguments = (ins
Tosa_Tensor5D:$input,
- TensorRankOf<[Tosa_Weight], [5]>:$weight,
+ TosaTensorRankOf<[Tosa_Weight], [5]>:$weight,
Tosa_Tensor1D:$bias,
-
Tosa_IntArrayAttr6:$pad,
Tosa_IntArrayAttr3:$stride,
Tosa_IntArrayAttr3:$dilation,
@@ -163,9 +160,8 @@ def Tosa_DepthwiseConv2DOp : Tosa_InferShapedTypeOp<"depthwise_conv2d"> {
let arguments = (ins
Tosa_Tensor4D:$input,
- 4DTensorOf<[Tosa_Weight]>:$weight,
+ TosaTensorRankOf<[Tosa_Weight], [4]>:$weight,
Tosa_Tensor1D:$bias,
-
Tosa_IntArrayAttr4:$pad,
Tosa_IntArrayAttr2:$stride,
Tosa_IntArrayAttr2:$dilation,
@@ -232,7 +228,7 @@ def Tosa_FullyConnectedOp : Tosa_InferShapedTypeOp<"fully_connected"> {
let arguments = (ins
Tosa_Tensor2D:$input,
- 2DTensorOf<[Tosa_Weight]>:$weight,
+ TosaTensorRankOf<[Tosa_Weight], [2]>:$weight,
Tosa_Tensor1D:$bias,
OptionalAttr<Tosa_ConvOpQuantizationAttr>:$quantization_info
);
@@ -347,9 +343,8 @@ def Tosa_TransposeConv2DOp : Tosa_InferShapedTypeOp<"transpose_conv2d"> {
let arguments = (ins
Tosa_Tensor4D:$input,
- 4DTensorOf<[Tosa_Weight]>:$filter,
+ TosaTensorRankOf<[Tosa_Weight], [4]>:$filter,
Tosa_Tensor1D:$bias,
-
Tosa_IntArrayAttr4:$out_pad,
Tosa_IntArrayAttr2:$stride,
Tosa_IntArrayAttrUpto4:$out_shape,
@@ -641,12 +636,12 @@ def Tosa_LogicalAndOp : Tosa_ElementwiseOp<"logical_and", [
}];
let arguments = (ins
- I1Tensor:$input1,
- I1Tensor:$input2
+ Tosa_I1Tensor:$input1,
+ Tosa_I1Tensor:$input2
);
let results = (outs
- I1Tensor:$z
+ Tosa_I1Tensor:$z
);
}
@@ -708,12 +703,12 @@ def Tosa_LogicalOrOp : Tosa_ElementwiseOp<"logical_or", [
}];
let arguments = (ins
- I1Tensor:$input1,
- I1Tensor:$input2
+ Tosa_I1Tensor:$input1,
+ Tosa_I1Tensor:$input2
);
let results = (outs
- I1Tensor:$z
+ Tosa_I1Tensor:$z
);
}
@@ -731,12 +726,12 @@ def Tosa_LogicalXorOp : Tosa_ElementwiseOp<"logical_xor", [
}];
let arguments = (ins
- I1Tensor:$input1,
- I1Tensor:$input2
+ Tosa_I1Tensor:$input1,
+ Tosa_I1Tensor:$input2
);
let results = (outs
- I1Tensor:$z
+ Tosa_I1Tensor:$z
);
}
@@ -1085,11 +1080,11 @@ def Tosa_LogicalNotOp : Tosa_ElementwiseOp<"logical_not",
}];
let arguments = (ins
- I1Tensor:$input1
+ Tosa_I1Tensor:$input1
);
let results = (outs
- I1Tensor:$output
+ Tosa_I1Tensor:$output
);
}
@@ -1208,7 +1203,7 @@ def Tosa_SelectOp : Tosa_ElementwiseOp<"select"> {
}];
let arguments = (ins
- I1Tensor:$pred,
+ Tosa_I1Tensor:$pred,
Tosa_Tensor:$on_true,
Tosa_Tensor:$on_false
);
@@ -1249,7 +1244,7 @@ def Tosa_EqualOp : Tosa_ElementwiseOp<"equal", [
);
let results = (outs
- I1Tensor:$output
+ Tosa_I1Tensor:$output
);
let extraClassDeclaration = [{
@@ -1277,7 +1272,7 @@ def Tosa_GreaterOp : Tosa_ElementwiseOp<"greater", [SameOperandsElementType]> {
);
let results = (outs
- I1Tensor:$output
+ Tosa_I1Tensor:$output
);
let hasFolder = 1;
@@ -1300,7 +1295,7 @@ def Tosa_GreaterEqualOp : Tosa_ElementwiseOp<"greater_equal",
);
let results = (outs
- I1Tensor:$output
+ Tosa_I1Tensor:$output
);
let hasFolder = 1;
@@ -1721,7 +1716,7 @@ def Tosa_TransposeOp : Tosa_InferShapedTypeOp<"transpose",
let arguments = (ins
Tosa_Tensor:$input1,
- Tosa_Int32Or64Tensor:$perms
+ Tosa_Int32Tensor:$perms
);
let results = (
@@ -1729,7 +1724,7 @@ def Tosa_TransposeOp : Tosa_InferShapedTypeOp<"transpose",
);
let extraClassDeclaration = [{
- LogicalResult getConstantPerms(llvm::SmallVector<int64_t> &perms);
+ LogicalResult getConstantPerms(llvm::SmallVector<int32_t> &perms);
}];
let hasCanonicalizer = 1;
@@ -1755,7 +1750,7 @@ def Tosa_GatherOp : Tosa_InferShapedTypeOp<"gather"> {
let arguments = (ins
Tosa_Tensor3D:$values,
- 2DTensorOf<[Tosa_Int32]>:$indices
+ TosaTensorRankOf<[Tosa_Int32], [2]>:$indices
);
let results = (outs
@@ -1776,7 +1771,7 @@ def Tosa_ScatterOp : Tosa_InferShapedTypeOp<"scatter"> {
let arguments = (ins
Tosa_Tensor3D:$values_in,
- 2DTensorOf<[Tosa_Int32]>:$indices,
+ TosaTensorRankOf<[Tosa_Int32], [2]>:$indices,
Tosa_Tensor3D:$input
);
@@ -1947,10 +1942,11 @@ def Tosa_ConstOp : Tosa_Op<"const", [ConstantLike, Pure,
);
let results = (outs
- TensorOf<[AnyTypeOf<[Tosa_AnyNumber]>]>:$output
+ TosaTensorOf<[AnyTypeOf<[Tosa_AnyNumber]>]>:$output
);
let hasFolder = 1;
+ let hasVerifier = 1;
}
//===----------------------------------------------------------------------===//
@@ -2054,7 +2050,7 @@ def Tosa_IfOp : Tosa_Op<"cond_if",
}];
let arguments = (ins
- I1Tensor:$cond,
+ Tosa_I1Tensor:$cond,
Variadic<Tosa_Tensor>:$inputs
);
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
index 14fc9c7a6730cc..c3a0128e95a84b 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
@@ -82,58 +82,83 @@ def Tosa_AnyNumber : AnyTypeOf<[Tosa_Int, Tosa_QuantizedInt, AnyFloat],
def Tosa_Weight : AnyTypeOf<[Tosa_Int4, Tosa_Int8,
Tosa_QuantizedInt, AnyFloat]>;
+//===----------------------------------------------------------------------===//
+// TOSA Tensor Conformance
+//===----------------------------------------------------------------------===//
+
+def HasNo0Dimensions : And<[
+ IsRankedTensorTypePred,
+ CPred<"::llvm::all_of(::llvm::cast<::mlir::RankedTensorType>($_self).getShape(), [](auto v) { return v != 0; })">]>;
+
+class TosaTensorOf<
+ list<Type> allowedTypes, string summary = "tosa-conformant tensor">
+ : TensorOf<allowedTypes, [Or<[HasNo0Dimensions, IsUnrankedTensorTypePred]>], summary>;
+
+class TosaRankedTensorOf<
+ list<Type> allowedTypes, list<Pred> preds = [], string summary = "tosa-conformant ranked tensor">
+ : RankedTensorOf<allowedTypes, !listconcat([HasNo0Dimensions], preds), summary>;
+
+class TosaUnrankedTensorOf<list<Type> allowedTypes, list<Pred> preds = [], string summary = "tosa-conformant unranked tensor">
+ : UnrankedTensorOf<allowedTypes, preds, summary>;
+
+class TosaTensorRankOf<list<Type> allowedTypes, list<int> ranks>
+ : TosaRankedTensorOf<allowedTypes,
+ [HasAnyRankOfPred<ranks>],
+ !interleave(!foreach(rank, ranks, rank # "D"), "/") # " tensor">;
+
//===----------------------------------------------------------------------===//
// Tensor types
//===----------------------------------------------------------------------===//
-def Tosa_Int32Tensor : TensorOf<[Tosa_Int32]>;
-def Tosa_Int32Or64Tensor : TensorOf<[Tosa_Int32Or64]>;
+def Tosa_I1Tensor : TosaTensorOf<[I1]>;
+def Tosa_Int32Tensor : TosaTensorOf<[Tosa_Int32]>;
+def Tosa_Int32Or64Tensor :TosaTensorOf<[Tosa_Int32Or64]>;
-def Tosa_FloatTensor : TensorOf<[AnyFloat]>;
+def Tosa_FloatTensor : TosaTensorOf<[AnyFloat]>;
// Either ranked or unranked tensor of TOSA supported element types.
-def Tosa_Tensor : TensorOf<[Tosa_AnyNumber]>;
+def Tosa_Tensor : TosaTensorOf<[Tosa_AnyNumber]>;
// Must be ranked but no further constraints
-def Tosa_RankedTensor : RankedTensorOf<[Tosa_AnyNumber]>;
+def Tosa_RankedTensor : TosaRankedTensorOf<[Tosa_AnyNumber]>;
// Any tensor element type allowed in Tosa ops.
def Tosa_ElementType : Type<Or<[Tosa_Int.predicate, Tosa_QuantizedInt.predicate,
AnyFloat.predicate]>, "tosa.dtype">;
class Tosa_TensorOfOrNone<list<Type> allowedTypes, string description = ""> :
- AnyTypeOf<[TensorOf<allowedTypes>, NoneType], description>;
+ AnyTypeOf<[TosaTensorOf<allowedTypes>, NoneType], description>;
//===----------------------------------------------------------------------===//
// Tensor types with constrained ranks.
//===----------------------------------------------------------------------===//
// Rank-0 (scalar) tensor
-def Tosa_ScalarTensor : TensorRankOf<[Tosa_AnyNumber], [0]>;
+def Tosa_ScalarTensor : TosaTensorRankOf<[Tosa_AnyNumber], [0]>;
// We include unranked tensors as a supported type for all possible tosa
// Tensors as unranked does not guarantee invalid. If unranked tensors exist
// they should be shape propagate used Tosa's shape inference pass and verified
// to not include any remaining unranked tensors.
-def Tosa_UnrankedTensor : UnrankedTensorOf<[Tosa_AnyNumber]>;
+def Tosa_UnrankedTensor : TosaUnrankedTensorOf<[Tosa_AnyNumber]>;
-def Tosa_Tensor1D : AnyTypeOf<[Tosa_UnrankedTensor, 1DTensorOf<[Tosa_AnyNumber]>], "1-d tensor", "::mlir::TensorType">;
-def Tosa_Tensor2D : AnyTypeOf<[Tosa_UnrankedTensor, 2DTensorOf<[Tosa_AnyNumber]>], "2-d tensor", "::mlir::TensorType">;
-def Tosa_Tensor3D : AnyTypeOf<[Tosa_UnrankedTensor, 3DTensorOf<[Tosa_AnyNumber]>], "3-d tensor", "::mlir::TensorType">;
-def Tosa_Tensor4D : AnyTypeOf<[Tosa_UnrankedTensor, 4DTensorOf<[Tosa_AnyNumber]>], "4-d tensor", "::mlir::TensorType">;
-def Tosa_Tensor5D : AnyTypeOf<[Tosa_UnrankedTensor, TensorRankOf<[Tosa_AnyNumber], [5]>], "5-d tensor", "::mlir::TensorType">;
+def Tosa_Tensor1D : AnyTypeOf<[Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_AnyNumber], [1]>], "1-d tosa-conformant tensor", "::mlir::TensorType">;
+def Tosa_Tensor2D : AnyTypeOf<[Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_AnyNumber], [2]>], "2-d tosa-conformant tensor", "::mlir::TensorType">;
+def Tosa_Tensor3D : AnyTypeOf<[Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_AnyNumber], [3]>], "3-d tosa-conformant tensor", "::mlir::TensorType">;
+def Tosa_Tensor4D : AnyTypeOf<[Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_AnyNumber], [4]>], "4-d tosa-conformant tensor", "::mlir::TensorType">;
+def Tosa_Tensor5D : AnyTypeOf<[Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_AnyNumber], [5]>], "5-d tosa-conformant tensor", "::mlir::TensorType">;
// Ranked tensors up to given rank.
def Tosa_Tensor1Dto4D : AnyTypeOf<[
- Tosa_UnrankedTensor, TensorRankOf<[Tosa_AnyNumber], [1,2,3,4]>]>;
+ Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_AnyNumber], [1,2,3,4]>]>;
def Tosa_Tensor1Dto6D : AnyTypeOf<[
- Tosa_UnrankedTensor, TensorRankOf<[Tosa_AnyNumber], [1,2,3,4,5,6]>]>;
+ Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_AnyNumber], [1,2,3,4,5,6]>]>;
def Tosa_TensorUpto4D : AnyTypeOf<[
- Tosa_UnrankedTensor, TensorRankOf<[Tosa_AnyNumber], [0,1,2,3,4]>]>;
+ Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_AnyNumber], [0,1,2,3,4]>]>;
def Tosa_Int32TensorUpto4D : AnyTypeOf<[
- Tosa_UnrankedTensor, TensorRankOf<[Tosa_Int32], [0,1,2,3,4]>]>;
+ Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_Int32], [0,1,2,3,4]>]>;
//===----------------------------------------------------------------------===//
// Generic scalar, vector, or tensor of a particular type.
@@ -142,7 +167,7 @@ def Tosa_Int32TensorUpto4D : AnyTypeOf<[
class Tosa_TypeLike<list<Type> types, string description = ""> : TypeConstraint<Or<[
AnyTypeOf<types>.predicate,
VectorOf<types>.predicate,
- TensorOf<types>.predicate]>,
+ TosaTensorOf<types>.predicate]>,
description>;
def Tosa_IntLike : Tosa_TypeLike<[Tosa_Int], "signless-integer-like">;
diff --git a/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h b/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h
index ef40b348ab5499..90fea1f68beb58 100644
--- a/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h
+++ b/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h
@@ -216,6 +216,19 @@ TosaOp CreateOpAndInferShape(PatternRewriter &rewriter, Location loc,
return CreateOpAndInferShape<TosaOp>(builder, resultTy, args...);
}
+// Apply an int32_t permutation to some input, that should be of the same
+// size as perms. Perms should contain some permutation of 0 - perms.size() - 1.
+template <typename T>
+SmallVector<T> applyTOSAPermutation(ArrayRef<T> input,
+ ArrayRef<int32_t> perms) {
+ SmallVector<T> permuted;
+ size_t N = input.size();
+ permuted.resize_for_overwrite(N);
+ for (size_t i = 0; i < N; i++)
+ permuted[i] = input[perms[i]];
+ return permuted;
+}
+
} // namespace tosa
} // namespace mlir
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
index 77c3d2e8757910..fe53b499674324 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
@@ -313,7 +313,7 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
// convolution operation.
// TODO(suderman): See if this can be efficiently folded - check whether
// the input is used anywhere else, if not fold the constant.
- SmallVector<int64_t> weightPerm;
+ SmallVector<int32_t> weightPerm;
for (int i = 1; i < resultTy.getRank(); i++)
weightPerm.push_back(i);
weightPerm.push_back(0);
@@ -321,7 +321,7 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
SmallVector<int64_t> newWeightShape;
for (auto dim : weightPerm)
newWeightShape.push_back(weightShape[dim]);
- auto weightPermAttr = rewriter.getI64TensorAttr(weightPerm);
+ auto weightPermAttr = rewriter.getI32TensorAttr(weightPerm);
Value weightPermValue =
rewriter.create<arith::ConstantOp>(loc, weightPermAttr);
Type newWeightTy =
@@ -337,7 +337,7 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
if (5 == inputTy.getRank()) {
// TODO(suderman): See if this can be efficiently folded - check whether
// the input is used anywhere else, if not fold the constant.
- SmallVector<int64_t> weightPerm;
+ SmallVector<int32_t> weightPerm;
for (int i = 1; i < resultTy.getRank(); i++)
weightPerm.push_back(i);
weightPerm.push_back(0);
@@ -345,7 +345,7 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
SmallVector<int64_t> newWeightShape;
for (auto dim : weightPerm)
newWeightShape.push_back(weightShape[dim]);
- auto weightPermAttr = rewriter.getI64TensorAttr(weightPerm);
+ auto weightPermAttr = rewriter.getI32TensorAttr(weightPerm);
Value weightPermValue =
rewriter.create<arith::ConstantOp>(loc, weightPermAttr);
Type newWeightTy =
@@ -1040,22 +1040,25 @@ class TransposeConverter : public OpRewritePattern<tosa::TransposeOp> {
LogicalResult matchAndRewrite(tosa::TransposeOp op,
PatternRewriter &rewriter) const final {
- SmallVector<int64_t> constantPerms;
+ SmallVector<int32_t> constantPerms;
if (failed(op.getConstantPerms(constantPerms)))
return failure();
Location loc = op.getLoc();
- // The verifier should have made sure we have a valid permutation tensor.
- assert(isPermutationVector(constantPerms) && "Expected valid permutation");
+ // The verifier should have made sure we have a valid TOSA permutation
+ // tensor. isPermutationVector doesn't actually check the TOSA perms we
+ // expect.
SmallVector<OpFoldResult> inputSizes =
tensor::getMixedSizes(rewriter, loc, op.getInput1());
auto permutedSizes =
- applyPermutation<OpFoldResult>(inputSizes, constantPerms);
+ applyTOSAPermutation<OpFoldResult>(inputSizes, constantPerms);
auto permutedInit = rewriter.create<tensor::EmptyOp>(
loc, permutedSizes, op.getInput1().getType().getElementType());
rewriter.replaceOpWithNewOp<linalg::TransposeOp>(
- op, op.getInput1(), permutedInit, constantPerms);
+ op, op.getInput1(), permutedInit,
+ llvm::to_vector(llvm::map_range(
+ constantPerms, [](int32_t v) -> int64_t { return v; })));
return success();
}
};
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index da9a93feac4d65..03876a7c64d07c 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -88,7 +88,7 @@ struct ConsolidateTransposeOptimization
return rewriter.notifyMatchFailure(transposeOp,
"input must be transpose operation");
- SmallVector<int64_t> transposePerms, innerTransposePerms;
+ SmallVector<int32_t> transposePerms, innerTransposePerms;
if (transposeOp.getConstantPerms(transposePerms).failed())
return rewriter.notifyMatchFailure(transposeOp,
"transpose perms must be constant");
@@ -497,8 +497,10 @@ OpFoldResult AddOp::fold(FoldAdaptor adaptor) {
return {};
auto resultETy = resultTy.getElementType();
- auto lhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
- auto rhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
+ auto lhsAttr =
+ llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
+ auto rhsAttr =
+ llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
if (lhsTy == resultTy && isSplatZero(resultETy, rhsAttr))
return getInput1();
@@ -536,8 +538,10 @@ OpFoldResult IntDivOp::fold(FoldAdaptor adaptor) {
// IntDivOp inputs must be integer type, no need to check for quantized type
auto resultETy = resultTy.getElementType();
- auto lhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
- auto rhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
+ auto lhsAttr =
+ llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
+ auto rhsAttr =
+ llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
if (lhsAttr && lhsAttr.isSplat()) {
if (llvm::isa<IntegerType>(resultETy) &&
lhsAttr.getSplatValue<APInt>().isZero())
@@ -605,10 +609,13 @@ OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
return {};
auto resultETy = resultTy.getElementType();
- auto lhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
- auto rhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
+ auto lhsAttr =
+ llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
+ auto rhsAttr =
+ llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
const int64_t shift = llvm::isa<IntegerType>(resultETy) ? getShift() : 0;
+
if (rhsTy == resultTy) {
if (isSplatZero(resultETy, lhsAttr))
return lhsAttr.resizeSplat(resultTy);
@@ -638,8 +645,10 @@ OpFoldResult SubOp::fold(FoldAdaptor adaptor) {
return {};
auto resultETy = resultTy.getElementType();
- auto lhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
- auto rhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
+ auto lhsAttr =
+ llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
+ auto rhsAttr =
+ llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
if (lhsTy == resultTy && isSplatZero(resultETy, rhsAttr))
return getInput1();
@@ -681,8 +690,10 @@ struct APIntFoldGreaterEqual {
OpFoldResult GreaterOp::fold(FoldAdaptor adaptor) {
auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
- auto lhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
- auto rhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
+ auto lhsAttr =
+ llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
+ auto rhsAttr =
+ llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
if (!lhsAttr || !rhsAttr)
return {};
@@ -693,8 +704,10 @@ OpFoldResult GreaterOp::fold(FoldAdaptor adaptor) {
OpFoldResult GreaterEqualOp::fold(FoldAdaptor adaptor) {
auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
- auto lhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
- auto rhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
+ auto lhsAttr =
+ llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
+ auto rhsAttr =
+ llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
if (!lhsAttr || !rhsAttr)
return {};
@@ -706,8 +719,10 @@ OpFoldResult GreaterEqualOp::fold(FoldAdaptor adaptor) {
OpFoldResult EqualOp::fold(FoldAdaptor adaptor) {
auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
- auto lhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
- auto rhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
+ auto lhsAttr =
+ llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
+ auto rhsAttr =
+ llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
Value lhs = getInput1();
Value rhs = getInput2();
auto lhsTy = llvm::cast<ShapedType>(lhs.getType());
@@ -838,14 +853,16 @@ OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) {
return {};
// reshape(const(x)) -> const(reshape-attr(x))
- if (auto operand = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1())) {
+ if (auto operand =
+ llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1())) {
// Constants must have static shape.
if (!outputTy.hasStaticShape())
return {};
// Okay to duplicate splat constants.
if (operand.isSplat())
- return SplatElementsAttr::get(outputTy, operand.getSplatValue<Attribute>());
+ return SplatElementsAttr::get(outputTy,
+ operand.getSplatValue<Attribute>());
// Don't duplicate other constants.
if (!getInput1().hasOneUse())
@@ -905,7 +922,8 @@ OpFoldResult ReverseOp::fold(FoldAdaptor adaptor) {
auto operand = getInput();
auto operandTy = llvm::cast<ShapedType>(operand.getType());
auto axis = getAxis();
- auto operandAttr = llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getInput());
+ auto operandAttr =
+ llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getInput());
if (operandAttr)
return operandAttr;
@@ -954,7 +972,8 @@ OpFoldResult tosa::SelectOp::fold(FoldAdaptor adaptor) {
if (getOnTrue() == getOnFalse())
return getOnTrue();
- auto predicate = llvm::dyn_cast_if_present<DenseIntElementsAttr>(adaptor.getPred());
+ auto predicate =
+ llvm::dyn_cast_if_present<DenseIntElementsAttr>(adaptor.getPred());
if (!predicate)
return {};
@@ -975,7 +994,8 @@ OpFoldResult TransposeOp::fold(FoldAdaptor adaptor) {
auto resultTy = llvm::cast<ShapedType>(getType());
// Transposing splat values just means reshaping.
- if (auto input = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1())) {
+ if (auto input =
+ llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1())) {
if (input.isSplat() && resultTy.hasStaticShape() &&
input.getType().getElementType() == resultTy.getElementType())
return input.reshape(resultTy);
@@ -986,11 +1006,11 @@ OpFoldResult TransposeOp::fold(FoldAdaptor adaptor) {
return {};
// Transpose is not the identity transpose.
- SmallVector<int64_t> perms;
+ SmallVector<int32_t> perms;
if (getConstantPerms(perms).failed())
return {};
- if (!llvm::equal(llvm::seq<int64_t>(0, perms.size()), perms))
+ if (!llvm::equal(llvm::seq<int32_t>(0, perms.size()), perms))
return {};
return getInput1();
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index d93db1b237f316..0d0241fea5152c 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -204,22 +204,6 @@ void mlir::tosa::printTypeOrAttr(OpAsmPrinter &p, Operation *op, TypeAttr type,
// TOSA Operator Verifiers.
//===----------------------------------------------------------------------===//
-static bool hasZeroDimension(ShapedType shapedType) {
- if (!shapedType.hasRank())
- return false;
-
- auto rank = shapedType.getRank();
-
- for (int i = 0; i < rank; i++) {
- if (shapedType.isDynamicDim(i))
- continue;
- if (shapedType.getDimSize(i) == 0)
- return true;
- }
-
- return false;
-}
-
template <typename T>
static LogicalResult verifyConvOp(T op) {
// All TOSA conv ops have an input() and weight().
@@ -236,10 +220,6 @@ static LogicalResult verifyConvOp(T op) {
return failure();
}
- if (hasZeroDimension(inputType))
- return op.emitOpError() << "tensor has a dimension with size zero. Each "
- "dimension of a tensor must have size >= 1";
-
auto inputEType = inputType.getElementType();
auto weightEType = weightType.getElementType();
@@ -262,6 +242,29 @@ static LogicalResult verifyConvOp(T op) {
"allowed for float type");
return failure();
}
+ return success();
+}
+
+LogicalResult tosa::ConstOp::verify() {
+
+ auto attrType = llvm::dyn_cast<TensorType>(getValueAttr().getType());
+ auto outputType = llvm::dyn_cast<TensorType>(getOutput().getType());
+
+ if (!attrType || !outputType) {
+ emitOpError("expected tensors for attr/result type");
+ return failure();
+ }
+
+ if (auto result = llvm::dyn_cast<mlir::quant::QuantizedType>(
+ outputType.getElementType())) {
+ if (result.getStorageType() == attrType.getElementType())
+ return success();
+ }
+
+ if (attrType.getElementType() != outputType.getElementType()) {
+ emitOpError("expected same attr/result element types");
+ return failure();
+ }
return success();
}
@@ -283,9 +286,6 @@ LogicalResult tosa::ArgMaxOp::verify() {
LogicalResult tosa::AvgPool2dOp::verify() {
auto inputType = llvm::cast<ShapedType>(getInput().getType());
- if (hasZeroDimension(inputType))
- return emitOpError() << "tensor has a dimension with size zero. Each "
- "dimension of a tensor must have size >= 1";
auto inputETy = inputType.getElementType();
auto resultETy = llvm::cast<ShapedType>(getType()).getElementType();
@@ -341,9 +341,9 @@ LogicalResult tosa::ClampOp::verify() {
if (inputETy != outputETy)
return emitOpError("input/output element types are incompatible.");
- // if input datatype is float, check that the two min/max_fp attributes share
- // the same type and that their type is either the same of the input's
- // datatype, or a float type whose bitwidth > input datatype bitwidth
+ // If input datatype is float, check that the two min/max_fp attributes
+ // share the same type and that their type is either the same of the input's
+ // datatype, or a float type whose bitwidth > input datatype bitwidth.
if (!inputETy.isInteger(dataTypeBitWidth)) {
if (((maxFpType != minFpType) ||
(maxFpType != inputETy && maxFpType.getIntOrFloatBitWidth() <=
@@ -383,7 +383,8 @@ static void buildConvOpWithQuantInfo(OpBuilder &builder, OperationState &result,
}
}
-/// Handles tosa.transpose_conv2d which has outpad and output shape attributes.
+/// Handles tosa.transpose_conv2d which has outpad and output shape
+/// attributes.
static void buildTransConvOpWithQuantInfo(
OpBuilder &builder, OperationState &result, Type outputType, Value input,
Value weight, Value bias, DenseI64ArrayAttr outpad,
@@ -420,9 +421,9 @@ static void buildFCOpWithQuantInfo(OpBuilder &builder, OperationState &result,
}
}
-/// The tosa.matmul op is also intended to be generated where a fully_connected
-/// op must be constructed where the weight is not a constant. In this case,
-/// the fully_connected op must be expressed using matmul.
+/// The tosa.matmul op is also intended to be generated where a
+/// fully_connected op must be constructed where the weight is not a constant.
+/// In this case, the fully_connected op must be expressed using matmul.
/// TODO: Add link to the leglization document explaining this.
static void buildMatMulOpWithQuantInfo(OpBuilder &builder,
OperationState &result, Type outputType,
@@ -457,9 +458,9 @@ static void buildMatMulOpWithQuantInfo(OpBuilder &builder,
}
}
-/// Both the tosa.avg_pool2d and unary ops use the same UnaruOpQuantizationAttr
-/// but avg_pool operator has its own builder as it has additional parameters
-/// not part of the unary ops.
+/// Both the tosa.avg_pool2d and unary ops use the same
+/// UnaruOpQuantizationAttr but avg_pool operator has its own builder as it
+/// has additional parameters not part of the unary ops.
static void
buildAvgPool2dOpWithQuantInfo(OpBuilder &builder, OperationState &result,
Type outputType, Value input,
@@ -526,8 +527,8 @@ static LogicalResult resolveBroadcastShape(const ValueShapeRange &operands,
for (int i = 0, e = operands.size(); i != e; ++i) {
auto shape = operands.getShape(i);
if (!shape.hasRank()) {
- // TODO(jennik): Update function to have better case handling for invalid
- // operands and for ranked tensors.
+ // TODO(jennik): Update function to have better case handling for
+ // invalid operands and for ranked tensors.
return failure();
}
outRank = std::max<int64_t>(outRank, shape.getRank());
@@ -776,8 +777,8 @@ LogicalResult tosa::PadOp::inferReturnTypeComponents(
return success();
}
- // If the input rank is unknown we can info the output rank using the padding
- // shape's first dim.
+ // If the input rank is unknown we can info the output rank using the
+ // padding shape's first dim.
if (!inputShape.hasRank()) {
if (paddingShape.isDynamicDim(0)) {
inferredReturnShapes.push_back(ShapedTypeComponents());
@@ -1000,10 +1001,6 @@ llvm::LogicalResult tosa::ReshapeOp::verify() {
TensorType inputType = getInput1().getType();
RankedTensorType outputType = getType();
- if (hasZeroDimension(inputType) || hasZeroDimension(outputType))
- return emitOpError() << "tensor has a dimension with size zero. Each "
- "dimension of a tensor must have size >= 1";
-
if ((int64_t)getNewShape().size() != outputType.getRank())
return emitOpError() << "new shape does not match result rank";
@@ -1034,16 +1031,15 @@ llvm::LogicalResult tosa::ReshapeOp::verify() {
return mlir::success();
}
-LogicalResult tosa::TransposeOp::getConstantPerms(SmallVector<int64_t> &perms) {
+LogicalResult tosa::TransposeOp::getConstantPerms(SmallVector<int32_t> &perms) {
// Perms must be constants.
DenseIntElementsAttr permsAttr;
if (!matchPattern(getPerms(), m_Constant(&permsAttr)))
return failure();
- // Transpose is not the identity transpose.
- perms = llvm::to_vector(
- llvm::map_range(permsAttr.getValues<APInt>(),
- [](const APInt &val) { return val.getSExtValue(); }));
+ perms.clear();
+ for (auto v : permsAttr.getValues<APInt>())
+ perms.push_back(v.getSExtValue());
return success();
}
@@ -1067,8 +1063,8 @@ LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
return success();
}
- // This would imply the number of permutations does not match the rank of the
- // input which is illegal.
+ // This would imply the number of permutations does not match the rank of
+ // the input which is illegal.
if (permsShape.getDimSize(0) != inputShape.getRank()) {
return failure();
}
@@ -1154,19 +1150,38 @@ LogicalResult tosa::TransposeOp::verify() {
<< " (output rank) but got size "
<< permType.getDimSize(0);
- SmallVector<int64_t> constantPerms;
+ SmallVector<int32_t> constantPerms;
if (succeeded(getConstantPerms(constantPerms))) {
- // Assert that the permutation tensor has a rank, which means that the rank
- // has been verified above.
+ // Assert that the permutation tensor has a rank, which means that the
+ // rank has been verified above.
assert(permType.hasRank() &&
"Unexpectedly found permutation tensor without rank");
- if (!isPermutationVector(constantPerms))
+ if (!llvm::all_of(constantPerms,
+ [&constantPerms](int32_t s) {
+ return s >= 0 &&
+ static_cast<size_t>(s) < constantPerms.size();
+ }) ||
+ !isPermutationVector(llvm::to_vector(llvm::map_range(
+ constantPerms, [](int32_t v) -> int64_t { return v; }))))
return emitOpError() << "expected valid permutation tensor";
- if (inputType.hasRank() && !llvm::all_of(constantPerms, [&](int64_t s) {
- return s < inputType.getRank();
- })) {
- return emitOpError() << "permutation must be within input bounds";
+ // Verify that the types of the input and output tensors are properly
+ // permuted.
+ if (inputType.hasRank() && outputType.hasRank()) {
+ assert(constantPerms.size() == static_cast<size_t>(inputType.getRank()) &&
+ inputType.getRank() == outputType.getRank());
+
+ for (auto i = 0; i < outputType.getRank(); i++) {
+ if (inputType.isDynamicDim(constantPerms[i]) ||
+ outputType.isDynamicDim(i))
+ continue;
+
+ if (inputType.getDimSize(constantPerms[i]) != outputType.getDimSize(i))
+ return emitOpError()
+ << "expected output tensor dim " << i << " to match "
+ << "input dim " << constantPerms[i] << " with value of "
+ << inputType.getDimSize(constantPerms[i]);
+ }
}
}
return success();
@@ -1175,7 +1190,7 @@ LogicalResult tosa::TransposeOp::verify() {
LogicalResult TransposeOp::reifyResultShapes(
OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
- SmallVector<int64_t> transposePerms;
+ SmallVector<int32_t> transposePerms;
if (getConstantPerms(transposePerms).failed())
return failure();
@@ -1184,7 +1199,7 @@ LogicalResult TransposeOp::reifyResultShapes(
SmallVector<OpFoldResult> returnedDims(inputType.getRank());
for (auto dim : transposePerms) {
- int64_t dimInInput = transposePerms[dim];
+ int32_t dimInInput = transposePerms[dim];
if (inputType.isDynamicDim(dimInInput))
returnedDims[dim] =
builder.create<tensor::DimOp>(getLoc(), input, dimInInput)
@@ -1378,8 +1393,8 @@ static LogicalResult verifyReduceOp(T op) {
<< ")";
return failure();
}
- // We can only verify the reduced dimension size to be 1 if this is not the
- // special case of output rank == 0.
+ // We can only verify the reduced dimension size to be 1 if this is not
+ // the special case of output rank == 0.
if (outputRank != 0) {
auto outputShape = outputType.getShape();
if (!outputType.isDynamicDim(reduceAxis) &&
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
index 39699ee315e6cb..0d55d1899c713e 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
@@ -1,6 +1,6 @@
-// RUN: mlir-opt --split-input-file -pass-pipeline="builtin.module(func.func(tosa-to-linalg-named))" %s -verify-diagnostics -o -| FileCheck %s
-// RUN: mlir-opt --split-input-file -pass-pipeline="builtin.module(func.func(tosa-to-linalg-named{prefer-conv2d-kernel-layout-hwcf=true}))" %s -verify-diagnostics -o -| FileCheck --check-prefix="HWCF" %s
-// RUN: mlir-opt --split-input-file -pass-pipeline="builtin.module(func.func(tosa-to-linalg-named,cse))" %s -verify-diagnostics -o -| FileCheck --check-prefix="CHECK-CSE" %s
+// RUN: mlir-opt --verify-each --split-input-file -pass-pipeline="builtin.module(func.func(tosa-to-linalg-named))" %s -verify-diagnostics -o -| FileCheck %s
+// RUN: mlir-opt --verify-each --split-input-file -pass-pipeline="builtin.module(func.func(tosa-to-linalg-named{prefer-conv2d-kernel-layout-hwcf=true}))" %s -verify-diagnostics -o -| FileCheck --check-prefix="HWCF" %s
+// RUN: mlir-opt --verify-each --split-input-file -pass-pipeline="builtin.module(func.func(tosa-to-linalg-named,cse))" %s -verify-diagnostics -o -| FileCheck --check-prefix="CHECK-CSE" %s
// CHECK-LABEL: @matmul
func.func @matmul(%arg0: tensor<1x5x3xf32>, %arg1: tensor<1x3x6xf32>) -> (tensor<1x5x6xf32>) {
@@ -521,7 +521,7 @@ func.func @conv2d_scalar_bias_f32(%input: tensor<1x49x42x27xf32>, %weights: tens
// CHECK-LABEL: @conv2d_i8
func.func @conv2d_i8(%input: tensor<1x49x42x27xi8>, %weights: tensor<28x1x1x27xi8>, %bias: tensor<28xi8>) -> () {
- // HWCF: %[[TRANSPOSE_DIMS:.+]] = arith.constant dense<[1, 2, 3, 0]> : tensor<4xi64>
+ // HWCF: %[[TRANSPOSE_DIMS:.+]] = arith.constant dense<[1, 2, 3, 0]> : tensor<4xi32>
// HWCF: %[[TRANSPOSE:.+]] = linalg.transpose ins(%arg1 : tensor<28x1x1x27xi8>) outs(%[[TRANSPOSEDINIT:.+]] : tensor<1x1x27x28xi8>) permutation = [1, 2, 3, 0]
// CHECK: %[[INIT:.+]] = tensor.empty() : tensor<1x45x40x28xi32>
// CHECK: %[[BROADCAST:.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<28xi8>) outs(%[[INIT]] : tensor<1x45x40x28xi32>) {
@@ -542,7 +542,7 @@ func.func @conv2d_i8(%input: tensor<1x49x42x27xi8>, %weights: tensor<28x1x1x27xi
// CHECK-LABEL: @conv2d_f32
func.func @conv2d_f32(%input: tensor<1x49x42x27xf32>, %weights: tensor<28x3x3x27xf32>, %bias: tensor<28xf32>) -> () {
- // HWCF: %[[TRANSPOSE_DIMS:.+]] = arith.constant dense<[1, 2, 3, 0]> : tensor<4xi64>
+ // HWCF: %[[TRANSPOSE_DIMS:.+]] = arith.constant dense<[1, 2, 3, 0]> : tensor<4xi32>
// HWCF: %[[TRANSPOSE:.+]] = linalg.transpose ins(%arg1 : tensor<28x3x3x27xf32>) outs(%[[TRANSPOSEDINIT:.+]] : tensor<3x3x27x28xf32>) permutation = [1, 2, 3, 0]
// CHECK: %[[INIT:.+]] = tensor.empty() : tensor<1x45x40x28xf32>
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-pipeline.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-pipeline.mlir
index c2bbfd5130ebcd..73da2810abe044 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-pipeline.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-pipeline.mlir
@@ -24,7 +24,7 @@ func.func @tensor_with_unknown_rank(%arg0: tensor<*xi8>) -> tensor<*xi8> {
// check that tosa verify kick in
func.func @test_avg_pool2d_zero_dim_input(%arg0: tensor<1x0x?x9xf32>) -> tensor<1x7x7x9xf32> {
- // expected-error at +1 {{'tosa.avg_pool2d' op tensor has a dimension with size zero. Each dimension of a tensor must have size >= 1}}
+ // expected-error at +1 {{'tosa.avg_pool2d' op operand #0 must be 4-d tosa-conformant tensor, but got 'tensor<1x0x?x9xf32>'}}
%0 = "tosa.avg_pool2d"(%arg0) {acc_type = f32, kernel = array<i64: 2, 2>, pad = array<i64: 0, 1, 0, 1>, stride = array<i64: 1, 1>}
: (tensor<1x0x?x9xf32>) -> tensor<1x7x7x9xf32>
return %0 : tensor<1x7x7x9xf32>
diff --git a/mlir/test/Dialect/Tosa/constant-op-fold.mlir b/mlir/test/Dialect/Tosa/constant-op-fold.mlir
index 8e19f87dbf4aa8..2902c4a62009e9 100644
--- a/mlir/test/Dialect/Tosa/constant-op-fold.mlir
+++ b/mlir/test/Dialect/Tosa/constant-op-fold.mlir
@@ -80,14 +80,14 @@ func.func @transpose_fold_4d_int() -> tensor<3x1x4x2xi32> {
[[ 0, 1, 2, 3], [ 4, 5, 6, 7], [ 8, 9, 10, 11]],
[[12, 13, 14, 15], [16, 17, 18, 19], [20, 21, 22, 23]]
]]> : tensor<1x2x3x4xi32>} : () -> tensor<1x2x3x4xi32>
- %perms = "tosa.const"() {value = dense<[2, 0, 3, 1]> : tensor<4xi64>} : () -> tensor<4xi64>
+ %perms = "tosa.const"() {value = dense<[2, 0, 3, 1]> : tensor<4xi32>} : () -> tensor<4xi32>
// CHECK: %[[CST:.+]] = "tosa.const"() <{
// CHECK-SAME{LITERAL}: value = dense<[
// CHECK-SAME{LITERAL}: [[[0, 12], [1, 13], [2, 14], [3, 15]]],
// CHECK-SAME{LITERAL}: [[[4, 16], [5, 17], [6, 18], [7, 19]]],
// CHECK-SAME{LITERAL}: [[[8, 20], [9, 21], [10, 22], [11, 23]]]
// CHECK-SAME{LITERAL}: ]>
- %1 = tosa.transpose %input, %perms : (tensor<1x2x3x4xi32>, tensor<4xi64>) -> tensor<3x1x4x2xi32>
+ %1 = tosa.transpose %input, %perms : (tensor<1x2x3x4xi32>, tensor<4xi32>) -> tensor<3x1x4x2xi32>
// CHECK: return %[[CST]]
return %1 : tensor<3x1x4x2xi32>
}
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index 418f7687b3cce8..414bcfe237d753 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -1,6 +1,22 @@
// RUN: mlir-opt %s -split-input-file -verify-diagnostics --tosa-validate=strict-op-spec-alignment
+func.func @test_const() -> tensor<1xf32> {
+ // expected-error at +1{{'tosa.const' op expected same attr/result element types}}
+ %0 = "tosa.const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xf32>
+ return %0 : tensor<1xf32>
+}
+
+// -----
+
+func.func @test_const_non_tensor_attr() {
+ // expected-error at +1{{tosa.const' op expected tensors for attr/result type}}
+ %0 = "tosa.const"() {value = dense<1.0> : vector<f32>} : () -> tensor<f32>
+ return
+}
+
+// -----
+
func.func @test_conv2d(%arg0: tensor<1x29x29x4xf32>, %arg1: tensor<16x3x3x4xi8>, %arg2: tensor<16xi8>) -> tensor<1x27x27x16xi8> {
// expected-error at +1 {{expect both input and weight to be float or not together, got 'f32' and 'i8'}}
%0 = tosa.conv2d %arg0, %arg1, %arg2 {dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>}
@@ -148,6 +164,42 @@ func.func @test_transpose_invalid_permutation_tensor(%arg0: tensor<13x21x3xf32>)
// -----
+func.func @test_transpose_invalid_permutation_negative(%arg0: tensor<3x2xi32>) -> tensor<*xi32> {
+ %perms = "tosa.const"() {value = dense<[-1, 0]> : tensor<2xi32>} : () -> tensor<2xi32>
+ // expected-error at +1 {{'tosa.transpose' op expected valid permutation tensor}}
+ %1 = tosa.transpose %arg0, %perms : (tensor<3x2xi32>, tensor<2xi32>) -> tensor<*xi32>
+ return %1 : tensor<*xi32>
+}
+
+// -----
+
+func.func @test_transpose_invalid_permutation_tensor_above_range(%arg0: tensor<3x2xi32>) -> tensor<*xi32> {
+ %perms = "tosa.const"() {value = dense<[2, 0]> : tensor<2xi32>} : () -> tensor<2xi32>
+ // expected-error at +1 {{'tosa.transpose' op expected valid permutation tensor}}
+ %1 = tosa.transpose %arg0, %perms : (tensor<3x2xi32>, tensor<2xi32>) -> tensor<*xi32>
+ return %1 : tensor<*xi32>
+}
+
+// -----
+
+func.func @test_transpose_invalid_permutation_types(%arg0: tensor<3x2xi32>) -> tensor<3x4xi32> {
+ %perms = "tosa.const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32>
+ // expected-error at +1 {{'tosa.transpose' op expected output tensor dim 0 to match input dim 1 with value of 2}}
+ %1 = tosa.transpose %arg0, %perms : (tensor<3x2xi32>, tensor<2xi32>) -> tensor<3x4xi32>
+ return %1 : tensor<3x4xi32>
+}
+
+// -----
+
+func.func @test_transpose_invalid_permutation_types_dynamic_dim_ok(%arg0: tensor<2x?xi32>) -> tensor<3x4xi32> {
+ %perms = "tosa.const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32>
+ // expected-error at +1 {{'tosa.transpose' op expected output tensor dim 1 to match input dim 0 with value of 2}}
+ %1 = tosa.transpose %arg0, %perms : (tensor<2x?xi32>, tensor<2xi32>) -> tensor<3x4xi32>
+ return %1 : tensor<3x4xi32>
+}
+
+// -----
+
func.func @test_fully_connected_non_const(%arg0: tensor<13x21x3xf32>, %arg1: tensor<2x3xf32>) -> tensor<273x2xf32> {
%0 = "tosa.const"() {value = dense<0.000000e+00> : tensor<2xf32>} : () -> tensor<2xf32>
%1 = tosa.reshape %arg0 {new_shape = array<i64: 273, 3>} : (tensor<13x21x3xf32>) -> tensor<273x3xf32>
@@ -269,7 +321,7 @@ func.func @test_reshape_type_mismatch(%arg0 : tensor<13x21x3xf32>) -> () {
// -----
func.func @test_reshape_static_zero_dim_input(%arg0 : tensor<13x0x3xf32>) -> () {
- // expected-error at +1 {{'tosa.reshape' op tensor has a dimension with size zero. Each dimension of a tensor must have size >= 1}}
+ // expected-error at +1 {{'tosa.reshape' op operand #0 must be tosa-conformant tensor of number values, but got 'tensor<13x0x3xf32>'}}
%0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 13, 21, 3>} : (tensor<13x0x3xf32>) -> tensor<13x0x3xf32>
return
}
@@ -277,7 +329,7 @@ func.func @test_reshape_static_zero_dim_input(%arg0 : tensor<13x0x3xf32>) -> ()
// -----
func.func @test_reshape_zero_dim_input(%arg0 : tensor<?x0x3xf32>) -> () {
- // expected-error at +1 {{'tosa.reshape' op tensor has a dimension with size zero. Each dimension of a tensor must have size >= 1}}
+ // expected-error at +1 {{'tosa.reshape' op operand #0 must be tosa-conformant tensor of number values, but got 'tensor<?x0x3xf32>'}}
%0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 13, 21, 3>} : (tensor<?x0x3xf32>) -> tensor<13x0x3xf32>
return
}
@@ -341,7 +393,7 @@ func.func @test_const_attribute_type_mismatch() -> tensor<100x100xf32> {
// -----
func.func @test_conv2d_static_zero_dim_input(%arg0: tensor<1x29x0x4xf32>, %arg1: tensor<16x3x3x4xf32>, %arg2: tensor<16xf32>) -> tensor<1x27x27x16xf32> {
- // expected-error at +1 {{'tosa.conv2d' op tensor has a dimension with size zero. Each dimension of a tensor must have size >= 1}}
+ // expected-error at +1 {{'tosa.conv2d' op operand #0 must be 4-d tosa-conformant tensor, but got 'tensor<1x29x0x4xf32>'}}
%0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>}
: (tensor<1x29x0x4xf32>, tensor<16x3x3x4xf32>, tensor<16xf32>) -> tensor<1x27x27x16xf32>
return %0 : tensor<1x27x27x16xf32>
@@ -350,8 +402,8 @@ func.func @test_conv2d_static_zero_dim_input(%arg0: tensor<1x29x0x4xf32>, %arg1:
// -----
func.func @test_conv2d_zero_dim_input(%arg0: tensor<1x?x0x4xf32>, %arg1: tensor<16x3x3x4xf32>, %arg2: tensor<16xf32>) -> tensor<1x27x27x16xf32> {
- // expected-error at +1 {{'tosa.conv2d' op tensor has a dimension with size zero. Each dimension of a tensor must have size >= 1}}
- %0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>}
+ // expected-error at +1 {{'tosa.conv2d' op operand #0 must be 4-d tosa-conformant tensor, but got 'tensor<1x?x0x4xf32>'}}
+ %0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {acc_type = f32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>}
: (tensor<1x?x0x4xf32>, tensor<16x3x3x4xf32>, tensor<16xf32>) -> tensor<1x27x27x16xf32>
return %0 : tensor<1x27x27x16xf32>
}
@@ -360,7 +412,7 @@ func.func @test_conv2d_zero_dim_input(%arg0: tensor<1x?x0x4xf32>, %arg1: tensor<
// -----
func.func @test_avg_pool2d_static_zero_dim_input(%arg0: tensor<1x0x7x9xf32>) -> tensor<1x7x7x9xf32> {
- // expected-error at +1 {{'tosa.avg_pool2d' op tensor has a dimension with size zero. Each dimension of a tensor must have size >= 1}}
+ // expected-error at +1 {{'tosa.avg_pool2d' op operand #0 must be 4-d tosa-conformant tensor, but got 'tensor<1x0x7x9xf32>'}}
%0 = "tosa.avg_pool2d"(%arg0) {acc_type = f32, kernel = array<i64: 2, 2>, pad = array<i64: 0, 1, 0, 1>, stride = array<i64: 1, 1>}
: (tensor<1x0x7x9xf32>) -> tensor<1x7x7x9xf32>
return %0 : tensor<1x7x7x9xf32>
@@ -369,7 +421,7 @@ func.func @test_avg_pool2d_static_zero_dim_input(%arg0: tensor<1x0x7x9xf32>) ->
// -----
func.func @test_avg_pool2d_zero_dim_input(%arg0: tensor<1x0x?x9xf32>) -> tensor<1x7x7x9xf32> {
- // expected-error at +1 {{'tosa.avg_pool2d' op tensor has a dimension with size zero. Each dimension of a tensor must have size >= 1}}
+ // expected-error at +1 {{'tosa.avg_pool2d' op operand #0 must be 4-d tosa-conformant tensor, but got 'tensor<1x0x?x9xf32>'}}
%0 = "tosa.avg_pool2d"(%arg0) {acc_type = f32, kernel = array<i64: 2, 2>, pad = array<i64: 0, 1, 0, 1>, stride = array<i64: 1, 1>}
: (tensor<1x0x?x9xf32>) -> tensor<1x7x7x9xf32>
return %0 : tensor<1x7x7x9xf32>
@@ -469,7 +521,7 @@ func.func @test_tile_io_rank_mismatch() {
// CHECK-LABEL: @test_invalid_constant_permutation
func.func @test_invalid_constant_permutation() {
- // expected-error at +3 {{permutation must be within input bounds}}
+ // expected-error at +3 {{'tosa.transpose' op expected valid permutation tensor}}
%0 = tensor.empty() : tensor<3x4x5xi32>
%1 = arith.constant dense<[3, 0, 1]> : tensor<3xi32>
%2 = tosa.transpose %0, %1 : (tensor<3x4x5xi32>, tensor<3xi32>) -> tensor<3x4x5xi32>
@@ -480,7 +532,7 @@ func.func @test_invalid_constant_permutation() {
// CHECK-LABEL: test_rank_size_constant_permutation
func.func @test_rank_size_constant_permutation() {
- // expected-error at +4 {{permutation must be within input bounds}}
+ // expected-error at +4 {{'tosa.transpose' op expected valid permutation tensor}}
%0 = arith.constant 6 : index
%1 = arith.constant dense<[0, 2]> : tensor<2xi32>
%2 = tensor.empty(%0) : tensor<?x27xi64>
@@ -492,7 +544,7 @@ func.func @test_rank_size_constant_permutation() {
// CHECK-LABEL: test_large_constant_permutation
func.func @test_large_constant_permutation() {
- // expected-error at +4 {{permutation must be within input bounds}}
+ // expected-error at +4 {{'tosa.transpose' op expected valid permutation tensor}}
%0 = arith.constant 6 : index
%1 = arith.constant dense<[1185677355, 332462212]> : tensor<2xi32>
%2 = tensor.empty(%0) : tensor<?x27xi64>
@@ -504,7 +556,7 @@ func.func @test_large_constant_permutation() {
// CHECK-LABEL: test_table_rank0_table
func.func @test_table_rank0_table(%arg0: tensor<64xi16>, %arg1: tensor<i16>) {
- // expected-error at +1 {{'tosa.table' op operand #1 must be 1-d tensor, but got 'tensor<i16>'}}
+ // expected-error at +1 {{'tosa.table' op operand #1 must be 1-d tosa-conformant tensor, but got 'tensor<i16>'}}
%0 = tosa.table %arg0, %arg1 : (tensor<64xi16>, tensor<i16>) -> tensor<64xi16>
return
}
diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir
index 525ee917ccd9fd..a1600fd33c54b4 100644
--- a/mlir/test/Dialect/Tosa/ops.mlir
+++ b/mlir/test/Dialect/Tosa/ops.mlir
@@ -573,6 +573,22 @@ func.func @test_transpose(%arg0: tensor<13x21x3xf32>) -> tensor<3x13x21xf32> {
return %1 : tensor<3x13x21xf32>
}
+// -----
+// CHECK-LABEL: transpose_dynamic_dim
+func.func @test_transpose_dynamic_dim(%arg0: tensor<13x?x3xf32>) -> tensor<3x13x?xf32> {
+ %0 = "tosa.const"() {value = dense<[2, 0, 1]> : tensor<3xi32>} : () -> tensor<3xi32>
+ %1 = tosa.transpose %arg0, %0 : (tensor<13x?x3xf32>, tensor<3xi32>) -> tensor<3x13x?xf32>
+ return %1 : tensor<3x13x?xf32>
+}
+
+// -----
+// CHECK-LABEL: transpose_half_dynamic_dim
+func.func @test_transpose_half_dynamic_dim(%arg0: tensor<13x3x3xf32>) -> tensor<3x13x?xf32> {
+ %0 = "tosa.const"() {value = dense<[2, 0, 1]> : tensor<3xi32>} : () -> tensor<3xi32>
+ %1 = tosa.transpose %arg0, %0 : (tensor<13x3x3xf32>, tensor<3xi32>) -> tensor<3x13x?xf32>
+ return %1 : tensor<3x13x?xf32>
+}
+
// -----
// CHECK-LABEL: gather
func.func @test_gather(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x26xi32>) -> tensor<13x26x3xf32> {
More information about the Mlir-commits
mailing list